source: TI12-security/trunk/python/ndg.security.common/ndg/security/common/saml/xml/etree.py @ 5513

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg.security.common/ndg/security/common/saml/xml/etree.py@5513
Revision 5513, 11.5 KB checked in by pjkersha, 11 years ago (diff)

Split SAML unit tests into separate ones for SAML and SAML Attribute Authority interface.

Line 
1"""Implementation of SAML 2.0 for NDG Security - ElementTree module for
2ElementTree representation of SAML objects
3
4NERC DataGrid Project
5"""
6__author__ = "P J Kershaw"
7__date__ = "23/07/09"
8__copyright__ = "(C) 2009 Science and Technology Facilities Council"
9__contact__ = "Philip.Kershaw@stfc.ac.uk"
10__license__ = "BSD - see LICENSE file in top-level directory"
11__contact__ = "Philip.Kershaw@stfc.ac.uk"
12__revision__ = "$Id$"
13import logging
14log = logging.getLogger(__name__)
15
16try: # python 2.5
17    from xml.etree import cElementTree as ElementTree
18except ImportError:
19    # if you've installed it yourself it comes this way
20    import cElementTree as ElementTree
21   
22from ndg.security.common.utils import canonicalize
23
24# XML signature module based on M2Crypto
25from ndg.security.common.xmlsec.etree import XMLSecDoc, InvalidSignature
26
27from ndg.security.common.saml import SAMLObject, Assertion, Attribute, \
28    AttributeStatement, AttributeValue, XSStringAttributeValue, \
29    XSGroupRoleAttributeValue
30   
31from ndg.security.common.saml.xml import XMLObject, AssertionXMLObject, \
32    XMLConstants
33
34
35class ETreeObject(XMLObject):
36    """Implement methods generic to all ElementTree SAML object representations
37    """
38    def __init__(self):
39        self._elem = None
40       
41    def parse(self, source):
42        """Read in the XML from source
43        @type source: basestring/file
44        @param source: file path to XML file or file object
45        """
46        tree = ElementTree.parse(source)
47        self._elem = tree.getroot()
48       
49        return self._elem
50   
51    def serialize(self):
52        """Serialise element tree into string"""
53        return canonicalize(self._elem)
54   
55    def prettyPrint(self):
56        """Basic pretty printing separating each element on to a new line"""
57        xml = self.serialize()
58        xml = ">\n".join(xml.split(">"))
59        xml = "\n<".join(xml.split("<"))
60        xml = '\n'.join(xml.split('\n\n'))
61        return xml
62   
63                         
64class AssertionETreeObject(ETreeObject, AssertionXMLObject):
65    """ElementTree based XML representation of Assertion class
66    """
67    def __init__(self):
68        ETreeObject.__init__(self)
69        AssertionXMLObject.__init__(self)
70   
71    def create(self, assertion, makeNsDeclaration=True, 
72               **attributeValueETreeObjectFactoryKw):
73        """Make a tree of a XML elements based on the assertion"""
74        if not isinstance(assertion, Assertion):
75            raise TypeError("Expecting %r type got: %r"%(Assertion, assertion))
76       
77        issueInstant = AssertionXMLObject.issueInstantDatetime2Str(
78                                                        assertion.issueInstant)
79        attrib = {
80            Assertion.ID_ATTRIB_NAME: assertion.id,
81            Assertion.ISSUE_INSTANT_ATTRIB_NAME: issueInstant,
82           
83            # Nb. Version is a SAMLVersion instance and requires explicit cast
84            Assertion.VERSION_ATTRIB_NAME: str(assertion.version)
85        }
86        self._elem = ElementTree.Element(str(Assertion.DEFAULT_ELEMENT_NAME), 
87                                         **attrib)
88        if makeNsDeclaration:
89            self._elem.set("xmlns:%s" % Assertion.DEFAULT_ELEMENT_NAME.prefix, 
90                           Assertion.DEFAULT_ELEMENT_NAME.namespaceURI)
91       
92        attributeStatementETreeObject = AttributeStatementETreeObject()
93       
94        for attributeStatement in assertion.attributeStatements:
95            attributeStatementElem = attributeStatementETreeObject.create(
96                                        attributeStatement,
97                                        makeNsDeclaration=False,
98                                        **attributeValueETreeObjectFactoryKw)
99            self._elem.append(attributeStatementElem)
100       
101        log.debug("Created Assertion:\n\n%s" % 
102                  ">\n".join(self.serialize().split(">")))
103        return self._elem
104
105 
106class AttributeStatementETreeObject(ETreeObject):
107    """ElementTree XML representation of AttributeStatement"""
108   
109    def create(self, attributeStatement, makeNsDeclaration=True, 
110               **attributeValueETreeObjectFactoryKw):
111        if not isinstance(attributeStatement, AttributeStatement):
112            raise TypeError("Expecting %r type got: %r" % (AttributeStatement, 
113                                                           attributeStatement))
114           
115        self._elem = ElementTree.Element(
116                                str(AttributeStatement.DEFAULT_ELEMENT_NAME))
117        if makeNsDeclaration:
118            self._elem.set(
119                "xmlns:%s" % AttributeStatement.DEFAULT_ELEMENT_NAME.prefix, 
120                AttributeStatement.DEFAULT_ELEMENT_NAME.namespaceURI)
121
122        attributeETreeObject = AttributeETreeObject()
123
124        for attribute in attributeStatement.attributes:
125            # Factory enables support for multiple attribute types
126            attributeElem = attributeETreeObject.create(attribute,
127                                        makeNsDeclaration=False,
128                                        **attributeValueETreeObjectFactoryKw)
129            self._elem.append(attributeElem)
130       
131        return self._elem
132   
133
134class AttributeETreeObject(ETreeObject):
135    """ElementTree XML representation of SAML Attribute object.  Extend
136    to make Attribute types""" 
137
138    def create(self, attribute, makeNsDeclaration=True,
139               **attributeValueETreeObjectFactoryKw):
140        """Make 'Attribute' element"""
141       
142        if not isinstance(attribute, Attribute):
143            raise TypeError("Expecting %r type got: %r"%(Attribute, attribute))
144           
145        self._elem = ElementTree.Element(str(Attribute.DEFAULT_ELEMENT_NAME))
146        if makeNsDeclaration:
147            self._elem.set("xmlns:%s" % Attribute.DEFAULT_ELEMENT_NAME.prefix, 
148                           Attribute.DEFAULT_ELEMENT_NAME.namespaceURI)
149           
150        if attribute.friendlyName:
151            self._elem.set(Attribute.FRIENDLY_NAME_ATTRIB_NAME,
152                           attribute.friendlyName) 
153        if attribute.name:
154            self._elem.set(Attribute.NAME_ATTRIB_NAME, attribute.name)
155       
156        if attribute.nameFormat:
157            self._elem.set(Attribute.NAME_FORMAT_ATTRIB_NAME,
158                           attribute.nameFormat)
159
160        for attributeValue in attribute.attributeValues:
161            factory = AttributeValueETreeObjectFactory(
162                                        **attributeValueETreeObjectFactoryKw)
163           
164            attributeValueETreeObject = factory(attributeValue)
165           
166            attributeValueElem = attributeValueETreeObject.create(
167                                                    attributeValue, 
168                                                    makeNsDeclaration=False)
169            self._elem.append(attributeValueElem)
170           
171        return self._elem
172 
173   
174class AttributeValueETreeObjectBase(ETreeObject):
175    """Base class ElementTree XML representation of SAML Attribute Value""" 
176   
177    def create(self, attributeValue, makeNsDeclaration=True):
178        """Make 'Attribute' XML element"""
179
180        if not isinstance(attributeValue, AttributeValue):
181            raise TypeError("Expecting %r type got: %r" % (AttributeValue, 
182                                                           attributeValue))
183           
184        self._elem = ElementTree.Element(
185                                    str(AttributeValue.DEFAULT_ELEMENT_NAME))
186        if makeNsDeclaration:
187            self._elem.set(
188                    "xmlns:%s" % AttributeValue.DEFAULT_ELEMENT_NAME.prefix, 
189                    AttributeValue.DEFAULT_ELEMENT_NAME.namespaceURI)
190
191        return self._elem
192
193
194class XSStringAttributeValueETreeObject(AttributeValueETreeObjectBase):
195    """ElementTree XML representation of SAML String type Attribute Value""" 
196   
197    def create(self, attributeValue, makeNsDeclaration=True):
198        """Create an XML representation of the input SAML Attribute Value"""
199        super(XSStringAttributeValueETreeObject, self).create(
200                                        attributeValue, 
201                                        makeNsDeclaration=makeNsDeclaration)
202       
203        if not isinstance(attributeValue, XSStringAttributeValue):
204            raise TypeError("Expecting %r type got: %r" % 
205                            (XSStringAttributeValue, attributeValue)) 
206                   
207        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX, 
208                                  XMLConstants.XSD_PREFIX), 
209                       XMLConstants.XSD_NS)
210       
211        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX,
212                                  XMLConstants.XSI_PREFIX), 
213                       XMLConstants.XSI_NS)
214       
215        self._elem.set("%s:%s" % (XMLConstants.XSI_PREFIX, 'type'), 
216                       "%s:%s" % (XMLConstants.XSD_PREFIX, 
217                                  XSStringAttributeValue.TYPE_LOCAL_NAME))
218
219        self._elem.text = attributeValue.value
220
221        return self._elem
222
223
224class XSGroupRoleAttributeValueETreeObject(AttributeValueETreeObjectBase):
225    """ElementTree XML representation of Earth System Grid custom Group/Role
226    Attribute Value""" 
227
228    def create(self, attributeValue, makeNsDeclaration=True):
229        """Create an XML representation of the input SAML Attribute Value"""
230        super(XSGroupRoleAttributeValueETreeObject, self).create(
231                                        attributeValue, 
232                                        makeNsDeclaration=makeNsDeclaration)
233       
234        if not isinstance(attributeValue, XSGroupRoleAttributeValue):
235            raise TypeError("Expecting %r type; got: %r" % 
236                            (XSGroupRoleAttributeValue, attributeValue))
237           
238        self._elem.set("%s:%s" % (XMLConstants.XMLNS_PREFIX, 
239                                  attributeValue.namespacePrefix), 
240                       attributeValue.namespaceURI)
241       
242        self._elem.set(XSGroupRoleAttributeValue.GROUP_ATTRIB_NAME, 
243                       attributeValue.group)
244       
245        self._elem.set(XSGroupRoleAttributeValue.ROLE_ATTRIB_NAME, 
246                       attributeValue.role)
247
248        return self._elem
249
250class AttributeValueETreeObjectFactory(object):
251    """Factory for creating ElementTree representations of SAML Attribute
252    value types
253    """
254    classMap = {
255        XSStringAttributeValue: XSStringAttributeValueETreeObject
256    }
257   
258    def __init__(self, customClassMap={}): 
259        """Set-up a SAML class to ElementTree mapping
260        """
261        self.__classMap = AttributeValueETreeObjectFactory.classMap
262        for samlClass, etreeClass in customClassMap.items(): 
263            if not issubclass(samlClass, AttributeValue):
264                raise TypeError("Input custom class must be derived from %r, "
265                                "got %r instead" % (Attribute, samlClass))
266               
267            self.__classMap[samlClass] = etreeClass
268           
269    def __call__(self, samlObject):
270        """Create an ElementTree object based on the Attribute class type
271        passed in
272        """
273        if not isinstance(samlObject, AttributeValue):
274            raise TypeError("Expecting %r class got %r" % (AttributeValue, 
275                                                           samlObject))
276           
277        xmlObjectClass = self.__classMap.get(samlObject.__class__)
278        if xmlObjectClass is None:
279            raise TypeError("no matching XMLObject class representation for "
280                            "SAML class %r" % samlObject.__class__)
281           
282        return xmlObjectClass()
Note: See TracBrowser for help on using the repository browser.