source: TI12-security/trunk/esg_system_tests/esg/security/test/system/test_ssl.py @ 7553

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/esg_system_tests/esg/security/test/system/test_ssl.py@7553
Revision 7553, 3.8 KB checked in by pjkersha, 10 years ago (diff)

New package for ESG Security system tests: probes a set of endpoints checking SSL configuration.

Line 
1"""ESG Security SSL system tests
2
3"""
4__author__ = "P J Kershaw"
5__date__ = "20/09/10"
6__copyright__ = "(C) 2010 Science and Technology Facilities Council"
7__license__ = "BSD"
8__contact__ = "Philip.Kershaw@stfc.ac.uk"
9__revision__ = '$Id$'
10import logging
11logging.basicConfig(level=logging.DEBUG, 
12                    format='%(asctime)s %(levelname)-8s %(message)s',
13                    datefmt='%a, %d %b %Y %H:%M:%S')
14log = logging.getLogger(__name__)
15
16import unittest
17import traceback
18import socket
19from os import path
20from ConfigParser import SafeConfigParser
21
22from OpenSSL import SSL
23
24
25class SSLTestCase(unittest.TestCase):
26    """Test SSL endpoints in ESG federation"""
27    THIS_DIR = path.dirname(path.abspath(__file__))
28    INI_FILENAME = 'test_ssl.cfg'
29    INI_FILEPATH = path.join(THIS_DIR, INI_FILENAME)
30   
31    def __init__(self, *arg, **kw):
32        cfg = SafeConfigParser(defaults=dict(here=self.__class__.THIS_DIR))
33        cfg.optionxform = str
34        cfg.read(self.__class__.INI_FILEPATH)
35       
36        self.caCertDir = cfg.get('DEFAULT', 'caCertDir')
37        self.endpoints = []
38        for i in cfg.get('DEFAULT', 'endpoints').split():
39            fqdn, port = i.split(':') 
40            self.endpoints.append((fqdn, int(port)))
41       
42        self.ctx = SSL.Context(SSL.SSLv3_METHOD)
43        self.ctx.load_verify_locations(None, self.caCertDir)
44        self.ctx.set_verify_depth(9)
45       
46        def _callback(conn, x509, errorNum, errorDepth,  preverifyOK):
47            return preverifyOK
48       
49        self.ctx.set_verify(SSL.VERIFY_PEER, _callback)
50       
51        super(SSLTestCase, self).__init__(*arg, **kw)
52   
53    def _test_connection(self, endpoint):
54        log.info('Probing %s:%d ...' % endpoint)
55
56        conn = SSL.Connection(self.ctx, socket.socket())
57        conn.connect(endpoint)
58        try:
59            conn.do_handshake()
60        except SSL.Error:
61            log.error("Handshake error for %r: %s" %
62                      (endpoint, traceback.format_exc()))
63            return False
64           
65        except socket.error:
66            log.error("Socket error for %r: %s" %
67                      (endpoint, traceback.format_exc()))
68            return False
69   
70    def _test_m2crypto_connection(self, endpoint):
71        log.info('Probing %s:%d ...' % endpoint)
72        sock = socket.socket()
73        sock.settimeout(3.0)
74       
75        import M2Crypto
76        ctx = M2Crypto.SSL.Context("sslv3")
77       
78        ctx.load_verify_locations(capath=self.caCertDir)
79       
80        ctx.set_verify(M2Crypto.SSL.verify_peer | 
81                       M2Crypto.SSL.verify_fail_if_no_peer_cert, 
82                       depth=9)
83        conn = M2Crypto.SSL.Connection(ctx, sock)
84       
85        try:
86            conn.connect(endpoint)
87        except socket.timeout:
88            log.error("Socket timeout for %r: %s" %
89                      (endpoint, traceback.format_exc()))
90            return False
91           
92        except socket.error:
93            log.error("Socket error for %r: %s" %
94                      (endpoint, traceback.format_exc()))
95            return False
96       
97        except Exception:
98            log.error("Unknown error for %r: %s" %
99                      (endpoint, traceback.format_exc()))
100            return False
101           
102        return True
103       
104    def test01(self):
105        nFails = 0
106        for i in self.endpoints:
107            if not self._test_connection(i):
108                nFails += 1
109
110        self.failIf(nFails > 0, "%d connection failure(s)" % nFails)
111       
112    def _test_m2crypto(self):
113        nFails = 0
114        for i in self.endpoints:
115            if not self._test_m2crypto_connection(i):
116                nFails += 1
117
118        self.failIf(nFails > 0, "%d connection failure(s)" % nFails)       
119
120if __name__ == "__main__":
121    unittest.main()
Note: See TracBrowser for help on using the repository browser.