source: TI12-security/trunk/ndg_security_saml/saml/common/__init__.py @ 6602

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/ndg_security_saml/saml/common/__init__.py@6602
Revision 6602, 7.2 KB checked in by pjkersha, 11 years ago (diff)
  • Enabling pickling by adding getstate and setstate methods
  • fixes for correct inheritance from SAMLObject
Line 
1"""SAML 2.0 common package
2
3Implementation of SAML 2.0 for NDG Security
4
5NERC DataGrid Project
6
7This implementation is adapted from the Java OpenSAML implementation.  The
8copyright and licence information are included here:
9
10Copyright [2005] [University Corporation for Advanced Internet Development, Inc.]
11
12Licensed under the Apache License, Version 2.0 (the "License");
13you may not use this file except in compliance with the License.
14You may obtain a copy of the License at
15
16http://www.apache.org/licenses/LICENSE-2.0
17
18Unless required by applicable law or agreed to in writing, software
19distributed under the License is distributed on an "AS IS" BASIS,
20WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21See the License for the specific language governing permissions and
22limitations under the License.
23"""
24__author__ = "P J Kershaw"
25__date__ = "11/08/09"
26__copyright__ = "(C) 2009 Science and Technology Facilities Council"
27__contact__ = "Philip.Kershaw@stfc.ac.uk"
28__license__ = "BSD - see LICENSE file in top-level directory"
29__contact__ = "Philip.Kershaw@stfc.ac.uk"
30__revision__ = "$Id: $"
31from saml.common.xml import SAMLConstants, QName
32 
33
34class SAMLObject(object):
35    """Base class for all SAML types"""
36    DEFAULT_ELEMENT_LOCAL_NAME = None
37    __slots__ = ('__qname',)
38   
39    def __init__(self,
40                 namespaceURI=SAMLConstants.SAML20_NS, 
41                 elementLocalName=None, 
42                 namespacePrefix=SAMLConstants.SAML20_PREFIX):
43        '''@param namespaceURI: the namespace the element is in
44        @param elementLocalName: the local name of the XML element this Object
45        represents
46        @param namespacePrefix: the prefix for the given namespace
47        '''
48        if elementLocalName is None:
49            elementLocalName = self.__class__.DEFAULT_ELEMENT_LOCAL_NAME
50           
51        self.__qname = QName(namespaceURI, 
52                             elementLocalName, 
53                             namespacePrefix)
54           
55    @property
56    def qname(self):
57        "Qualified Name for this type"
58        return self.__qname
59           
60    @classmethod
61    def fromXML(cls, xmlObject):
62        '''Parse from an XML representation into a SAML object
63        @type: XML class e.g. ElementTree or 4Suite XML
64        @param: XML representation of SAML Object
65        @rtype: saml.saml2.common.SAMLObject derived type
66        @return: SAML object
67        '''
68        raise NotImplementedError()
69   
70    @classmethod
71    def toXML(cls, samlObject):
72        '''Convert the input SAML object into an XML representation
73        @type: saml.saml2.common.SAMLObject derived type
74        @param: SAML object
75        @rtype: XML class e.g. ElementTree or 4Suite XML
76        @return: XML representation of SAML Object
77        '''
78        raise NotImplementedError()
79
80    def __getstate__(self):
81        '''Enable pickling'''
82        _dict = {}
83        for attrName in SAMLObject.__slots__:
84            # Ugly hack to allow for derived classes setting private member
85            # variables
86            if attrName.startswith('__'):
87                attrName = "_SAMLObject" + attrName
88               
89            try:
90                _dict[attrName] = getattr(self, attrName)
91            except:
92                pass
93           
94        return _dict
95 
96    def __setstate__(self, attrDict):
97        '''Enable pickling'''
98        for attrName, val in attrDict.items():
99            setattr(self, attrName, val)
100           
101
102class SAMLVersion(object):
103    """Version helper class"""
104   
105    VERSION_10 = (1, 0)
106    VERSION_11 = (1, 1)
107    VERSION_20 = (2, 0)
108    KNOWN_VERSIONS = (VERSION_10, VERSION_11, VERSION_20)
109   
110    __slots__ = ('__version', )
111   
112    def __init__(self, version):
113        if isinstance(version, basestring):
114            self.__version = SAMLVersion.valueOf(version)
115        elif isinstance(version, (tuple, list)):
116            self.__version = tuple(version)
117        else:
118            raise TypeError("Expecting string, tuple or list type for SAML "
119                            "version initialiser; got %r" % version)
120           
121    def __getstate__(self):
122        '''Enable pickling'''
123        _dict = {}
124        for attrName in SAMLVersion.__slots__:
125            # Ugly hack to allow for derived classes setting private member
126            # variables
127            if attrName.startswith('__'):
128                attrName = "_SAMLVersion" + attrName
129               
130            _dict[attrName] = getattr(self, attrName)
131           
132        return _dict
133 
134    def __setstate__(self, attrDict):
135        '''Enable pickling'''
136        for attrName, val in attrDict.items():
137            setattr(self, attrName, val)
138   
139    def __str__(self):
140        return ".".join([str(i) for i in self.__version])
141   
142    def __eq__(self, version):
143        """Test for equality against an input version string, tuple or list"""
144               
145        if isinstance(version, basestring):
146            return self.__version == SAMLVersion.valueOf(version)
147        elif isinstance(version, (tuple, list)):
148            return self.__version == tuple(version)
149        else:
150            raise TypeError("Expecting string, tuple or list type for SAML "
151                            "version comparison; got %r" % version)
152           
153    def __ne__(self, version):
154        return not self.__eq__(version)
155           
156    def __gt__(self, version):               
157        if isinstance(version, basestring):
158            return self.__version > SAMLVersion.valueOf(version)
159        elif isinstance(version, (tuple, list)):
160            return self.__version > tuple(version)
161        else:
162            raise TypeError("Expecting string, tuple or list type for SAML "
163                            "version comparison; got %r" % version)
164           
165    def __lt__(self, version):
166        if isinstance(version, basestring):
167            return self.__version < SAMLVersion.valueOf(version)
168        elif isinstance(version, (tuple, list)):
169            return self.__version < tuple(version)
170        else:
171            raise TypeError("Expecting string, tuple or list type for SAML "
172                            "version comparison; got %r" % version)
173           
174    def __ge__(self, version):               
175        if isinstance(version, basestring):
176            return self.__version >= SAMLVersion.valueOf(version)
177        elif isinstance(version, (tuple, list)):
178            return self.__version >= tuple(version)
179        else:
180            raise TypeError("Expecting string, tuple or list type for SAML "
181                            "version comparison; got %r" % version)
182           
183    def __le__(self, version):               
184        if isinstance(version, basestring):
185            return self.__version <= SAMLVersion.valueOf(version)
186        elif isinstance(version, (tuple, list)):
187            return self.__version <= tuple(version)
188        else:
189            raise TypeError("Expecting string, tuple or list type for SAML "
190                            "version comparison; got %r" % version)
191   
192    @staticmethod
193    def valueOf(version):
194        """Parse input string into version tuple
195        @type version: version
196        @param version: SAML version
197        @rtype: tuple
198        @return: SAML version tuple"""
199        return tuple([int(i) for i in version.split(".")])
Note: See TracBrowser for help on using the repository browser.