source: TI12-security/trunk/python/ndg_security_common_utils_openid_dbinterface/ndg/security/common/utils/openid/dbinterface.py @ 6069

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg_security_common_utils_openid_dbinterface/ndg/security/common/utils/openid/dbinterface.py@6069
Revision 6069, 18.1 KB checked in by pjkersha, 11 years ago (diff)

Re-release as rc1

  • Property svn:executable set to *
Line 
1#!/usr/bin/env python
2"""NDG Security ndg.security.common.utils.openid package
3
4NERC DataGrid Project
5"""
6__author__ = "P J Kershaw"
7__date__ = "17/09/09"
8__copyright__ = "(C) 2009 Science and Technology Facilities Council"
9__license__ = "BSD - see LICENSE file in top-level directory"
10__contact__ = "Philip.Kershaw@stfc.ac.uk"
11__revision__ = '$Id: $'
12import re
13import urllib
14from string import Template
15import logging
16log = logging.getLogger(__name__)
17
18from ConfigParser import SafeConfigParser
19
20# Make an optional import to allow use of the UserIdentifier class alone which
21# has no database dependence
22_psycopg2Unavailable = False
23try:
24    from psycopg2 import connect
25except ImportError:
26    _psycopg2Unavailable = True
27
28
29class UserIdentifiersConfigError(Exception):
30    """Configuration error for UserIdentifiers instance"""
31
32
33class UserIdentifiers(object):
34    """Library of helper methods to parse a set of space delimited first names
35    and a surname string into a suitable OpenID indentifier of the form
36    <firstname>.<surname>
37    """
38
39    @staticmethod
40    def none2String(s):
41        """Convert None type to an empty string
42       
43        @type s: basestring/None
44        @param s: string to be converted
45        @rtype: basestring
46        @return: converted string
47        """
48        if s is None:
49            return ''
50        else:
51            return s
52
53    # Split surnames allowing for spaces, hyphenation and ' for e.g. O'Connor
54    splitPat = re.compile("[-)('\s]+")
55
56    @staticmethod
57    def sentenceCase(s, preserveCaps=False):
58        """Convert string to sentence case - first letter of first word is
59        capitalised
60        @type s: basestring
61        @param s: string to be converted
62        @type preserveCaps: bool
63        @param preserveCaps: set to True to leave uppercase strings such as
64        acronyms alone
65        @rtype: basestring
66        @return: converted string
67        """
68        if s is None:
69            return ''
70        elif s.isupper() and preserveCaps:
71            return s
72        else:
73            return s[0].upper() + s[1:].lower()
74
75    camelCase = staticmethod(lambda s: ''.join(
76                                    [UserIdentifiers.sentenceCase(i)
77                                     for i in UserIdentifiers.splitPat.split(
78                                     UserIdentifiers.none2String(s.strip()))
79                                     if i]
80                                    )
81                            )
82
83    getFirstName = staticmethod(lambda s: UserIdentifiers.camelCase(
84                            UserIdentifiers.none2String(s.strip()).split()[0]))
85
86    @staticmethod
87    def mcFilter(s):
88        '''Allow for surnames starting with Mc e.g. McDonald
89        @type s: basestring
90        @param s: string to be converted
91        @rtype: basestring
92        @return: converted string
93        '''
94        if s.startswith('Mc'):
95            return 'Mc' + s[2].upper() + s[3:]
96        else:
97            return s
98
99    convertSurname = staticmethod(lambda s: UserIdentifiers.mcFilter(
100                                                UserIdentifiers.camelCase(s)))
101
102    isUrlSafe = staticmethod(lambda s: urllib.quote(s) == s)
103
104    @staticmethod
105    def makeIdentifier(firstNames, surname):
106        """Make an OpenID identifier based on the users first name joined to
107        their last name with a dot separator"""
108        if not firstNames:
109            return UserIdentifiers.camelCase(surname)
110        else:
111            try:
112                firstName = UserIdentifiers.getFirstName(firstNames)
113            except IndexError:
114                log.warning('Error parsing first name from "%s": using '
115                            'surname only' % firstNames)
116                firstName = ''
117
118            if firstName.endswith('.'):
119                firstName = firstName[:-1]
120
121            try:
122                surname = UserIdentifiers.convertSurname(surname)
123            except IndexError:
124                log.warning('Error parsing surname from "%s": setting to null '
125                            'string' % surname)
126                surname = ''
127
128            if not firstName and not surname:
129                log.warning('Null identifier returned for firstNames="%s" and '
130                            'surname="%s"' % (firstNames, surname))
131                return ''
132
133            # Allow for surname or firstname not set
134            newId = (firstName+'.'+surname).strip('.')
135            if not UserIdentifiers.isUrlSafe(newId):
136                urlSafeId = urllib.quote(newId)
137                log.warning('Changing id "%s" to "%s" to make it URL safe' %
138                            (newId, urlSafeId))
139                return urlSafeId
140            else:
141                return newId
142
143
144class DatabaseUserIdentifiersConfigError(Exception):
145    """Configuration error for DatabaseUserIdentifiers instance"""
146
147
148class DatabaseUserIdentifiers(UserIdentifiers):
149    """Generate a list of OpenID identifiers from a query
150    to a user database"""
151
152    SECTION_NAME = "DatabaseOpenIDUserIdentifiers"
153    HOST_OPTNAME = "host"
154    DBNAME_OPTNAME = "dbName"
155    USERNAME_OPTNAME = "username"
156    PWD_OPTNAME = "pwd"
157    QUERY_GENERATE_OPTNAME = "query.generate"
158    QUERY_GENERATE_FROM_USERKEY_OPTNAME = "query.generateFromUserKey"
159    QUERY_UNIQ_IDENTIFIER_OPTNAME = "query.uniqIdentifier"
160    QUERY_GET_IDENTIFIERS_OPTNAME = "query.getIdentifiers"
161    UNIQ_IDENTIFIER_FIELD_NAME = "openIdUserComponent"
162    USERKEY_FIELD_NAME = "userKey"
163
164    def __init__(self, propertiesFilePath=None):
165        """Connect to Postgres database"""
166        super(DatabaseUserIdentifiers, self).__init__()
167       
168        if _psycopg2Unavailable:
169            log.warning("psycopg2 package is required for this class")
170
171        self.__con = None
172        self.__host = None
173        self.__dbName = None
174        self.__username = None
175        self.__pwd = None
176        self.__generateQuery = None
177        self.__generateFromUserKeyQuery = None
178        self.__getIdentifiersQuery = None
179        self.__uniqIdentifierQuery = None
180        self.__cursor = None
181        self.__db = None
182       
183        if propertiesFilePath is None:
184            raise AttributeError("No Configuration file was set")
185
186        self.readConfigFile(propertiesFilePath)
187
188    def __del__(self):
189        """Close database connection"""
190        self.close()
191
192    def readConfigFile(self, propertiesFilePath):
193        """Read the configuration for the database connection
194
195        @type propertiesFilePath: string
196        @param propertiesFilePath: file path to config file"""
197
198        if not isinstance(propertiesFilePath, basestring):
199            raise TypeError("Input Properties file path "
200                            "must be a valid string.")
201
202        cfg = SafeConfigParser()
203        cfg.read(propertiesFilePath)
204
205        self.__host = cfg.get(
206                        DatabaseUserIdentifiers.SECTION_NAME,
207                        DatabaseUserIdentifiers.HOST_OPTNAME)
208        self.__dbName = cfg.get(
209                        DatabaseUserIdentifiers.SECTION_NAME,
210                        DatabaseUserIdentifiers.DBNAME_OPTNAME)
211        self.__username = cfg.get(
212                        DatabaseUserIdentifiers.SECTION_NAME,
213                        DatabaseUserIdentifiers.USERNAME_OPTNAME)
214        self.__pwd = cfg.get(
215                        DatabaseUserIdentifiers.SECTION_NAME,
216                        DatabaseUserIdentifiers.PWD_OPTNAME)
217
218        self.__generateQuery = cfg.get(
219                        DatabaseUserIdentifiers.SECTION_NAME,
220                        DatabaseUserIdentifiers.QUERY_GENERATE_OPTNAME)
221
222        self.__generateFromUserKeyQuery = cfg.get(
223                    DatabaseUserIdentifiers.SECTION_NAME,
224                    DatabaseUserIdentifiers.QUERY_GENERATE_FROM_USERKEY_OPTNAME)
225
226        self.__getIdentifiersQuery = cfg.get(
227                    DatabaseUserIdentifiers.SECTION_NAME,
228                    DatabaseUserIdentifiers.QUERY_GET_IDENTIFIERS_OPTNAME)
229
230        self.__uniqIdentifierQuery = cfg.get(
231                    DatabaseUserIdentifiers.SECTION_NAME,
232                    DatabaseUserIdentifiers.QUERY_UNIQ_IDENTIFIER_OPTNAME)
233
234    def connect(self,
235                username=None,
236                dbName=None,
237                host=None,
238                pwd=None,
239                prompt="Database password: "):
240        """Connect to database
241
242        Values for keywords omitted are derived from the config file.  If pwd
243        is not in the config file it will be prompted for from stdin
244
245        @type username: string
246        @keyword username: database account username
247        @type dbName: string
248        @keyword dbName: name of database
249        @type host: string
250        @keyword host: database host machine
251        @type pwd: string
252        @keyword pwd: password for database account.  If omitted and not in
253        the config file it will be prompted for from stdin
254        @type prompt: string
255        @keyword prompt: override default password prompt"""
256
257        if not host:
258            host = self.__host
259
260        if not dbName:
261            dbName = self.__dbName
262
263        if not username:
264            username = self.__username
265
266        if not pwd:
267            pwd = self.__pwd
268
269            if not pwd:
270                import getpass
271                pwd = getpass.getpass(prompt)
272
273        try:
274            self.__db = connect("host=%s dbname=%s user=%s password=%s" %
275                                (host, dbName, username, pwd))
276            self.__cursor = self.__db.cursor()
277
278        except NameError, e:
279            raise DatabaseUserIdentifiersConfigError(
280                    "Error accessing connect() function - check that the "
281                    "Postgres Python package psycopg2 is installed; error is: "
282                    "%s" % e)
283
284        except Exception, e:
285            log.error("Error connecting to database \"%s\": %s" % (dbName, e))
286            raise
287
288    def close(self):
289        """Close database connection"""
290        if self.__con:
291            self.__con.close()
292
293    @staticmethod
294    def makeUniqId(ids, idBase, newId=None, counter=1):
295        """Check newId is not already assigned in the ids list,
296        if it is, make a new id based on idBase and the counter.
297        Check this new id and again if it's already assigned try
298        again but incrementing the counter to give a new id.
299        Make recursive calls until a unique id is arrived at
300       
301        @type ids: list
302        @param ids: list of existing OpenID identifiers
303        @type idBase: basestring
304        @param idBase: base string from which to construct a new identifier
305        @type newId: basestring
306        @param newId: candidate identifier to add in
307        @type counter: int
308        @param counter: number to append to an existing identifier in order to
309        make it into a new unique one
310        @rtype: basestring
311        @return: new unique identifier
312        """
313        if newId is None:
314            newId = idBase
315
316        if newId.lower() in ids:
317            newId = "%s%d" % (idBase, counter)
318            newId = DatabaseUserIdentifiers.makeUniqId(ids,
319                                                       idBase,
320                                                       newId=newId,
321                                                       counter=counter+1)
322        return newId
323
324    def generate(self):
325        """Generate a list of OpenID identifiers from the configured
326        query
327        @rtype: list
328        @return: list of username/OpenID user identifiers tuples
329        """
330
331        try:
332            self.connect()
333
334            ids = []
335            lowerCaseIds = []
336            accountIds = []
337            self.__cursor.execute(self.__generateQuery)
338            queryRes = self.__cursor.fetchall()
339
340            # Create OpenID URL snippets
341            for res in queryRes:
342                accountIds += [res[0]]
343                identifier = DatabaseUserIdentifiers.makeIdentifier(*res[1:3])
344                identifier = DatabaseUserIdentifiers.makeUniqId(lowerCaseIds, 
345                                                                identifier)
346                ids += [identifier]
347                lowerCaseIds += [identifier.lower()]
348        finally:
349            self.close()
350
351        # Associate the OpenID snippets with their account IDs
352        return zip(accountIds, ids)
353
354    def getIdentifiers(self):
355        """Get all OpenID identifiers currently held in the database
356        @rtype: list
357        @return: list of OpenID user identifiers
358        """
359        try:
360            self.connect()
361
362            self.__cursor.execute(self.__getIdentifiersQuery)
363            queryRes = self.__cursor.fetchall()
364
365            identifiers = [res[0] for res in queryRes]
366        finally:
367            self.close()
368
369        return identifiers
370
371    def generateFromUserKey(self, userKey):
372        """Generate a single OpenID identifier from a single user key configured
373        in the given query
374        @type userKey: basestring
375        @param userKey: database user table primary key
376        @rtype: string
377        @return: OpenID user identifier, None if user key is not found
378        """
379        queryTmpl = Template(self.__generateFromUserKeyQuery)
380        userKeyKw = {
381            DatabaseUserIdentifiers.USERKEY_FIELD_NAME: userKey
382        }
383        query = queryTmpl.substitute(userKeyKw)
384
385        try:
386            self.connect()
387
388            self.__cursor.execute(query)
389            queryRes = self.__cursor.fetchall()
390        finally:
391            self.close()
392
393        if len(queryRes) == 0:
394            log.debug("No userkey = %r found", userKey)
395            return None
396       
397        # Create OpenID URL snippet
398        res = queryRes[0]
399        identifier = DatabaseUserIdentifiers.makeIdentifier(*res[1:3])
400        existingIdentifiers = [i.lower() for i in self.getIdentifiers()]
401        identifier = DatabaseUserIdentifiers.makeUniqId(existingIdentifiers,
402                                                        identifier)
403
404        # Return the generated identifier
405        return identifier
406
407    def isUniqIdentifier(self, identifier):
408        """Check for the given OpenID User identifier URI snippet in the
409        database
410        @type identifier: basestring
411        @param identifier: OpenID user identifier
412        @rtype: bool
413        @return: True if input identifier is not present in the database
414        """
415        queryTmpl = Template(self.__uniqIdentifierQuery)
416        identKw = {
417            DatabaseUserIdentifiers.UNIQ_IDENTIFIER_FIELD_NAME: identifier
418        }
419        query = queryTmpl.substitute(identKw)
420
421        try:
422            self.connect()
423            self.__cursor.execute(query)
424            queryRes = self.__cursor.fetchall()
425        finally:
426            self.close()
427
428        return len(queryRes) == 0
429
430    def __getCursor(self):
431        """Return a database cursor instance"""
432        return self.__cursor
433
434    cursor = property(fget=__getCursor, doc="database cursor")
435
436
437import optparse
438import sys
439import os
440
441class Main(object):
442    """Wrapper to DatabaseUserIdentifiers class to enable call from the
443    command line
444    """
445
446    MIN_NARGS = 3
447    GENERATE_CMD_NARGS = ''
448    GENERATE_CMD_STR = 'generate-identifiers'
449    GET_IDENTIFIERS_CMD_STR = 'get-identifiers'
450    GENERATE_FROM_USERKEY_CMD_STR = 'generate-identifier-from-userkey'
451    QUERY_IDENTIFIER_CMD_STR = "is-uniq-identifier"
452
453    CMD_STRS = (
454        GENERATE_CMD_STR,
455        GET_IDENTIFIERS_CMD_STR,
456        GENERATE_FROM_USERKEY_CMD_STR,
457        QUERY_IDENTIFIER_CMD_STR
458    )
459    DEBUG_ENVVAR_NAME = 'OPENID_USER_IDENTIFIERS_DEBUG'
460
461    @classmethod
462    def run(cls):
463        """Parse command line arguments and run the query specified"""
464
465        if cls.DEBUG_ENVVAR_NAME in os.environ:
466            import pdb
467            pdb.set_trace()
468
469        parser = optparse.OptionParser()
470        parser.add_option("-c",
471                          "--command",
472                          dest="command",
473                          help="Database query to execute [%s]" %
474                               '|'.join(cls.CMD_STRS))
475
476        parser.add_option("-f",
477                          "--config-file",
478                          dest="configFilePath",
479                          help="database configuration file path")
480
481        parser.add_option("-u",
482                          "--user-identifier",
483                          dest="userIdentifier",
484                          help="user identifier for %s command only" %
485                               cls.QUERY_IDENTIFIER_CMD_STR)
486
487        parser.add_option("-k",
488                          "--user-key",
489                          dest="userKey",
490                          type="string",
491                          help="database user key for %s command only" %
492                               cls.GENERATE_FROM_USERKEY_CMD_STR)
493
494        opt = parser.parse_args()[0]
495
496        if not opt.command:
497            msg = "Error, no command set.\n\n" + parser.format_help()
498            raise SystemExit(msg)
499
500        ids = DatabaseUserIdentifiers(opt.configFilePath)
501
502        if opt.command == cls.GENERATE_CMD_STR:
503            for i in ids.generate():
504                print(i)
505        elif opt.command == cls.GET_IDENTIFIERS_CMD_STR:
506            identifiers = ids.getIdentifiers()
507            print(identifiers)
508
509        elif opt.command == cls.GENERATE_FROM_USERKEY_CMD_STR:
510            if not opt.userKey:
511                msg = "Error, no user key set for query command.\n\n" +\
512                    parser.format_help()
513                raise SystemExit(msg)
514
515            identifier = ids.generateFromUserKey(opt.userKey)
516            if identifier is None:
517                print('')
518            else:
519                print(identifier)
520
521        elif opt.command == cls.QUERY_IDENTIFIER_CMD_STR:
522            if not opt.userIdentifier:
523                msg = "Error, no user identifier set for query command.\n\n" +\
524                    parser.format_help()
525                raise SystemExit(msg)
526
527            isUniqIdent = ids.isUniqIdentifier(opt.userIdentifier)
528            print(isUniqIdent)
529
530            # In line with exit status convention 0 status means it
531            # IS NOT currently allocated, 1 means it is
532            sys.exit(not isUniqIdent)
533        else:
534            msg = "Command %s not recognised.\n\n" % opt.command + \
535                    parser.format_help()
536            raise SystemExit(msg)
537
538if __name__ == "__main__":
539    logging.basicConfig()
540    Main.run()
Note: See TracBrowser for help on using the repository browser.