source: TI12-security/trunk/python/ndg_security_common/ndg/security/common/utils/m2crypto.py @ 6050

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg_security_common/ndg/security/common/utils/m2crypto.py@6050
Revision 6050, 20.9 KB checked in by pjkersha, 10 years ago (diff)
  • Enhanced ndg.security.common.saml_utils.bindings for better AttributeQuery? client interface
  • Refactored AttributeQuery? interface out of SamlCredentialWallet? - it belongs in the above only.
  • deleted ndg.security.common.credentialrepository package - relic from NDG1
Line 
1"""Extend M2Crypto SSL functionality for cert verification and custom
2timeout settings.
3
4NERC DataGrid Project"""
5__author__ = "P J Kershaw"
6__date__ = "02/07/07"
7__copyright__ = "(C) 2009 Science and Technology Facilities Council"
8__license__ = "BSD - see LICENSE file in top-level directory"
9__contact__ = "Philip.Kershaw@stfc.ac.uk"
10__revision__ = '$Id: $'
11import logging
12log = logging.getLogger(__name__)
13
14import os
15import re
16
17from M2Crypto import SSL, X509
18from M2Crypto.httpslib import HTTPSConnection as _HTTPSConnection
19
20from ndg.security.common.X509 import X509Cert, X509Stack, X500DN
21
22
23class InvalidCertSignature(SSL.Checker.SSLVerificationError):
24    """Raise if verification against CA cert public key fails"""
25
26
27class InvalidCertDN(SSL.Checker.SSLVerificationError):
28    """Raise if verification against a list acceptable DNs fails"""
29   
30
31class HostCheck(SSL.Checker.Checker, object):
32    """Override SSL.Checker.Checker to enable alternate Common Name
33    setting match for peer cert"""
34
35    def __init__(self, 
36                 peerCertDN=None, 
37                 peerCertCN=None,
38                 acceptedDNs=[], 
39                 caCertList=[],
40                 caCertFilePathList=[], 
41                 **kw):
42        """Override parent class __init__ to enable setting of myProxyServerDN
43        setting
44       
45        @type peerCertDN: string/list
46        @param peerCertDN: Set the expected Distinguished Name of the
47        server to avoid errors matching hostnames.  This is useful
48        where the hostname is not fully qualified. 
49
50        *param acceptedDNs: a list of acceptable DNs.  This enables validation
51        where the expected DN is where against a limited list of certs.
52       
53        @type peerCertCN: string
54        @param peerCertCN: enable alternate Common Name to peer
55        hostname
56       
57        @type caCertList: list type of M2Crypto.X509.X509 types
58        @param caCertList: CA X.509 certificates - if set the peer cert's
59        CA signature is verified against one of these.  At least one must
60        verify
61       
62        @type caCertFilePathList: list string types
63        @param caCertFilePathList: same as caCertList except input as list
64        of CA cert file paths"""
65       
66        SSL.Checker.Checker.__init__(self, **kw)
67       
68        self.peerCertDN = peerCertDN
69        self.peerCertCN = peerCertCN
70        self.acceptedDNs = acceptedDNs
71       
72        if caCertList:
73            self.caCertList = caCertList
74        elif caCertFilePathList:
75            self.caCertFilePathList = caCertFilePathList
76        else:
77            # Set default to enable len() test in __call__
78            self.__caCertStack = ()
79           
80    def __call__(self, peerCert, host=None):
81        """Carry out checks on server ID
82        @param peerCert: MyProxy server host certificate as M2Crypto.X509.X509
83        instance
84        @param host: name of host to check
85        """
86        if peerCert is None:
87            raise SSL.Checker.NoCertificate('SSL Peer did not return '
88                                            'certificate')
89
90        peerCertDN = '/'+peerCert.get_subject().as_text().replace(', ', '/')
91        try:
92            SSL.Checker.Checker.__call__(self, peerCert, host=self.peerCertCN)
93           
94        except SSL.Checker.WrongHost, e:
95            # Try match against peerCertDN set   
96            if peerCertDN != self.peerCertDN:
97                raise e
98
99        # At least one match should be found in the list - first convert to
100        # NDG X500DN type to allow per field matching for DN comparison
101        peerCertX500DN = X500DN(dn=peerCertDN)
102       
103        if self.acceptedDNs:
104           matchFound = False
105           for dn in self.acceptedDNs:
106               x500dn = X500DN(dn=dn)
107               if x500dn == peerCertX500DN:
108                   matchFound = True
109                   break
110               
111           if not matchFound:
112               raise InvalidCertDN('Peer cert DN "%s" doesn\'t match '
113                                   'verification list' % peerCertDN)
114
115        if len(self.__caCertStack) > 0:
116            try:
117                self.__caCertStack.verifyCertChain(
118                           x509Cert2Verify=X509Cert(m2CryptoX509=peerCert))
119            except Exception, e:
120                raise InvalidCertSignature("Peer certificate verification "
121                                           "against CA certificate failed: %s" 
122                                           % e)
123             
124        # They match - drop the exception and return all OK instead         
125        return True
126     
127    def __setCACertList(self, caCertList):
128        """Set list of CA certs - peer cert must validate against at least one
129        of these"""
130        self.__caCertStack = X509Stack()
131        for caCert in caCertList:
132            self.__caCertStack.push(caCert)
133
134    caCertList = property(fset=__setCACertList,
135                          doc="list of CA certificates - the peer certificate "
136                              "must validate against one")
137
138    def __setCACertsFromFileList(self, caCertFilePathList):
139        '''Read CA certificates from file and add them to the X.509
140        stack
141       
142        @type caCertFilePathList: basestring, list or tuple
143        @param caCertFilePathList: list of file paths for CA certificates to
144        be used to verify certificate used to sign message.  If a single
145        string item is input then this is converted into a tuple
146        '''
147        if isinstance(caCertFilePathList, basestring):
148            caCertFilePathList = (caCertFilePathList,)
149           
150        elif not isinstance(caCertFilePathList, (list, tuple)):
151            raise TypeError('Expecting a basestring, list or tuple type for '
152                            '"caCertFilePathList"')
153
154        self.__caCertStack = X509Stack()
155
156        for caCertFilePath in caCertFilePathList:
157            self.__caCertStack.push(X509.load_cert(caCertFilePath))
158       
159    caCertFilePathList = property(fset=__setCACertsFromFileList,
160                                  doc="list of CA certificate file paths - "
161                                      "peer certificate must validate against "
162                                      "one")
163
164
165class HTTPSConnection(_HTTPSConnection):
166    """Modified version of M2Crypto equivalent to enable custom checks with
167    the peer and timeout settings
168   
169    @type defReadTimeout: M2Crypto.SSL.timeout
170    @cvar defReadTimeout: default timeout for read operations
171    @type defWriteTimeout: M2Crypto.SSL.timeout
172    @cvar defWriteTimeout: default timeout for write operations"""   
173    defReadTimeout = SSL.timeout(sec=20.)
174    defWriteTimeout = SSL.timeout(sec=20.)
175   
176    def __init__(self, *args, **kw):
177        '''Overload to enable setting of post connection check
178        callback to SSL.Connection
179       
180        type *args: tuple
181        param *args: args which apply to M2Crypto.httpslib.HTTPSConnection
182        type **kw: dict
183        param **kw: additional keywords
184        @type postConnectionCheck: SSL.Checker.Checker derivative
185        @keyword postConnectionCheck: set class for checking peer
186        @type readTimeout: M2Crypto.SSL.timeout
187        @keyword readTimeout: readTimeout - set timeout for read
188        @type writeTimeout: M2Crypto.SSL.timeout
189        @keyword writeTimeout: similar to read timeout'''
190       
191        self._postConnectionCheck = kw.pop('postConnectionCheck',
192                                           SSL.Checker.Checker)
193       
194        if 'readTimeout' in kw:
195            if not isinstance(kw['readTimeout'], SSL.timeout):
196                raise AttributeError("readTimeout must be of type "
197                                     "M2Crypto.SSL.timeout")
198            self.readTimeout = kw.pop('readTimeout')
199        else:
200            self.readTimeout = HTTPSConnection.defReadTimeout
201             
202        if 'writeTimeout' in kw:
203            if not isinstance(kw['writeTimeout'], SSL.timeout):
204                raise AttributeError("writeTimeout must be of type "
205                                     "M2Crypto.SSL.timeout") 
206            self.writeTimeout = kw.pop('writeTimeout')
207        else:
208            self.writeTimeout = HTTPSConnection.defWriteTimeout
209   
210        self._clntCertFilePath = kw.pop('clntCertFilePath', None)
211        self._clntPriKeyFilePath = kw.pop('clntPriKeyFilePath', None)
212       
213        _HTTPSConnection.__init__(self, *args, **kw)
214       
215        # load up certificate stuff
216        if (self._clntCertFilePath is not None and 
217            self._clntPriKeyFilePath is not None):
218            self.ssl_ctx.load_cert(self._clntCertFilePath, 
219                                   self._clntPriKeyFilePath)
220       
221       
222    def connect(self):
223        '''Overload M2Crypto.httpslib.HTTPSConnection to enable
224        custom post connection check of peer certificate and socket timeout'''
225
226        self.sock = SSL.Connection(self.ssl_ctx)
227        self.sock.set_post_connection_check_callback(self._postConnectionCheck)
228
229        self.sock.set_socket_read_timeout(self.readTimeout)
230        self.sock.set_socket_write_timeout(self.writeTimeout)
231
232        self.sock.connect((self.host, self.port))
233
234    def putrequest(self, method, url, **kw):
235        '''Overload to work around bug with unicode type URL'''
236        url = str(url)
237        _HTTPSConnection.putrequest(self, method, url, **kw) 
238         
239             
240class SSLContextProxy(object):
241    """Holder for M2Crypto.SSL.Context parameters"""
242    PRE_VERIFY_FAIL, PRE_VERIFY_OK = range(2)
243   
244    SSL_CERT_FILEPATH_OPTNAME = "sslCertFilePath"
245    SSL_PRIKEY_FILEPATH_OPTNAME = "sslPriKeyFilePath"
246    SSL_PRIKEY_PWD_OPTNAME = "sslPriKeyPwd"
247    SSL_CACERT_FILEPATH_OPTNAME = "sslCACertFilePath"
248    SSL_CACERT_DIRPATH_OPTNAME = "sslCACertDir"
249    SSL_VALID_DNS_OPTNAME = "sslValidDNs"
250   
251    OPTNAMES = (
252        SSL_CERT_FILEPATH_OPTNAME,
253        SSL_PRIKEY_FILEPATH_OPTNAME,
254        SSL_PRIKEY_PWD_OPTNAME,
255        SSL_CACERT_FILEPATH_OPTNAME,
256        SSL_CACERT_DIRPATH_OPTNAME,
257        SSL_VALID_DNS_OPTNAME
258    )
259   
260    __slots__ = OPTNAMES
261    __slots__ += tuple(["_SSLContextProxy__%s" % name for name in __slots__])
262    del name
263   
264    VALID_DNS_PAT = re.compile(',\s*')
265   
266    def __init__(self):
267        self.__sslCertFilePath = None
268        self.__sslPriKeyFilePath = None
269        self.__sslPriKeyPwd = None
270        self.__sslCACertFilePath = None
271        self.__sslCACertDir = None
272        self.__sslValidDNs = []
273
274    def createCtx(self, **kw):
275        """Create an M2Crypto SSL Context from this objects properties
276        @rtype: M2Crypto.SSL.Context
277        @return M2Crypto SSL context object
278        """
279        ctx = SSL.Context(**kw)
280        if self.sslCertFilePath and self.sslPriKeyFilePath:
281            ctx.load_cert(self.sslCertFilePath, 
282                          self.__sslPriKeyFilePath, 
283                          lambda *arg, **kw: self.sslPriKeyPwd)
284       
285        if self.sslCACertFilePath or self.sslCACertDir:
286            ctx.load_verify_locations(self.sslCACertFilePath, 
287                                      self.sslCACertDir)
288           
289        ctx.set_verify(SSL.verify_peer|SSL.verify_fail_if_no_peer_cert, 9,
290                       callback=self.createVerifySSLPeerCertCallback())
291           
292        return ctx
293 
294    def copy(self, sslCtxProxy):
295        """Copy settings from another context object
296        """
297        if not isinstance(sslCtxProxy, SSLContextProxy):
298            raise TypeError('Expecting %r for copy method input object; '
299                            'got %r' % (SSLContextProxy, type(sslCtxProxy)))
300       
301        for name in SSLContextProxy.OPTNAMES:
302            setattr(self, name, getattr(sslCtxProxy, name))
303           
304    def createVerifySSLPeerCertCallback(self):
305        """Create a callback function to enable the DN of the peer in an SSL
306        connection to be verified against a whitelist. 
307       
308        Nb. Making this function within the scope of a method of the class to
309        enables to access instance variables
310        """
311       
312        def _verifySSLPeerCertCallback(preVerifyOK, x509StoreCtx):
313            '''SSL verify callback function used to control the behaviour when
314            the SSL_VERIFY_PEER flag is set.  See:
315           
316            http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
317           
318            This implementation applies verification in order to check the DN
319            of the peer certificate against a whitelist
320           
321            @type preVerifyOK: int
322            @param preVerifyOK: If a verification error is found, this
323            parameter will be set to 0
324            @type x509StoreCtx: M2Crypto.X509.X509_Store_Context
325            @param x509StoreCtx: locate the certificate to be verified and
326            perform additional verification steps as needed
327            @rtype: int
328            @return: controls the strategy of the further verification process.
329            - If verify_callback returns 0, the verification process is
330            immediately stopped with "verification failed" state. If
331            SSL_VERIFY_PEER is set, a verification failure alert is sent to the
332            peer and the TLS/SSL handshake is terminated.
333            - If verify_callback returns 1, the verification process is
334            continued.
335            If verify_callback always returns 1, the TLS/SSL handshake will not
336            be terminated with respect to verification failures and the
337            connection
338            will be established. The calling process can however retrieve the
339            error code of the last verification error using
340            SSL_get_verify_result or by maintaining its own error storage
341            managed by verify_callback.
342            '''
343            if preVerifyOK == 0:
344                # Something is wrong with the certificate don't bother
345                # proceeding any further
346                log.error("verifyCallback: pre-verify OK flagged an error "
347                          "with the peer certificate, returning error state "
348                          "to caller ...")
349                return preVerifyOK
350           
351            x509CertChain = x509StoreCtx.get1_chain()
352            for cert in x509CertChain:
353                x509Cert = X509Cert.fromM2Crypto(cert)
354                if x509Cert.dn in self.sslValidDNs:
355                    return preVerifyOK
356               
357                subject = cert.get_subject()
358                dn = subject.as_text()
359                log.debug("verifyCallback: dn = %r", dn)
360               
361            # No match found so return fail status
362            return SSLContextProxy.PRE_VERIFY_FAIL
363       
364        return _verifySSLPeerCertCallback
365
366    def _getSSLCertFilePath(self):
367        return self.__sslCertFilePath
368   
369    def _setSSLCertFilePath(self, filePath):
370        "Set X.509 cert file path property method"
371       
372        if isinstance(filePath, basestring):
373            filePath = os.path.expandvars(filePath)
374           
375        elif filePath is not None:
376            raise TypeError("X.509 cert. file path must be a valid string")
377       
378        self.__sslCertFilePath = filePath
379               
380    sslCertFilePath = property(fset=_setSSLCertFilePath,
381                               fget=_getSSLCertFilePath,
382                               doc="File path to X.509 cert.")
383       
384    def _getSSLCACertFilePath(self):
385        """Get file path for list of CA cert or certs used to validate SSL
386        connections
387       
388        @rtype sslCACertFilePath: basestring
389        @return sslCACertFilePathList: file path to file containing concatenated
390        PEM encoded CA certificates."""
391        return self.__sslCACertFilePath
392   
393    def _setSSLCACertFilePath(self, value):
394        """Set CA cert file path
395       
396        @type sslCACertFilePath: basestring, list, tuple or None
397        @param sslCACertFilePath: file path to CA certificate file.  If None
398        then the input is quietly ignored."""
399        if isinstance(value, basestring):
400            self.__sslCACertFilePath = os.path.expandvars(value)
401           
402        elif value is None:
403            self.__sslCACertFilePath = value
404           
405        else:
406            raise TypeError("Input CA Certificate file path must be "
407                            "a valid string or None type: %r" % type(value)) 
408       
409       
410    sslCACertFilePath = property(fget=_getSSLCACertFilePath,
411                                 fset=_setSSLCACertFilePath,
412                                 doc="Path to file containing concatenated PEM "
413                                     "encoded CA Certificates - used for "
414                                     "verification of peer certs in SSL "
415                                     "connection")
416       
417    def _getSSLCACertDir(self):
418        """Get file path for list of CA cert or certs used to validate SSL
419        connections
420       
421        @rtype sslCACertDir: basestring
422        @return sslCACertDirList: directory containing PEM encoded CA
423        certificates."""
424        return self.__sslCACertDir
425   
426    def _setSSLCACertDir(self, value):
427        """Set CA cert or certs to validate AC signatures, signatures
428        of Attribute Authority SOAP responses and SSL connections where
429        AA SOAP service is run over SSL.
430       
431        @type sslCACertDir: basestring
432        @param sslCACertDir: directory containing CA certificate files.
433        """
434        if isinstance(value, basestring):
435            self.__sslCACertDir = os.path.expandvars(value)
436        elif value is None:
437            self.__sslCACertDir = value
438        else:
439            raise TypeError("Input CA Certificate directroy must be "
440                            "a valid string or None type: %r" % type(value))     
441       
442    sslCACertDir = property(fget=_getSSLCACertDir,
443                            fset=_setSSLCACertDir,
444                            doc="Path to directory containing PEM encoded CA "
445                                "Certificates used for verification of peer "
446                                "certs in SSL connection.   Files in the "
447                                "directory must be named with the form "
448                                "<hash>.0 where <hash> can be obtained using "
449                                "openssl x509 -in cert -hash -noout or using "
450                                "the c_rehash OpenSSL script")
451   
452    def _getSslValidDNs(self):
453        return self.__sslValidDNs
454
455    def _setSslValidDNs(self, value):
456        if isinstance(value, basestring): 
457            pat = SSLContextProxy.VALID_DNS_PAT
458            self.__sslValidDNs = [X500DN.fromString(dn) 
459                                  for dn in pat.split(value)]
460           
461        elif isinstance(value, (tuple, list)):
462            self.__sslValidDNs = [X500DN.fromString(dn) for dn in value]
463        else:
464            raise TypeError('Expecting list/tuple or basestring type for "%s" '
465                            'attribute; got %r' %
466                            (SSLContextProxy.SSL_VALID_DNS_OPTNAME, 
467                             type(value)))
468   
469    sslValidDNs = property(_getSslValidDNs, 
470                           _setSslValidDNs, 
471                           doc="whitelist of acceptable certificate "
472                               "Distinguished Names for peer certificates in "
473                               "SSL requests")
474
475    def _getSSLPriKeyFilePath(self):
476        return self.__sslPriKeyFilePath
477   
478    def _setSSLPriKeyFilePath(self, filePath):
479        "Set ssl private key file path property method"
480       
481        if isinstance(filePath, basestring):
482            filePath = os.path.expandvars(filePath)
483
484        elif filePath is not None:
485            raise TypeError("Private key file path must be a valid "
486                            "string or None type")
487       
488        self.__sslPriKeyFilePath = filePath
489       
490    sslPriKeyFilePath = property(fget=_getSSLPriKeyFilePath,
491                                 fset=_setSSLPriKeyFilePath,
492                                 doc="File path to SSL private key")
493 
494    def _setSSLPriKeyPwd(self, sslPriKeyPwd):
495        "Set method for ssl private key file password"
496        if not isinstance(sslPriKeyPwd, (type(None), basestring)):
497            raise TypeError("Signing private key password must be None "
498                            "or a valid string")
499       
500        # Explicitly convert to string as M2Crypto OpenSSL wrapper fails with
501        # unicode type
502        self.__sslPriKeyPwd = str(sslPriKeyPwd)
503
504    def _getSSLPriKeyPwd(self):
505        "Get property method for SSL private key"
506        return self.__sslPriKeyPwd
507       
508    sslPriKeyPwd = property(fset=_setSSLPriKeyPwd,
509                             fget=_getSSLPriKeyPwd,
510                             doc="Password protecting SSL private key file")
511
512    def __getstate__(self):
513        '''Enable pickling for use with beaker.session'''
514        return dict([(attrName, getattr(self, attrName))
515                     for attrName in self.__class__.__slots__])
516       
517    def __setstate__(self, attrDict):
518        '''Enable pickling for use with beaker.session'''
519        for attr, val in attrDict.items():
520            setattr(self, attr, val)
Note: See TracBrowser for help on using the repository browser.