source: TI12-security/trunk/python/ndg_security_common/ndg/security/common/utils/m2crypto.py @ 6040

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg_security_common/ndg/security/common/utils/m2crypto.py@6040
Revision 6040, 20.3 KB checked in by pjkersha, 10 years ago (diff)
  • Added capability to thread paste SSL service for unit tests
  • added SSL Client authn functionality to SamlCredentialWallet? using ndg.security.common.utils.m2crypto.SSLContextProxy
Line 
1"""Extend M2Crypto SSL functionality for cert verification and custom
2timeout settings.
3
4NERC DataGrid Project"""
5__author__ = "P J Kershaw"
6__date__ = "02/07/07"
7__copyright__ = "(C) 2009 Science and Technology Facilities Council"
8__license__ = "BSD - see LICENSE file in top-level directory"
9__contact__ = "Philip.Kershaw@stfc.ac.uk"
10__revision__ = '$Id: $'
11import logging
12log = logging.getLogger(__name__)
13
14import os
15import re
16
17from M2Crypto import SSL, X509
18from M2Crypto.httpslib import HTTPSConnection as _HTTPSConnection
19
20from ndg.security.common.X509 import X509Cert, X509Stack, X500DN
21
22
23class InvalidCertSignature(SSL.Checker.SSLVerificationError):
24    """Raise if verification against CA cert public key fails"""
25
26
27class InvalidCertDN(SSL.Checker.SSLVerificationError):
28    """Raise if verification against a list acceptable DNs fails"""
29   
30
31class HostCheck(SSL.Checker.Checker, object):
32    """Override SSL.Checker.Checker to enable alternate Common Name
33    setting match for peer cert"""
34
35    def __init__(self, 
36                 peerCertDN=None, 
37                 peerCertCN=None,
38                 acceptedDNs=[], 
39                 caCertList=[],
40                 caCertFilePathList=[], 
41                 **kw):
42        """Override parent class __init__ to enable setting of myProxyServerDN
43        setting
44       
45        @type peerCertDN: string/list
46        @param peerCertDN: Set the expected Distinguished Name of the
47        server to avoid errors matching hostnames.  This is useful
48        where the hostname is not fully qualified. 
49
50        *param acceptedDNs: a list of acceptable DNs.  This enables validation
51        where the expected DN is where against a limited list of certs.
52       
53        @type peerCertCN: string
54        @param peerCertCN: enable alternate Common Name to peer
55        hostname
56       
57        @type caCertList: list type of M2Crypto.X509.X509 types
58        @param caCertList: CA X.509 certificates - if set the peer cert's
59        CA signature is verified against one of these.  At least one must
60        verify
61       
62        @type caCertFilePathList: list string types
63        @param caCertFilePathList: same as caCertList except input as list
64        of CA cert file paths"""
65       
66        SSL.Checker.Checker.__init__(self, **kw)
67       
68        self.peerCertDN = peerCertDN
69        self.peerCertCN = peerCertCN
70        self.acceptedDNs = acceptedDNs
71       
72        if caCertList:
73            self.caCertList = caCertList
74        elif caCertFilePathList:
75            self.caCertFilePathList = caCertFilePathList
76        else:
77            # Set default to enable len() test in __call__
78            self.__caCertStack = ()
79           
80    def __call__(self, peerCert, host=None):
81        """Carry out checks on server ID
82        @param peerCert: MyProxy server host certificate as M2Crypto.X509.X509
83        instance
84        @param host: name of host to check
85        """
86        if peerCert is None:
87            raise SSL.Checker.NoCertificate('SSL Peer did not return '
88                                            'certificate')
89
90        peerCertDN = '/'+peerCert.get_subject().as_text().replace(', ', '/')
91        try:
92            SSL.Checker.Checker.__call__(self, peerCert, host=self.peerCertCN)
93           
94        except SSL.Checker.WrongHost, e:
95            # Try match against peerCertDN set   
96            if peerCertDN != self.peerCertDN:
97                raise e
98
99        # At least one match should be found in the list - first convert to
100        # NDG X500DN type to allow per field matching for DN comparison
101        peerCertX500DN = X500DN(dn=peerCertDN)
102       
103        if self.acceptedDNs:
104           matchFound = False
105           for dn in self.acceptedDNs:
106               x500dn = X500DN(dn=dn)
107               if x500dn == peerCertX500DN:
108                   matchFound = True
109                   break
110               
111           if not matchFound:
112               raise InvalidCertDN('Peer cert DN "%s" doesn\'t match '
113                                   'verification list' % peerCertDN)
114
115        if len(self.__caCertStack) > 0:
116            try:
117                self.__caCertStack.verifyCertChain(
118                           x509Cert2Verify=X509Cert(m2CryptoX509=peerCert))
119            except Exception, e:
120                raise InvalidCertSignature("Peer certificate verification "
121                                           "against CA certificate failed: %s" 
122                                           % e)
123             
124        # They match - drop the exception and return all OK instead         
125        return True
126     
127    def __setCACertList(self, caCertList):
128        """Set list of CA certs - peer cert must validate against at least one
129        of these"""
130        self.__caCertStack = X509Stack()
131        for caCert in caCertList:
132            self.__caCertStack.push(caCert)
133
134    caCertList = property(fset=__setCACertList,
135                          doc="list of CA certificates - the peer certificate "
136                              "must validate against one")
137
138    def __setCACertsFromFileList(self, caCertFilePathList):
139        '''Read CA certificates from file and add them to the X.509
140        stack
141       
142        @type caCertFilePathList: basestring, list or tuple
143        @param caCertFilePathList: list of file paths for CA certificates to
144        be used to verify certificate used to sign message.  If a single
145        string item is input then this is converted into a tuple
146        '''
147        if isinstance(caCertFilePathList, basestring):
148            caCertFilePathList = (caCertFilePathList,)
149           
150        elif not isinstance(caCertFilePathList, (list, tuple)):
151            raise TypeError('Expecting a basestring, list or tuple type for '
152                            '"caCertFilePathList"')
153
154        self.__caCertStack = X509Stack()
155
156        for caCertFilePath in caCertFilePathList:
157            self.__caCertStack.push(X509.load_cert(caCertFilePath))
158       
159    caCertFilePathList = property(fset=__setCACertsFromFileList,
160                                  doc="list of CA certificate file paths - "
161                                      "peer certificate must validate against "
162                                      "one")
163
164
165class HTTPSConnection(_HTTPSConnection):
166    """Modified version of M2Crypto equivalent to enable custom checks with
167    the peer and timeout settings
168   
169    @type defReadTimeout: M2Crypto.SSL.timeout
170    @cvar defReadTimeout: default timeout for read operations
171    @type defWriteTimeout: M2Crypto.SSL.timeout
172    @cvar defWriteTimeout: default timeout for write operations"""   
173    defReadTimeout = SSL.timeout(sec=20.)
174    defWriteTimeout = SSL.timeout(sec=20.)
175   
176    def __init__(self, *args, **kw):
177        '''Overload to enable setting of post connection check
178        callback to SSL.Connection
179       
180        type *args: tuple
181        param *args: args which apply to M2Crypto.httpslib.HTTPSConnection
182        type **kw: dict
183        param **kw: additional keywords
184        @type postConnectionCheck: SSL.Checker.Checker derivative
185        @keyword postConnectionCheck: set class for checking peer
186        @type readTimeout: M2Crypto.SSL.timeout
187        @keyword readTimeout: readTimeout - set timeout for read
188        @type writeTimeout: M2Crypto.SSL.timeout
189        @keyword writeTimeout: similar to read timeout'''
190       
191        self._postConnectionCheck = kw.pop('postConnectionCheck',
192                                           SSL.Checker.Checker)
193       
194        if 'readTimeout' in kw:
195            if not isinstance(kw['readTimeout'], SSL.timeout):
196                raise AttributeError("readTimeout must be of type "
197                                     "M2Crypto.SSL.timeout")
198            self.readTimeout = kw.pop('readTimeout')
199        else:
200            self.readTimeout = HTTPSConnection.defReadTimeout
201             
202        if 'writeTimeout' in kw:
203            if not isinstance(kw['writeTimeout'], SSL.timeout):
204                raise AttributeError("writeTimeout must be of type "
205                                     "M2Crypto.SSL.timeout") 
206            self.writeTimeout = kw.pop('writeTimeout')
207        else:
208            self.writeTimeout = HTTPSConnection.defWriteTimeout
209   
210        self._clntCertFilePath = kw.pop('clntCertFilePath', None)
211        self._clntPriKeyFilePath = kw.pop('clntPriKeyFilePath', None)
212       
213        _HTTPSConnection.__init__(self, *args, **kw)
214       
215        # load up certificate stuff
216        if self._clntCertFilePath is not None and \
217           self._clntPriKeyFilePath is not None:
218            self.ssl_ctx.load_cert(self._clntCertFilePath, 
219                                   self._clntPriKeyFilePath)
220       
221       
222    def connect(self):
223        '''Overload M2Crypto.httpslib.HTTPSConnection to enable
224        custom post connection check of peer certificate and socket timeout'''
225
226        self.sock = SSL.Connection(self.ssl_ctx)
227        self.sock.set_post_connection_check_callback(self._postConnectionCheck)
228
229        self.sock.set_socket_read_timeout(self.readTimeout)
230        self.sock.set_socket_write_timeout(self.writeTimeout)
231
232        self.sock.connect((self.host, self.port))
233
234    def putrequest(self, method, url, **kw):
235        '''Overload to work around bug with unicode type URL'''
236        url = str(url)
237        _HTTPSConnection.putrequest(self, method, url, **kw) 
238         
239             
240class SSLContextProxy(object):
241    """Holder for M2Crypto.SSL.Context parameters"""
242    PRE_VERIFY_FAIL, PRE_VERIFY_OK = range(2)
243   
244    SSL_CERT_FILEPATH_OPTNAME = "sslCertFilePath"
245    SSL_PRIKEY_FILEPATH_OPTNAME = "sslPriKeyFilePath"
246    SSL_PRIKEY_PWD_OPTNAME = "sslPriKeyPwd"
247    SSL_CACERT_FILEPATH_OPTNAME = "sslCACertFilePath"
248    SSL_CACERT_DIRPATH_OPTNAME = "sslCACertDir"
249    SSL_VALID_DNS_OPTNAME = "sslValidDNs"
250   
251    OPTNAMES = (
252        SSL_CERT_FILEPATH_OPTNAME,
253        SSL_PRIKEY_FILEPATH_OPTNAME,
254        SSL_PRIKEY_PWD_OPTNAME,
255        SSL_CACERT_FILEPATH_OPTNAME,
256        SSL_CACERT_DIRPATH_OPTNAME,
257        SSL_VALID_DNS_OPTNAME
258    )
259   
260    __slots__ = OPTNAMES
261    __slots__ += tuple(["_SSLContextProxy__%s" % name for name in __slots__])
262    del name
263   
264    VALID_DNS_PAT = re.compile(',\s*')
265   
266    def __init__(self):
267        self.__sslCertFilePath = None
268        self.__sslPriKeyFilePath = None
269        self.__sslPriKeyPwd = None
270        self.__sslCACertFilePath = None
271        self.__sslCACertDir = None
272        self.__sslValidDNs = []
273
274    def createCtx(self, **kw):
275        """Create an M2Crypto SSL Context from this objects properties
276        @rtype: M2Crypto.SSL.Context
277        @return M2Crypto SSL context object
278        """
279        ctx = SSL.Context(**kw)
280        if self.sslCertFilePath and self.sslPriKeyFilePath:
281            ctx.load_cert(self.sslCertFilePath, 
282                          self.__sslPriKeyFilePath, 
283                          lambda *arg, **kw: self.sslPriKeyPwd)
284       
285        if self.sslCACertFilePath or self.sslCACertDir:
286            ctx.load_verify_locations(self.sslCACertFilePath, 
287                                      self.sslCACertDir)
288           
289        ctx.set_verify(SSL.verify_peer|SSL.verify_fail_if_no_peer_cert, 9,
290                       callback=self.createVerifySSLPeerCertCallback())
291           
292        return ctx
293   
294    def createVerifySSLPeerCertCallback(self):
295        """Create a callback function to enable the DN of the peer in an SSL
296        connection to be verified against a whitelist. 
297       
298        Nb. Making this function within the scope of a method of the class to
299        enables to access instance variables
300        """
301       
302        def _verifySSLPeerCertCallback(preVerifyOK, x509StoreCtx):
303            '''SSL verify callback function used to control the behaviour when
304            the SSL_VERIFY_PEER flag is set.  See:
305           
306            http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
307           
308            This implementation applies verification in order to check the DN
309            of the peer certificate against a whitelist
310           
311            @type preVerifyOK: int
312            @param preVerifyOK: If a verification error is found, this
313            parameter will be set to 0
314            @type x509StoreCtx: M2Crypto.X509.X509_Store_Context
315            @param x509StoreCtx: locate the certificate to be verified and
316            perform additional verification steps as needed
317            @rtype: int
318            @return: controls the strategy of the further verification process.
319            - If verify_callback returns 0, the verification process is
320            immediately stopped with "verification failed" state. If
321            SSL_VERIFY_PEER is set, a verification failure alert is sent to the
322            peer and the TLS/SSL handshake is terminated.
323            - If verify_callback returns 1, the verification process is
324            continued.
325            If verify_callback always returns 1, the TLS/SSL handshake will not
326            be terminated with respect to verification failures and the
327            connection
328            will be established. The calling process can however retrieve the
329            error code of the last verification error using
330            SSL_get_verify_result or by maintaining its own error storage
331            managed by verify_callback.
332            '''
333            if preVerifyOK == 0:
334                # Something is wrong with the certificate don't bother
335                # proceeding any further
336                log.error("verifyCallback: pre-verify OK flagged an error "
337                          "with the peer certificate, returning error state "
338                          "to caller ...")
339                return preVerifyOK
340           
341            x509CertChain = x509StoreCtx.get1_chain()
342            for cert in x509CertChain:
343                x509Cert = X509Cert.fromM2Crypto(cert)
344                if x509Cert.dn in self.sslValidDNs:
345                    return preVerifyOK
346               
347                subject = cert.get_subject()
348                dn = subject.as_text()
349                log.debug("verifyCallback: dn = %r", dn)
350               
351            # No match found so return fail status
352            return SSLContextProxy.PRE_VERIFY_FAIL
353       
354        return _verifySSLPeerCertCallback
355
356    def _getSSLCertFilePath(self):
357        return self.__sslCertFilePath
358   
359    def _setSSLCertFilePath(self, filePath):
360        "Set X.509 cert file path property method"
361       
362        if isinstance(filePath, basestring):
363            filePath = os.path.expandvars(filePath)
364           
365        elif filePath is not None:
366            raise TypeError("X.509 cert. file path must be a valid string")
367       
368        self.__sslCertFilePath = filePath
369               
370    sslCertFilePath = property(fset=_setSSLCertFilePath,
371                               fget=_getSSLCertFilePath,
372                               doc="File path to X.509 cert.")
373       
374    def _getSSLCACertFilePath(self):
375        """Get file path for list of CA cert or certs used to validate SSL
376        connections
377       
378        @rtype sslCACertFilePath: basestring
379        @return sslCACertFilePathList: file path to file containing concatenated
380        PEM encoded CA certificates."""
381        return self.__sslCACertFilePath
382   
383    def _setSSLCACertFilePath(self, value):
384        """Set CA cert file path
385       
386        @type sslCACertFilePath: basestring, list, tuple or None
387        @param sslCACertFilePath: file path to CA certificate file.  If None
388        then the input is quietly ignored."""
389        if not isinstance(value, (basestring, type(None))):
390            raise TypeError("Input CA Certificate file path must be "
391                            "a valid string or None type: %r" % type(value)) 
392       
393        self.__sslCACertFilePath = os.path.expandvars(value)
394       
395    sslCACertFilePath = property(fget=_getSSLCACertFilePath,
396                                 fset=_setSSLCACertFilePath,
397                                 doc="Path to file containing concatenated PEM "
398                                     "encoded CA Certificates - used for "
399                                     "verification of peer certs in SSL "
400                                     "connection")
401       
402    def _getSSLCACertDir(self):
403        """Get file path for list of CA cert or certs used to validate SSL
404        connections
405       
406        @rtype sslCACertDir: basestring
407        @return sslCACertDirList: directory containing PEM encoded CA
408        certificates."""
409        return self.__sslCACertDir
410   
411    def _setSSLCACertDir(self, value):
412        """Set CA cert or certs to validate AC signatures, signatures
413        of Attribute Authority SOAP responses and SSL connections where
414        AA SOAP service is run over SSL.
415       
416        @type sslCACertDir: basestring
417        @param sslCACertDir: directory containing CA certificate files.
418        """
419        if not isinstance(value, (basestring, type(None))):
420            raise TypeError("Input CA Certificate directroy must be "
421                            "a valid string or None type: %r" % type(value))     
422       
423        self.__sslCACertDir = os.path.expandvars(value)
424       
425    sslCACertDir = property(fget=_getSSLCACertDir,
426                            fset=_setSSLCACertDir,
427                            doc="Path to directory containing PEM encoded CA "
428                                "Certificates used for verification of peer "
429                                "certs in SSL connection.   Files in the "
430                                "directory must be named with the form "
431                                "<hash>.0 where <hash> can be obtained using "
432                                "openssl x509 -in cert -hash -noout or using "
433                                "the c_rehash OpenSSL script")
434   
435    def _getSslValidDNs(self):
436        return self.__sslValidDNs
437
438    def _setSslValidDNs(self, value):
439        if isinstance(value, basestring): 
440            pat = SSLContextProxy.VALID_DNS_PAT
441            self.__sslValidDNs = [X500DN.fromString(dn) 
442                                  for dn in pat.split(value)]
443           
444        elif isinstance(value, (tuple, list)):
445            self.__sslValidDNs = [X500DN.fromString(dn) for dn in value]
446        else:
447            raise TypeError('Expecting list/tuple or basestring type for "%s" '
448                            'attribute; got %r' %
449                            (SSLContextProxy.SSL_VALID_DNS_OPTNAME, 
450                             type(value)))
451   
452    sslValidDNs = property(_getSslValidDNs, 
453                           _setSslValidDNs, 
454                           doc="whitelist of acceptable certificate "
455                               "Distinguished Names for peer certificates in "
456                               "SSL requests")
457
458    def _getSSLPriKeyFilePath(self):
459        return self.__sslPriKeyFilePath
460   
461    def _setSSLPriKeyFilePath(self, filePath):
462        "Set ssl private key file path property method"
463       
464        if isinstance(filePath, basestring):
465            filePath = os.path.expandvars(filePath)
466
467        elif filePath is not None:
468            raise TypeError("Private key file path must be a valid "
469                            "string or None type")
470       
471        self.__sslPriKeyFilePath = filePath
472       
473    sslPriKeyFilePath = property(fget=_getSSLPriKeyFilePath,
474                                 fset=_setSSLPriKeyFilePath,
475                                 doc="File path to SSL private key")
476 
477    def _setSSLPriKeyPwd(self, sslPriKeyPwd):
478        "Set method for ssl private key file password"
479        if not isinstance(sslPriKeyPwd, (type(None), basestring)):
480            raise TypeError("Signing private key password must be None "
481                            "or a valid string")
482       
483        # Explicitly convert to string as M2Crypto OpenSSL wrapper fails with
484        # unicode type
485        self.__sslPriKeyPwd = str(sslPriKeyPwd)
486
487    def _getSSLPriKeyPwd(self):
488        "Get property method for SSL private key"
489        return self.__sslPriKeyPwd
490       
491    sslPriKeyPwd = property(fset=_setSSLPriKeyPwd,
492                             fget=_getSSLPriKeyPwd,
493                             doc="Password protecting SSL private key file")
494
495    def __getstate__(self):
496        '''Enable pickling for use with beaker.session'''
497        return dict([(attrName, getattr(self, attrName))
498                     for attrName in self.__class__.__slots__])
499       
500    def __setstate__(self, attrDict):
501        '''Enable pickling for use with beaker.session'''
502        for attr, val in attrDict.items():
503            setattr(self, attr, val)
Note: See TracBrowser for help on using the repository browser.