source: TI12-security/trunk/python/ndg.security.common/ndg/security/common/saml/xml/etree.py @ 5510

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg.security.common/ndg/security/common/saml/xml/etree.py@5510
Revision 5510, 11.5 KB checked in by pjkersha, 11 years ago (diff)

Working SAML Assertion code with ESG specific Group/Role? attributes

Line 
1"""Implementation of SAML 2.0 for NDG Security - ElementTree module for
2ElementTree representation of SAML objects
3
4NERC DataGrid Project
5"""
6__author__ = "P J Kershaw"
7__date__ = "23/07/09"
8__copyright__ = "(C) 2009 Science and Technology Facilities Council"
9__contact__ = "Philip.Kershaw@stfc.ac.uk"
10__license__ = "BSD - see LICENSE file in top-level directory"
11__contact__ = "Philip.Kershaw@stfc.ac.uk"
12__revision__ = "$Id$"
13import logging
14log = logging.getLogger(__name__)
15
16try: # python 2.5
17    from xml.etree import cElementTree as ElementTree
18except ImportError:
19    # if you've installed it yourself it comes this way
20    import cElementTree as ElementTree
21   
22from ndg.security.common.utils import canonicalize
23
24# XML signature module based on M2Crypto
25from ndg.security.common.xmlsec.etree import XMLSecDoc, InvalidSignature
26
27from ndg.security.common.saml import SAMLObject, Assertion, Attribute, \
28    AttributeStatement, AttributeValue, XSStringAttributeValue, \
29    XSGroupRoleAttributeValue
30   
31from ndg.security.common.saml.xml import XMLObject, AssertionXMLObject, \
32    XMLConstants
33
34
35class ETreeObject(XMLObject):
36    """Implement methods generic to all ElementTree SAML object representations
37    """
38    def __init__(self):
39        self._elem = None
40       
41    def parse(self, source):
42        """Read in the XML from source
43        @type source: basestring/file
44        @param source: file path to XML file or file object
45        """
46        tree = ElementTree.parse(source)
47        self._elem = tree.getroot()
48       
49        return self._elem
50   
51    def serialize(self):
52        """Serialise element tree into string"""
53        return canonicalize(self._elem)
54   
55    def prettyPrint(self):
56        """Basic pretty printing separating each element on to a new line"""
57        xml = self.serialize()
58        xml = ">\n".join(xml.split(">"))
59        xml = "\n</".join(xml.split("<"))
60        xml = '\n'.join(xml.split('\n\n'))
61        return xml
62   
63                         
64class AssertionETreeObject(ETreeObject, AssertionXMLObject):
65    """ElementTree based XML representation of Assertion class
66    """
67    def __init__(self):
68        ETreeObject.__init__(self)
69        AssertionXMLObject.__init__(self)
70   
71    def create(self, assertion, makeNsDeclaration=True, 
72               **attributeValueETreeObjectFactoryKw):
73        """Make a tree of a XML elements based on the assertion"""
74        if not isinstance(assertion, Assertion):
75            raise TypeError("Expecting %r type got: %r"%(Assertion, assertion))
76       
77        issueInstant = AssertionXMLObject.issueInstantDatetime2Str(
78                                                        assertion.issueInstant)
79        attrib = {
80            Assertion.ID_ATTRIB_NAME: assertion.id,
81            Assertion.ISSUE_INSTANT_ATTRIB_NAME: issueInstant,
82           
83            # Nb. Version is a SAMLVersion instance and requires explicit cast
84            Assertion.VERSION_ATTRIB_NAME: str(assertion.version)
85        }
86        self._elem = ElementTree.Element(str(Assertion.DEFAULT_ELEMENT_NAME), 
87                                         **attrib)
88        if makeNsDeclaration:
89            self._elem.set("xmlns:%s" % Assertion.DEFAULT_ELEMENT_NAME.prefix, 
90                           Assertion.DEFAULT_ELEMENT_NAME.namespaceURI)
91       
92        attributeStatementETreeObject = AttributeStatementETreeObject()
93       
94        for attributeStatement in assertion.attributeStatements:
95            attributeStatementElem = attributeStatementETreeObject.create(
96                                        attributeStatement,
97                                        makeNsDeclaration=False,
98                                        **attributeValueETreeObjectFactoryKw)
99            self._elem.append(attributeStatementElem)
100       
101        log.debug("Created Assertion:\n\n%s" % 
102                  ">\n".join(self.serialize().split(">")))
103        return self._elem
104
105 
106class AttributeStatementETreeObject(ETreeObject):
107    """ElementTree XML representation of AttributeStatement"""
108   
109    def create(self, attributeStatement, makeNsDeclaration=True, 
110               **attributeValueETreeObjectFactoryKw):
111        if not isinstance(attributeStatement, AttributeStatement):
112            raise TypeError("Expecting %r type got: %r" % (AttributeStatement, 
113                                                           attributeStatement))
114           
115        self._elem = ElementTree.Element(
116                                str(AttributeStatement.DEFAULT_ELEMENT_NAME))
117        if makeNsDeclaration:
118            self._elem.set(
119                "xmlns:%s" % AttributeStatement.DEFAULT_ELEMENT_NAME.prefix, 
120                AttributeStatement.DEFAULT_ELEMENT_NAME.namespaceURI)
121
122        attributeETreeObject = AttributeETreeObject()
123
124        for attribute in attributeStatement.attributes:
125            # Factory enables support for multiple attribute types
126            attributeElem = attributeETreeObject.create(attribute,
127                                        makeNsDeclaration=False,
128                                        **attributeValueETreeObjectFactoryKw)
129            self._elem.append(attributeElem)
130       
131        return self._elem
132   
133
134class AttributeETreeObject(ETreeObject):
135    """ElementTree XML representation of SAML Attribute object.  Extend
136    to make Attribute types""" 
137
138    def create(self, attribute, makeNsDeclaration=True,
139               **attributeValueETreeObjectFactoryKw):
140        """Make 'Attribute' element"""
141       
142        if not isinstance(attribute, Attribute):
143            raise TypeError("Expecting %r type got: %r"%(Attribute, attribute))
144           
145        self._elem = ElementTree.Element(str(Attribute.DEFAULT_ELEMENT_NAME))
146        if makeNsDeclaration:
147            self._elem.set("xmlns:%s" % Attribute.DEFAULT_ELEMENT_NAME.prefix, 
148                           Attribute.DEFAULT_ELEMENT_NAME.namespaceURI)
149           
150        if attribute.friendlyName:
151            self._elem.set(Attribute.FRIENDLY_NAME_ATTRIB_NAME,
152                           attribute.friendlyName) 
153        if attribute.name:
154            self._elem.set(Attribute.NAME_ATTRIB_NAME, attribute.name)
155       
156        if attribute.nameFormat:
157            self._elem.set(Attribute.NAME_FORMAT_ATTRIB_NAME,
158                           attribute.nameFormat)
159
160        for attributeValue in attribute.attributeValues:
161            factory = AttributeValueETreeObjectFactory(
162                                        **attributeValueETreeObjectFactoryKw)
163           
164            attributeValueETreeObject = factory(attributeValue)
165           
166            attributeValueElem = attributeValueETreeObject.create(
167                                                    attributeValue, 
168                                                    makeNsDeclaration=False)
169            self._elem.append(attributeValueElem)
170           
171        return self._elem
172 
173   
174class AttributeValueETreeObjectBase(ETreeObject):
175    """Base class ElementTree XML representation of SAML Attribute Value""" 
176   
177    def create(self, attributeValue, makeNsDeclaration=True):
178        """Make 'Attribute' XML element"""
179
180        if not isinstance(attributeValue, AttributeValue):
181            raise TypeError("Expecting %r type got: %r" % (AttributeValue, 
182                                                           attributeValue))
183           
184        self._elem = ElementTree.Element(
185                                    str(AttributeValue.DEFAULT_ELEMENT_NAME))
186        if makeNsDeclaration:
187            self._elem.set(
188                    "xmlns:%s" % AttributeValue.DEFAULT_ELEMENT_NAME.prefix, 
189                    AttributeValue.DEFAULT_ELEMENT_NAME.namespaceURI)
190
191        return self._elem
192
193
194class XSStringAttributeValueETreeObject(AttributeValueETreeObjectBase):
195    """ElementTree XML representation of SAML String type Attribute Value""" 
196   
197    def create(self, attributeValue, makeNsDeclaration=True):
198        """Create an XML representation of the input SAML Attribute Value"""
199        super(XSStringAttributeValueETreeObject, self).create(
200                                        attributeValue, 
201                                        makeNsDeclaration=makeNsDeclaration)
202       
203        if not isinstance(attributeValue, XSStringAttributeValue):
204            raise TypeError("Expecting %r type got: %r" % 
205                            (XSStringAttributeValue, attributeValue)) 
206                   
207        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX, 
208                                  XMLConstants.XSD_PREFIX), 
209                       XMLConstants.XSD_NS)
210       
211        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX,
212                                  XMLConstants.XSI_PREFIX), 
213                       XMLConstants.XSI_NS)
214       
215        self._elem.set("%s:%s" % (XMLConstants.XSI_PREFIX, 'type'), 
216                       "%s:%s" % (XMLConstants.XSD_PREFIX, 
217                                  XSStringAttributeValue.TYPE_LOCAL_NAME))
218
219        self._elem.text = attributeValue.value
220
221        return self._elem
222
223
224class XSGroupRoleAttributeValueETreeObject(AttributeValueETreeObjectBase):
225    """ElementTree XML representation of Earth System Grid custom Group/Role
226    Attribute Value""" 
227
228    def create(self, attributeValue, makeNsDeclaration=True):
229        """Create an XML representation of the input SAML Attribute Value"""
230        super(XSGroupRoleAttributeValueETreeObject, self).create(
231                                        attributeValue, 
232                                        makeNsDeclaration=makeNsDeclaration)
233       
234        if not isinstance(attributeValue, XSGroupRoleAttributeValue):
235            raise TypeError("Expecting %r type; got: %r" % 
236                            (XSGroupRoleAttributeValue, attributeValue))
237           
238        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX, 
239                                  attributeValue.namespacePrefix), 
240                       attributeValue.namespaceURI)
241       
242        self._elem.set(XSGroupRoleAttributeValue.GROUP_ATTRIB_NAME, 
243                       attributeValue.group)
244       
245        self._elem.set(XSGroupRoleAttributeValue.ROLE_ATTRIB_NAME, 
246                       attributeValue.role)
247
248        return self._elem
249
250class AttributeValueETreeObjectFactory(object):
251    """Factory for creating ElementTree representations of SAML Attribute
252    value types
253    """
254    classMap = {
255        XSStringAttributeValue: XSStringAttributeValueETreeObject
256    }
257   
258    def __init__(self, customClassMap={}): 
259        """Set-up a SAML class to ElementTree mapping
260        """
261        self.__classMap = AttributeValueETreeObjectFactory.classMap
262        for samlClass, etreeClass in customClassMap.items(): 
263            if not issubclass(samlClass, AttributeValue):
264                raise TypeError("Input custom class must be derived from %r, "
265                                "got %r instead" % (Attribute, samlClass))
266               
267            self.__classMap[samlClass] = etreeClass
268           
269    def __call__(self, samlObject):
270        """Create an ElementTree object based on the Attribute class type
271        passed in
272        """
273        if not isinstance(samlObject, AttributeValue):
274            raise TypeError("Expecting %r class got %r" % (AttributeValue, 
275                                                           samlObject))
276           
277        xmlObjectClass = self.__classMap.get(samlObject.__class__)
278        if xmlObjectClass is None:
279            raise TypeError("no matching XMLObject class representation for "
280                            "SAML class %r" % samlObject.__class__)
281           
282        return xmlObjectClass()
Note: See TracBrowser for help on using the repository browser.