Changeset 5507


Ignore:
Timestamp:
23/07/09 16:47:09 (10 years ago)
Author:
pjkersha
Message:

Working SAML assertion code with string attributes

Location:
TI12-security/trunk/python
Files:
3 added
2 edited

Legend:

Unmodified
Added
Removed
  • TI12-security/trunk/python/ndg.security.common/ndg/security/common/saml/__init__.py

    r5504 r5507  
    1 """Implementation of SAML 1.1 Attribute Assertion for NDG Security 
    2  
    3 NERC Data Grid Project 
     1"""Implementation of SAML 2.0 for NDG Security 
     2 
     3NERC DataGrid Project 
    44""" 
    55__author__ = "P J Kershaw" 
     
    1414log = logging.getLogger(__name__) 
    1515 
    16 try: # python 2.5 
    17     from xml.etree import cElementTree as ElementTree 
    18 except ImportError: 
    19     # if you've installed it yourself it comes this way 
    20     import cElementTree as ElementTree 
    21  
    22 from elementtree import ElementC14N 
    23  
    2416# Time module for use with validity times 
    2517from time import strftime, strptime 
    2618from datetime import datetime, timedelta 
    2719    
    28 # XML signature module based on M2Crypto, ZSI Canonicalization and DOM 
    29 from ndg.security.common.xmlsec.etree import XMLSecDoc, InvalidSignature 
    30 from ndg.security.common.utils import QName, canonicalize 
    31  
     20from ndg.security.common.utils import QName 
     21 
     22class XMLConstants(object): 
     23    '''XML related constants.''' 
     24 
     25    # XML Tooling 
     26 
     27    # Configuration namespace 
     28    XMLTOOLING_CONFIG_NS = "http:#www.opensaml.org/xmltooling-config" 
     29 
     30    # Configuration namespace prefix 
     31    XMLTOOLING_CONFIG_PREFIX = "xt" 
     32     
     33    # Name of the object provider used for objects that don't have a registered 
     34    # object provider 
     35    XMLTOOLING_DEFAULT_OBJECT_PROVIDER = "DEFAULT" 
     36 
     37    # Core XML 
     38 
     39    # XML core namespace 
     40    XML_NS = "http:#www.w3.org/XML/1998/namespace" 
     41     
     42    # XML core prefix for xml attributes 
     43    XML_PREFIX = "xml" 
     44 
     45    # XML namespace for xmlns attributes 
     46    XMLNS_NS = "http://www.w3.org/2000/xmlns/" 
     47 
     48    # XML namespace prefix for xmlns attributes 
     49    XMLNS_PREFIX = "xmlns" 
     50 
     51    # XML Schema namespace 
     52    XSD_NS = "http://www.w3.org/2001/XMLSchema" 
     53 
     54    # XML Schema QName prefix 
     55    XSD_PREFIX = "xs" 
     56 
     57    # XML Schema Instance namespace 
     58    XSI_NS = "http://www.w3.org/2001/XMLSchema-instance" 
     59 
     60    # XML Schema Instance QName prefix 
     61    XSI_PREFIX = "xsi" 
     62 
     63    # XML XMLSecSignatureImpl namespace 
     64    XMLSIG_NS = "http://www.w3.org/2000/09/xmldsig#" 
     65 
     66    # XML XMLSecSignatureImpl QName prefix 
     67    XMLSIG_PREFIX = "ds" 
     68 
     69    # XML Encryption namespace 
     70    XMLENC_NS = "http://www.w3.org/2001/04/xmlenc#" 
     71 
     72    # XML Encryption QName prefix 
     73    XMLENC_PREFIX = "xenc" 
     74     
     75    # Local name of EncryptedData element 
     76    XMLENC_ENCDATA_LOCAL_NAME = "EncryptedData" 
     77     
     78    # Local name of EncryptedKey element 
     79    XMLENC_ENCKEY_LOCAL_NAME = "EncryptedKey" 
     80     
     81       
    3282class SAMLObject(object): 
    33     pass 
    34  
    35 class XMLConstants(object): 
    3683    pass 
    3784 
     
    258305        return ".".join([str(i) for i in self.__version]) 
    259306     
    260     @classmethod 
    261     def valueOf(cls, version): 
     307    @staticmethod 
     308    def valueOf(version): 
    262309        return tuple(version.split(".")) 
    263310     
    264311 
    265 class Assertion(SAMLObject, XMLSecDoc): 
     312class Assertion(SAMLObject): 
    266313    """SAML 2.0 Attribute Assertion for use with NERC DataGrid     
    267314    """     
     
    303350 
    304351        # Base class initialisation 
    305         SAMLObject.__init__(self) 
    306         XMLSecDoc.__init__(self, **xmlSecDocKw) 
     352        super(Assertion, self).__init__() 
    307353         
    308354        self._version = None 
     
    505551 
    506552 
    507     #_________________________________________________________________________     
    508     def parse(self, xmlTxt, rtnRootElem=False): 
    509         """Parse an Attribute Assertion content contained in string input 
    510  
    511         @param xmlTxt:     Attribute Assertion XML content as string 
    512         @param rtnRootElem: boolean set to True to return the ElementTree 
    513         root element 
    514          
    515         @rtype: ElementTree root element 
    516         @return: root element if rtnRootElem keyword is set to True""" 
    517          
    518         rootElem = ElementTree.XML(xmlTxt) 
    519  
    520         # Call generic ElementTree parser 
    521         self.__parse(rootElem) 
    522  
    523  
    524         # Call base class parser method to initialise DOM objects for 
    525         # signature validation 
    526         try: 
    527             XMLSecDoc.parse(self, xmlTxt) 
    528  
    529         except Exception, e: 
    530             raise AssertionError("Attribute Assertion: %s" % e) 
    531  
    532         if rtnRootElem: 
    533             return rootElem 
    534  
    535     def __parse(self, rootElem): 
    536         """Private XML parsing method accepts a ElementTree.Element type 
    537         as input 
    538  
    539         @param rootElem: root element of doc - ElementTree.Element type 
    540         """ 
    541          
    542         # Extract from acInfo tag 
    543         acInfoElem = rootElem.find("acInfo") 
    544          
    545         if not acInfoElem: 
    546             raise AssertionError("<acInfo> tag not found in \"%s\"" % \ 
    547                                self.filePath) 
    548  
    549  
    550         # Copy all acInfo tags into dictionary 
    551         for elem in acInfoElem: 
    552             if elem.tag not in self.__dat: 
    553                 raise AssertionError('%s: "<%s>" not recognised.' % \ 
    554                                     (self.filePath, elem.tag)) 
    555  
    556             # Make sure not to copy validity and attributes tags - handle  
    557             # these separately below 
    558             if not elem.getchildren(): 
    559                 self.__dat[elem.tag] = elem.text 
    560  
    561         # Convert issuer and holder into X500DN instances 
    562         try: 
    563             self.__issuerDN = X500DN(dn=self.__dat['issuer']) 
    564  
    565         except X500DNError, x500dnErr: 
    566             raise AssertionError("Issuer DN: %s" % x500dnErr) 
    567  
    568  
    569         try: 
    570             self.__holderDN = X500DN(dn=self.__dat['holder']) 
    571  
    572         except X500DNError, x500dnErr: 
    573             raise AssertionError("Holder DN: %s" % x500dnErr) 
    574          
    575                                   
    576         # Extract validity and attributes subsets 
    577         self.__dat['validity']['notBefore'] = \ 
    578                                 rootElem.findtext("acInfo/validity/notBefore") 
    579          
    580         if self.__dat['validity']['notBefore'] is None: 
    581             raise AssertionError('<notBefore> tag not found in "%s"'%\ 
    582                                                                self.filePath) 
    583         elif self.__dat['validity']['notBefore'] == '': 
    584              
    585             # Allow empty string setting but re-initialise corresponding  
    586             # datetime value 
    587             self.__dtNotBefore = None 
    588         else: 
    589             # Update datetime object equivalent 
    590             self.__dtNotBefore = self.timeStr2datetime(\ 
    591                                         self.__dat['validity']['notBefore']) 
    592  
    593          
    594         self.__dat['validity']['notAfter'] = \ 
    595                                 rootElem.findtext("acInfo/validity/notAfter") 
    596          
    597         if self.__dat['validity']['notAfter'] is None: 
    598             raise AssertionError('<notAfter> tag not found in "%s"' % \ 
    599                                self.filePath) 
    600         elif self.__dat['validity']['notBefore'] == '': 
    601              
    602             # Allow empty string setting but re-initialise corresponding  
    603             # datetime value 
    604             self.__dtNotAfter = None 
    605         else: 
    606             # Update datetime object equivalent 
    607             self.__dtNotAfter = self.timeStr2datetime(\ 
    608                                         self.__dat['validity']['notAfter']) 
    609  
    610         # set up role list 
    611         roleElem = acInfoElem.findall("attributes/roleSet/role/name") 
    612         if roleElem is None: 
    613             raise AssertionError("<role> tag not found in \"%s\"" % \ 
    614                                self.filePath) 
    615          
    616         self.__dat['attributes']['roleSet'] = \ 
    617                                 [{'role': {'name': i.text}} for i in roleElem] 
    618                      
    619          
    620         if not self.isValidVersion():            
    621             raise AssertionError('Attribute Assertion version is ' + \ 
    622                                self.__dat['version'] + ' but version ' + \ 
    623                                Assertion.version + ' expected') 
    624  
    625  
    626     #_________________________________________________________________________     
    627     def _createXML(self): 
    628         """Create XML for Attribute Assertion from current data settings and 
    629         return as a string.  The XML created is MINUS the digital signature. 
    630         To obtain the signed version, run the applyEnvelopedSignature method  
    631         (inherited from XMLSecDoc) and pass the attributeAssertion object  
    632         reference into str() 
    633  
    634         @rtype: string 
    635         @return: formatted XML for certificate as a string""" 
    636  
    637         # Nb. 
    638         # * this method is used by Assertion.read() 
    639         # * Signing by Attribute Authority is separate - see  
    640         # Assertion.sign() 
    641          
    642  
    643         # Check for valid provenance 
    644         if not self.isValidProvenance(): 
    645             raise AssertionError("Provenance must be set to \"" + \ 
    646                 "\" or \"".join(Assertion.__validProvenanceSettings)+\ 
    647                 "\"") 
    648  
    649          
    650         self._rootElem = ElementTree.Element("{%s}Assertion" %  
    651                                              Assertion.ns) 
    652         self._rootETree = ElementTree.ElementTree(element=self._rootElem) 
    653          
    654         self._rootElem.set('xmlns:%s' % Assertion.nsPfx,  
    655                            Assertion.ns) 
    656         self._rootElem.set('MajorVersion', '1') 
    657         self._rootElem.set('MinorVersion', '1') 
    658         self._rootElem.set('Issuer', self.issuerName) 
    659      
    660         conditionsElem = ElementTree.SubElement(self._rootElem, 
    661                                                 '{%s}Conditions' % \ 
    662                                                 Assertion.ns)     
    663         conditionsElem.set('NotBefore', 'x') 
    664         conditionsElem.set('NotAfter', 'y') 
    665          
    666         attributeStatementElem = ElementTree.SubElement(self._rootElem, 
    667                                                 '{%s}AttributeStatement' % \ 
    668                                                 Assertion.ns) 
    669          
    670         subjectElem = ElementTree.SubElement(attributeStatementElem, 
    671                                              '{%s}Subject' % \ 
    672                                              Assertion.ns) 
    673          
    674         attributeElem = ElementTree.SubElement(attributeStatementElem, 
    675                                                '{%s}Attribute' % \ 
    676                                                Assertion.ns) 
    677      
    678         attributeElem.set('AttributeName', 
    679                           "urn:mace:dir:attribute-def:eduPersonAffiliation") 
    680         attributeElem.set('AttributeNamespace', 
    681                           "urn:mace:shibboleth:1.0:attributeNamespace:uri") 
    682          
    683         for val in ['member', 'student']: 
    684            attributeValElem = ElementTree.SubElement(attributeElem, 
    685                                                      '{%s}AttributeValue' % \ 
    686                                                      Assertion.ns) 
    687            attributeValElem.text = val 
    688  
    689  
    690     def applyEnvelopedSignature(self, **xmlSecDocKw): 
    691         '''Override super class version to ensure settings have been parsed  
    692         into a DOM object ready for signature 
    693          
    694         @param **xmlSecDocKw: keywords applying to  
    695         XMLSecDoc.applyEnvelopedSignature() 
    696         '''        
    697         self.parse(self._createXML()) 
    698         super(Assertion, self).applyEnvelopedSignature(**xmlSecDocKw) 
    699  
    700         
    701     #_________________________________________________________________________     
    702     def setValidityTime(self, 
    703                         dtNotBefore=None,  
    704                         dtNotAfter=None,  
    705                         lifetime=None, 
    706                         notBeforeOffset=None): 
    707         """Set the notBefore and notAfter times which determine the window for 
    708         which the Attribute Assertion is valid.  These times are set as 
    709         datetime types and also the correct string format settings are made  
    710         ready for output. 
    711  
    712         Nb. use UTC time.  lifetime and notBeforeOffset are in seconds 
    713          
    714         @param dtNotBefore: not before time as datetime type.  If omitted, 
    715         it defaults to the current time 
    716          
    717         @param dtNotAfter: not after time as datetime type.  Defaults to  
    718         self.__dtNotBefore + self.__lifetime.  If dtNotAfter is set it will 
    719         reset self.__lifetime to self.__dtNotAfter - self.dtNotBefore 
    720          
    721         @param lifetime: lifetime for certificate in seconds i.e. not after 
    722         time - not before time.  If dtNotAfter is set then this keyword will 
    723         be ignored. 
    724          
    725         @param notBeforeOffset: skew the not before time by some offset.  This 
    726         is useful in cases where system clocks are not correctly synchronized 
    727         between different hosts.  Set a negative value to skew the offset 
    728         backward in time. 
    729         """ 
    730  
    731         if dtNotBefore is not None: 
    732             if not isinstance(dtNotBefore, datetime): 
    733                 raise AssertionError(\ 
    734                                 "Input not before time must be datetime type") 
    735              
    736             self.__dtNotBefore = dtNotBefore 
    737              
    738         else: 
    739             # Use current UTC +/- offset 
    740             self.__dtNotBefore = datetime.utcnow() 
    741              
    742         if notBeforeOffset is not None: 
    743             self.__dtNotBefore += timedelta(seconds=notBeforeOffset) 
    744              
    745  
    746         if dtNotAfter is not None: 
    747             if not isinstance(dtNotAfter, datetime): 
    748                 raise AssertionError(\ 
    749                                 "Input not after time must be datetime type") 
    750  
    751             # Use input Not After time to calculate a new lifetime setting 
    752             dtDeltaLifeTime = dtNotAfter - self.__dtNotBefore 
    753             if dtDeltaLifeTime < timedelta(0): 
    754                 raise AssertionError("Input Not After time is invalid %s" % \ 
    755                                    str(dtNotAfter)) 
    756  
    757             self.__lifetime = dtDeltaLifeTime.days*86400 + \ 
    758                               dtDeltaLifeTime.seconds 
    759  
    760             self.__dtNotAfter = dtNotAfter 
    761              
    762         else: 
    763             # Check for input certificate life time interval 
    764             if lifetime is not None: 
    765                 self.__lifetime = lifetime 
    766                  
    767             try: 
    768                 # Make a time delta object from the lifetime expressed in 
    769                 # seconds 
    770                 dtDeltaLifeTime = timedelta(seconds=self.__lifetime) 
    771             except Exception, e: 
    772                 raise AssertionError("Invalid Certificate lifetime set %.3f" %\ 
    773                                    self.__lifetime) 
    774              
    775             # Add certificate lifetime to calculate not after time 
    776             self.__dtNotAfter = self.__dtNotBefore + dtDeltaLifeTime 
    777  
    778          
    779         self.__dat['validity']['notBefore'] = \ 
    780                                     self.datetime2timeStr(self.__dtNotBefore) 
    781          
    782         self.__dat['validity']['notAfter'] = \ 
    783                                     self.datetime2timeStr(self.__dtNotAfter) 
    784  
    785  
    786     #_________________________________________________________________________     
    787     def datetime2timeStr(self, dtVal): 
    788         """Convert a datetime object to a notBefore/notAfter time string 
    789          
    790         @param dtVal: input datetime 
    791          
    792         @rtype: string 
    793         @return: datetime converted into correct string format for Assertion""" 
    794  
    795         if not isinstance(dtVal, datetime): 
    796             raise AssertionError("Invalid datetime object for " 
    797                                           "conversion to string") 
    798          
    799         # Convert from 1-12 to 0-11 month format used in XML file 
    800         #lDateTime = list(dtVal.utctimetuple()[0:6]) 
    801  
    802         #lDateTime[1] -= 1 
    803  
    804         # Format as a single string with no commas or brackets 
    805         #return ''.join(re.findall('[0-9 ]', str(lDateTime))) 
    806  
    807         # Use 1-12 format 
    808         # P J Kershaw 09/06/05 
    809         return dtVal.strftime("%Y %m %d %H %M %S") 
    810  
    811  
    812     #_________________________________________________________________________     
    813     def timeStr2datetime(self, sTime): 
    814         """Convert a notBefore/notAfter time string to a datetime object 
    815          
    816         @param sTime: time in string format as used by Assertion 
    817         @rtype: datetime 
    818         @return: datetime type equivalent of string input""" 
    819  
    820         try: 
    821             lTime = strptime(sTime, "%Y %m %d %H %M %S") 
    822             return datetime(*lTime[0:6]) 
    823          
    824         except Exception, e: 
    825             raise AssertionError(\ 
    826                 "Error converting time string into datetime object: %s" % e) 
    827          
    828  
    829     #_________________________________________________________________________     
    830     def isValidTime(self, dtNow=None, raiseExcep=False): 
    831         """Check Attribute Assertion for expiry.  Set raiseExcep to True 
    832         to raise an exception with a message indicating the nature of the  
    833         time error 
    834          
    835         @param dtNow: the time to test against in datetime format.  This time 
    836         must be within the range of the not before and not after times in 
    837         order for the certificate to be valid.  Defaults to the current  
    838         system time 
    839          
    840         @param raiseExcep: boolean set to True to raise an exception if the  
    841         time is invalid.  Defaults to False in which case no exception is 
    842         raised if the time is invalid, instead False is returned 
    843          
    844         @rtype: bool 
    845         @return: boolean True if time is valid, False if invalid.  Also see 
    846         raiseExcep keyword above.""" 
    847  
    848         if not isinstance(self.__dtNotBefore, datetime): 
    849             raise AssertionError("Not Before datetime is not set") 
    850  
    851         if not isinstance(self.__dtNotAfter, datetime): 
    852             raise AssertionError("Not After datetime is not set") 
    853         
    854         if dtNow is None: 
    855             dtNow = datetime.utcnow() 
    856          
    857         # Testing only 
    858         # 
    859         # P J Kershaw 02/03/06 
    860         #notBefore = self.__dtNotBefore 
    861         #notAfter = self.__dtNotAfter 
    862         #print "Valid Time? = %d" % (dtNow > notBefore and dtNow < notAfter) 
    863         if raiseExcep: 
    864             if dtNow < self.__dtNotBefore: 
    865                 raise AssertionError("Current time %s " % \ 
    866                            dtNow.strftime("%d/%m/%Y %H:%M:%S") + \ 
    867                            "is before Attribute Assertion's " + \ 
    868                            "not before time of %s" % \ 
    869                            self.__dtNotBefore.strftime("%d/%m/%Y %H:%M:%S")) 
    870              
    871             if dtNow > self.__dtNotAfter: 
    872                 raise AssertionError("Current time %s " % \ 
    873                            dtNow.strftime("%d/%m/%Y %H:%M:%S") + \ 
    874                            "is after Attribute Assertion's " + \ 
    875                            "expiry time of %s" % \ 
    876                            self.__dtNotBefore.strftime("%d/%m/%Y %H:%M:%S"))                
    877              
    878             return True         
    879         else: 
    880             return dtNow > self.__dtNotBefore and dtNow < self.__dtNotAfter 
    881          
    882          
    883     #_________________________________________________________________________     
    884     def isValidVersion(self): 
    885         """Check Attribute Assertion XML file version 
    886          
    887         @rtype: bool 
    888         @return: boolean True if certificate version matches the expected one, 
    889         False otherwise. 
    890         """ 
    891         return self.__dat['version'] == Assertion.version 
    892  
    893  
    894     #_________________________________________________________________________     
    895     def isValid(self, 
    896                 raiseExcep=False, 
    897                 chkTime=True, 
    898                 chkVersion=True, 
    899                 chkProvenance=True, 
    900                 chkSig=True, 
    901                 **xmlSecDocKw): 
    902         """Check Attribute Assertion is valid: 
    903  
    904         - Time validity is OK 
    905         - XML file version is OK 
    906         - valid provenance setting 
    907         - Signature is valid. 
    908  
    909         @param chkTime: set to True to do time validity check (default is  
    910         True) 
    911  
    912         @param chkVersion: set to True to Attribute Assertion file 
    913         version (default is True) 
    914  
    915         @param chkProvenanceset to True to check provenance value is valid 
    916         (default is True) 
    917  
    918         @param chkSig: set to True to check digital signature - for 
    919         this certFilePathList must contain the root certificate of the X.509  
    920         certificate used to sign the Assertion.  Alternatively, certFilePathList 
    921         can be set via __init__ (default chkSig value is True) 
    922                                  
    923         @param raiseExcep: set to true to raise an exception if invalid  
    924         instead of returning False.  Default is to set this flag to False. 
    925  
    926         @param **xmlSecDocKw: Also accepts keyword arguments corresponding to  
    927         XMLSecDoc.verifyEnvelopedSignature(). 
    928          
    929         @rtype: bool 
    930         @return: boolean True if certificate is valid, False otherwise.  Also 
    931         see explanation for raiseExcep keyword.                          
    932         """ 
    933  
    934         # Carry out checks in turn - Specific exception error messages are 
    935         # raised if flag is set 
    936         if chkTime and not self.isValidTime(raiseExcep=raiseExcep): 
    937             return False 
    938  
    939         if chkVersion and not self.isValidVersion(): 
    940             if raiseExcep: 
    941                 raise AssertionError('Attribute Assertion version is ' + \ 
    942                                    self.__dat['version'] + ' but version ' + \ 
    943                                    Assertion.version + ' expected') 
    944             return False 
    945  
    946         if chkProvenance and not self.isValidProvenance(): 
    947             if raiseExcep: 
    948                 raise AssertionError(\ 
    949                     "Attribute Assertion Provenance must be set to \"" + \ 
    950                     "\" or \"".join(Assertion.__validProvenanceSettings) + "\"") 
    951             return False 
    952  
    953         # Handle exception from XMLSecDocc.isValidSig() regardless of 
    954         # raiseExcep flag setting 
    955         if chkSig: 
    956             try: 
    957                 self.verifyEnvelopedSignature(**xmlSecDocKw) 
    958          
    959             except InvalidSignature, e: 
    960                  if raiseExcep: 
    961                      raise AssertionError(e) 
    962                  else: 
    963                      return False 
    964                  
    965         # All tests passed 
    966         return True 
    967  
    968  
    969 #_____________________________________________________________________________ 
    970 # Alternative Assertion constructors 
    971 def AssertionRead(filePath): 
    972     """Read an attribute assertion from a file""" 
    973      
    974     attributeAssertion = Assertion(filePath=filePath) 
    975     attributeAssertion.read() 
    976      
    977     return attributeAssertion 
    978  
    979 #_________________________________________________________________________     
    980 def AssertionParse(attributeAssertionTxt): 
    981     """Parse an attribute assertion from a string""" 
    982      
    983     attributeAssertion = Assertion() 
    984     attributeAssertion.parse(attributeAssertionTxt) 
    985      
    986     return attributeAssertion 
    987  
    988 def create(): 
    989     ns = "urn:oasis:names:tc:SAML:1.0:assertion" 
    990     nsPfx = "saml" 
    991     issuer = 'http:#badc.nerc.ac.uk' 
    992     attributeName="urn:mace:dir:attribute-def:eduPersonAffiliation" 
    993     attributeNS="urn:mace:shibboleth:1.0:attributeNamespace:uri" 
    994      
    995     assertionElem = ElementTree.Element("{%s}Assertion" % ns) 
    996     assertionElem.set('xmlns:%s' % nsPfx, ns) 
    997     assertionElem.set('MajorVersion', '1') 
    998     assertionElem.set('MinorVersion', '1') 
    999     assertionElem.set('Issuer', issuer) 
    1000  
    1001     conditionsElem = ElementTree.SubElement(assertionElem, 
    1002                                             '{%s}Conditions' % ns)     
    1003     conditionsElem.set('NotBefore', 'x') 
    1004     conditionsElem.set('NotAfter', 'y') 
    1005      
    1006     attributeStatementElem = ElementTree.SubElement(assertionElem, 
    1007                                                     '{%s}AttributeStatement' %\ 
    1008                                                     ns) 
    1009      
    1010     subjectElem = ElementTree.SubElement(attributeStatementElem, 
    1011                                          '{%s}Subject' % ns) 
    1012      
    1013     attributeElem = ElementTree.SubElement(attributeStatementElem, 
    1014                                            '{%s}Attribute' % ns) 
    1015  
    1016     attributeElem.set('AttributeName', 
    1017                       "urn:mace:dir:attribute-def:eduPersonAffiliation") 
    1018     attributeElem.set('AttributeNamespace', 
    1019                       "urn:mace:shibboleth:1.0:attributeNamespace:uri") 
    1020      
    1021     for val in ['member', 'student']: 
    1022        attributeValElem = ElementTree.SubElement(attributeElem, 
    1023                                                  '{%s}AttributeValue' % ns) 
    1024        attributeValElem.text = val 
    1025  
    1026     return tostring(assertionElem) 
    1027  
    1028553class Attribute(SAMLObject): 
    1029554    '''SAML 2.0 Core Attribute.''' 
     
    1046571 
    1047572    # Name of the Name attribute.  
    1048     NAME_ATTTRIB_NAME = "Name" 
     573    NAME_ATTRIB_NAME = "Name" 
    1049574 
    1050575    # Name for the NameFormat attribute.  
     
    1135660                                 SAMLConstants.SAML20_PREFIX) 
    1136661     
    1137 class Statement(object): 
     662     
     663class XSStringAttributeValue(AttributeValue): 
     664 
     665    # Local name of the XSI type 
     666    TYPE_LOCAL_NAME = "string" 
     667         
     668    # QName of the XSI type 
     669    TYPE_NAME = QName(XMLConstants.XSD_NS,  
     670                      TYPE_LOCAL_NAME,  
     671                      XMLConstants.XSD_PREFIX) 
     672   
     673    def __init__(self): 
     674        self.__value = None 
     675         
     676    def _getValue(self): 
     677        return self.__value 
     678         
     679    def _setValue(self, value): 
     680        if not isinstance(value, basestring): 
     681            raise TypeError("Input must be a basestring derived type, got %r" % 
     682                            value.__class__) 
     683             
     684        self.__value = value 
     685 
     686    value = property(fget=_getValue, fset=_setValue, doc="string value")   
     687    
     688class Statement(SAMLObject): 
    1138689    pass 
    1139690 
     
    1176727     
    1177728 
    1178 class XSGroupRole(object):  
     729class XSGroupRole(SAMLObject):  
    1179730     '''ESG Specific Group/Role attribute.  ESG attribute permissions are 
    1180731     organised into group/role pairs 
     
    1212763         # no children 
    1213764         return None 
    1214   
    1215  
    1216 # TODO: refactor print helper 
    1217 from StringIO import StringIO 
    1218 def elem2String(elem, **kw): 
    1219 #    f = StringIO() 
    1220 #    ElementC14N.write(ElementC14N.build_scoped_tree(elem),  
    1221 #                      f,  
    1222 #                      **kw) 
    1223     f = StringIO() 
    1224     ElementC14N.write(ElementTree.ElementTree(element=elem), f) 
    1225     return f.getvalue() 
    1226  
    1227 class AssertionMarshaller(object): 
    1228     issueInstantFmt = "%Y-%m-%dT%H:%M:%SZ" 
    1229      
    1230     @classmethod 
    1231     def issueInstantDatetime2Str(cls, dtIssueInstant): 
    1232         """Convert issue instant datetime to correct string type for output 
    1233         @type dtIssueInstant: datetime.datetime 
    1234         @param dtIssueInstant: issue instance as a datetime 
    1235         @rtype: basestring 
    1236         @return: issue instance as a string 
    1237         """ 
    1238         return dtIssueInstant.strftime(AssertionMarshaller.issueInstantFmt) 
    1239      
    1240 class AssertionETreeMarshaller(AssertionMarshaller): 
    1241     """ElementTree based marshalling of Assertion class into XML 
    1242     """ 
    1243     def __init__(self): 
    1244         pass 
    1245  
    1246     def marshall(self, assertion): 
    1247         """Make a tree of a XML elements based on the assertion""" 
    1248         if not isinstance(assertion, Assertion): 
    1249             raise TypeError("Expecting %r type got: %r"%(Assertion, assertion)) 
    1250          
    1251         issueInstant = AssertionETreeMarshaller.issueInstantDatetime2Str( 
    1252                                                         assertion.issueInstant) 
    1253         attrib = { 
    1254             Assertion.ID_ATTRIB_NAME: assertion.id, 
    1255             Assertion.ISSUE_INSTANT_ATTRIB_NAME: issueInstant, 
    1256              
    1257             # Nb. Version is a SAMLVersion instance and requires explicit cast 
    1258             Assertion.VERSION_ATTRIB_NAME: str(assertion.version) 
    1259         } 
    1260         assertionElem=ElementTree.Element(str(Assertion.DEFAULT_ELEMENT_NAME),  
    1261                                           **attrib) 
    1262          
    1263         assertionElem.set("xmlns:%s" % Assertion.DEFAULT_ELEMENT_NAME.prefix,  
    1264                           Assertion.DEFAULT_ELEMENT_NAME.namespaceURI) 
    1265          
    1266         self.marshallAttributes(assertion, assertionElem) 
    1267          
    1268         print canonicalize(assertionElem) 
    1269         return assertionElem 
    1270  
    1271     def marshallAttributes(self, assertion, assertionElem): 
    1272         attributeStatementETreeMarshaller = AttributeStatementETreeMarshaller() 
    1273          
    1274         for attributeStatement in assertion.attributeStatements: 
    1275             attributeStatementElem = attributeStatementETreeMarshaller( 
    1276                                                         attributeStatement) 
    1277             assertionElem.append(attributeStatementElem) 
    1278              
    1279  
    1280 class AttributeStatementETreeMarshaller(object): 
    1281     """Marshal AttributeStatement object into an ElementTree XML  
    1282     representation""" 
    1283     def marshall(self, attributeStatement, makeNsDeclaration=False, **attrib): 
    1284         if not isinstance(attributeStatement, AttributeStatement): 
    1285             raise TypeError("Expecting %r type got: %r" % (AttributeStatement,  
    1286                                                            attributeStatement)) 
    1287              
    1288         attributeStatementElem = ElementTree.Element( 
    1289                                 str(AttributeStatement.DEFAULT_ELEMENT_NAME),  
    1290                                 **attrib) 
    1291         if makeNsDeclaration: 
    1292             attributeStatementElem.set( 
    1293                 "xmlns:%s" % AttributeStatement.DEFAULT_ELEMENT_NAME.prefix,  
    1294                 AttributeStatement.DEFAULT_ELEMENT_NAME.namespaceURI) 
    1295          
    1296         return attributeStatementElem 
    1297  
    1298     def marshallAttributes(self, attributeStatement, attributeStatement): 
    1299         for attribute in attributeStatement.attributes: 
    1300             attributeElem = attributeStatementETreeMarshaller( 
    1301                                                         attributeStatement) 
    1302             assertionElem.append(attributeStatementElem) 
    1303  
    1304 class AttributeETreeMarshaller(object): 
    1305     """Marshal Attribute object into an ElementTree XML representation"""  
    1306  
    1307 class _AssertionTmp:         
    1308     def parse(self, source): 
    1309         """Read in the XML from source 
    1310         @type source: basestring/file 
    1311         @param source: file path to XML file or file object 
    1312         """ 
    1313         elem = ElementTree.parse(source) 
    1314         root = elem.getroot() 
    1315          
    1316         return root 
    1317          
    1318     def serialise(self): 
    1319         return txt 
  • TI12-security/trunk/python/ndg.security.test/ndg/security/test/unit/attributeauthority/saml/test_samlinterface.py

    r5504 r5507  
    1111import logging 
    1212logging.basicConfig(level=logging.DEBUG) 
    13 from ndg.security.test.unit import BaseTestCase 
    14 from ndg.security.common.saml import Attribute, AttributeValue, \ 
    15     AttributeStatement, Assertion, SAMLVersion, XSGroupRole, \ 
    16     AssertionETreeMarshaller 
     13     
    1714from datetime import datetime 
    1815import base64  
    1916import os 
    2017from uuid import uuid4 
     18 
     19from ndg.security.test.unit import BaseTestCase 
     20 
     21from ndg.security.common.saml import Assertion, Attribute, AttributeValue, \ 
     22    AttributeStatement, SAMLVersion, XSGroupRole, XSStringAttributeValue 
     23from ndg.security.common.saml.xml.etree import AssertionETreeObject 
    2124 
    2225class Request(object): 
     
    3538    def test01(self): 
    3639        samlUtil = SamlUtil() 
     40         
    3741        samlUtil.firstName = "Philip" 
    3842        samlUtil.lastName = "Kershaw" 
     43        samlUtil.emailAddress = "p.j.k@somewhere" 
     44         
    3945        samlUtil.groupRoleList = [ 
    4046            ('urn:badc:security:authz:1.0:attr:org:id:badc',  
     
    4248        ] 
    4349        assertion = samlUtil.buildAssertion() 
    44         marshaller = AssertionETreeMarshaller() 
    45         assertionElem = marshaller.marshall(assertion) 
     50        assertionETreeObject = AssertionETreeObject() 
     51        assertionElem = assertionETreeObject.create(assertion) 
     52        print(">\n".join(assertionETreeObject.serialize().split(">"))) 
    4653 
    4754 
     
    5259        self.firstName = None 
    5360        self.lastName = None 
    54      
     61        self.emailAddress = None 
     62         
    5563        self.groupRoleList = [] 
    5664        self.miscAttrList = [] 
     
    7886            fnAttribute.friendlyName = "FirstName" 
    7987 
    80             fNameStr = (AttributeValue.DEFAULT_ELEMENT_NAME,  
    81                         "string",  
    82                         self.firstName) 
    83             fnAttribute.attributeValues.append(fNameStr) 
     88            firstName = XSStringAttributeValue() 
     89            firstName.value = self.firstName 
     90            fnAttribute.attributeValues.append(firstName) 
    8491 
    8592            attributeStatement.attributes.append(fnAttribute) 
     
    93100            lnAttribute.friendlyName = "LastName" 
    94101 
    95             lNameStr = (AttributeValue.DEFAULT_ELEMENT_NAME,  
    96                         "string", 
    97                         self.lastName) 
    98             lnAttribute.attributeValues.append(lNameStr) 
     102            lastName = XSStringAttributeValue() 
     103            lastName.value = self.lastName 
     104            lnAttribute.attributeValues.append(lastName) 
    99105 
    100106            attributeStatement.attributes.append(lnAttribute) 
    101107         
    102         if len(self.groupRoleList) > 0: 
    103             # custom group/role attribute to be added to attr statement 
    104             groupRoleAttr = Attribute() 
    105             groupRoleAttr.name = "GroupRole" 
    106             groupRoleAttr.nameFormat = XSGroupRole.TYPE_LOCAL_NAME 
    107108 
    108             for group, role in self.groupRoleList: 
    109                 groupRoleValue = XSGroupRole("http://www.esg.org",  
    110                                              "groupRole",  
    111                                              "esg") 
     109        if self.emailAddress is not None: 
     110            # special case handling for 'LastName' attribute 
     111            emailAddressAttribute = Attribute() 
     112            emailAddressAttribute.name = "urn:esg:email:address" 
     113            emailAddressAttribute.nameFormat = \ 
     114                                    "http://www.w3.org/2001/XMLSchema#string" 
     115            emailAddressAttribute.friendlyName = "emailAddress" 
    112116 
    113                 groupRoleValue.group = group 
    114                 groupRoleValue.role = role 
     117            emailAddress = XSStringAttributeValue() 
     118            emailAddress.value = self.emailAddress 
     119            emailAddressAttribute.attributeValues.append(emailAddress) 
    115120 
    116                 groupRoleAttrValue = (AttributeValue.DEFAULT_ELEMENT_NAME, 
    117                                       groupRoleValue) 
    118  
    119                 groupRoleAttr.attributeValues.append(groupRoleAttrValue) 
    120              
    121             attributeStatement.attributes.append(groupRoleAttr) 
     121            attributeStatement.attributes.append(emailAddressAttribute) 
    122122         
    123         if len(self.miscAttrList) > 0: 
    124          
    125             for name, value in self.miscAttrList: 
    126                 attribute = Attribute() 
    127                 attribute.name = name 
    128                 attribute.nameFormat="http://www.w3.org/2001/XMLSchema#string" 
    129  
    130                 stringValue = (AttributeValue.DEFAULT_ELEMENT_NAME,  
    131                                "string", 
    132                                value) 
    133                 attribute.attributeValues.append(stringValue) 
    134  
    135                 # add all attributes to the attribute statement                                                                                                    
    136                 attributeStatement.attributes.append(attribute) 
     123#        if len(self.groupRoleList) > 0: 
     124#            # custom group/role attribute to be added to attr statement 
     125#            groupRoleAttr = Attribute() 
     126#            groupRoleAttr.name = "GroupRole" 
     127#            groupRoleAttr.nameFormat = XSGroupRole.TYPE_LOCAL_NAME 
     128# 
     129#            for group, role in self.groupRoleList: 
     130#                groupRoleValue = XSGroupRole("http://www.esg.org",  
     131#                                             "groupRole",  
     132#                                             "esg") 
     133# 
     134#                groupRoleValue.group = group 
     135#                groupRoleValue.role = role 
     136# 
     137#                groupRoleAttrValue = (AttributeValue.DEFAULT_ELEMENT_NAME, 
     138#                                      groupRoleValue) 
     139# 
     140#                groupRoleAttr.attributeValues.append(groupRoleAttrValue) 
     141#             
     142#            attributeStatement.attributes.append(groupRoleAttr) 
     143#         
     144#        if len(self.miscAttrList) > 0: 
     145#         
     146#            for name, value in self.miscAttrList: 
     147#                attribute = Attribute() 
     148#                attribute.name = name 
     149#                attribute.nameFormat="http://www.w3.org/2001/XMLSchema#string" 
     150# 
     151#                stringValue = (AttributeValue.DEFAULT_ELEMENT_NAME,  
     152#                               "string", 
     153#                               value) 
     154#                attribute.attributeValues.append(stringValue) 
     155# 
     156#                # add all attributes to the attribute statement                                                                                                    
     157#                attributeStatement.attributes.append(attribute) 
    137158             
    138159        assertion.attributeStatements.append(attributeStatement) 
Note: See TracChangeset for help on using the changeset viewer.