source: TI12-security/trunk/python/ndg_security_common_utils_openid_dbinterface/ndg/security/common/utils/openid/dbinterface.py @ 5722

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI12-security/trunk/python/ndg_security_common_utils_openid_dbinterface/ndg/security/common/utils/openid/dbinterface.py@5722
Revision 5722, 13.1 KB checked in by pjkersha, 10 years ago (diff)

Added new OpenID Postgres Database interface utility for creating OpenIDs from user name details and querying a database for existing OpenIDs. This is a separate ndg.security.common.utils optional package to help with BADC/NEODC OpenID migration.

  • Property svn:executable set to *
Line 
1#!/usr/bin/env python
2"""NDG Security ndg.security.common.utils.openid package
3
4NERC DataGrid Project
5"""
6__author__ = "P J Kershaw"
7__date__ = "17/09/09"
8__copyright__ = "(C) 2009 Science and Technology Facilities Council"
9__license__ = "BSD - see LICENSE file in top-level directory"
10__contact__ = "Philip.Kershaw@stfc.ac.uk"
11__revision__ = '$Id$'
12import re
13import urllib
14from string import Template
15import logging
16log = logging.getLogger(__name__)
17
18from ConfigParser import SafeConfigParser
19
20# Make an optional import to allow use of the UserIdentifier class alone which
21# has no database dependence
22_psycopg2Unavailable = False
23try:
24    from psycopg2 import connect
25except ImportError:
26    _psycopg2Unavailable = True
27
28
29class UserIdentifiersConfigError(Exception):
30    """Configuration error for UserIdentifiers instance"""
31
32
33class UserIdentifiers(object):
34    """Library of helper methods to parse a set of space delimited first names
35    and a surname string into a suitable OpenID indentifier of the form
36    <firstname>.<surname>
37    """
38
39    @staticmethod
40    def none2String(str):
41        if str is None:
42            return ''
43        else:
44            return str
45
46    # Split surnames allowing for spaces, hyphenation and ' for e.g. O'Connor
47    splitPat = re.compile("[-)('\s]+")
48
49    @staticmethod
50    def sentenceCase(s, preserveCaps=False):
51        if s is None:
52            return ''
53        elif s.isupper() and preserveCaps:
54            return s
55        else:
56            return s[0].upper() + s[1:].lower()
57
58    camelCase = staticmethod(lambda s: ''.join(
59                                    [UserIdentifiers.sentenceCase(i)
60                                     for i in UserIdentifiers.splitPat.split(
61                                     UserIdentifiers.none2String(s.strip()))
62                                     if i]
63                                    )
64                            )
65
66    getFirstName = staticmethod(lambda s: UserIdentifiers.camelCase(
67                                            UserIdentifiers.none2String(s.strip()).split()[0]))
68
69    @staticmethod
70    def mcFilter(s):
71        '''Allow for surnames starting with Mc e.g. McDonald'''
72        if s.startswith('Mc'):
73            return 'Mc' + s[2].upper() + s[3:]
74        else:
75            return s
76
77    convertSurname = staticmethod(lambda s: UserIdentifiers.mcFilter(
78                                                UserIdentifiers.camelCase(s)))
79
80    isUrlSafe = staticmethod(lambda s: urllib.quote(s) == s)
81
82    @staticmethod
83    def makeIdentifier(firstNames, surname):
84        """Make an OpenID identifier based on the users first name joined to
85        their last name with a dot separator"""
86        if not firstNames:
87            return UserIdentifiers.camelCase(surname)
88        else:
89            try:
90                firstName = UserIdentifiers.getFirstName(firstNames)
91            except IndexError:
92                log.warning('Error parsing first name from "%s": using '
93                            'surname only' % firstNames)
94                firstName = ''
95
96            if firstName.endswith('.'):
97                firstName = firstName[:-1]
98
99            try:
100                surname = UserIdentifiers.convertSurname(surname)
101            except IndexError:
102                log.warning('Error parsing surname from "%s": setting to null '
103                            'string' % surname)
104                surname = ''
105
106            if not firstName and not surname:
107                log.warning('Null identifier returned for firstNames="%s" and '
108                            'surname="%s"' % (firstNames, surname))
109                return ''
110
111            # Allow for surname or firstname not set
112            newId = (firstName+'.'+surname).strip('.')
113            if not UserIdentifiers.isUrlSafe(newId):
114                urlSafeId = urllib.quote(newId)
115                log.warning('Changing id "%s" to "%s" to make it URL safe' %
116                            (newId, urlSafeId))
117                return urlSafeId
118            else:
119                return newId
120
121
122class DatabaseUserIdentifiersConfigError(Exception):
123    """Configuration error for DatabaseUserIdentifiers instance"""
124
125class DatabaseUserIdentifiers(UserIdentifiers):
126    """Generate a list of OpenID identifiers from a query
127    to a user database"""
128
129    SECTION_NAME = "DatabaseOpenIDUserIdentifiers"
130    HOST_OPTION_NAME = "host"
131    DBNAME_OPTION_NAME = "dbName"
132    USERNAME_OPTION_NAME = "username"
133    PWD_OPTION_NAME = "pwd"
134    QUERY_GENERATE_OPTION_NAME = "query.generate"
135    QUERY_UNIQ_IDENTIFIER_OPTION_NAME = "query.uniqIdentifier"
136    UNIQ_IDENTIFIER_FIELD_NAME = "openIdUserComponent"
137
138    def __init__(self, propertiesFilePath=None):
139        """Connect to Postgres database"""
140        if _psycopg2Unavailable:
141            log.warning("psycopg2 package is required for this class")
142           
143        self.__con = None
144        self.__host = None
145        self.__dbName = None
146        self.__username = None
147        self.__pwd = None
148
149        if propertiesFilePath is None:
150            raise AttributeError("No Configuration file was set")
151
152        self.readConfigFile(propertiesFilePath)
153
154    def __del__(self):
155        """Close database connection"""
156        self.close()
157
158    def readConfigFile(self, propertiesFilePath):
159        """Read the configuration for the database connection
160
161        @type propertiesFilePath: string
162        @param propertiesFilePath: file path to config file"""
163
164        if not isinstance(propertiesFilePath, basestring):
165            raise TypeError("Input Properties file path "
166                            "must be a valid string.")
167
168        cfg = SafeConfigParser()
169        cfg.read(propertiesFilePath)
170
171        self.__host = cfg.get(
172                        DatabaseUserIdentifiers.SECTION_NAME, 
173                        DatabaseUserIdentifiers.HOST_OPTION_NAME)
174        self.__dbName = cfg.get(
175                        DatabaseUserIdentifiers.SECTION_NAME, 
176                        DatabaseUserIdentifiers.DBNAME_OPTION_NAME)
177        self.__username = cfg.get(
178                        DatabaseUserIdentifiers.SECTION_NAME, 
179                        DatabaseUserIdentifiers.USERNAME_OPTION_NAME)
180        self.__pwd = cfg.get(
181                        DatabaseUserIdentifiers.SECTION_NAME, 
182                        DatabaseUserIdentifiers.PWD_OPTION_NAME)
183
184        self.__generateQuery = cfg.get(
185                        DatabaseUserIdentifiers.SECTION_NAME, 
186                        DatabaseUserIdentifiers.QUERY_GENERATE_OPTION_NAME)
187
188        self.__uniqIdentifierQuery = cfg.get(
189                    DatabaseUserIdentifiers.SECTION_NAME, 
190                    DatabaseUserIdentifiers.QUERY_UNIQ_IDENTIFIER_OPTION_NAME)
191
192    def connect(self,
193                username=None,
194                dbName=None,
195                host=None,
196                pwd=None,
197                prompt="Database password: "):
198        """Connect to database
199
200        Values for keywords omitted are derived from the config file.  If pwd
201        is not in the config file it will be prompted for from stdin
202
203        @type username: string
204        @keyword username: database account username
205        @type dbName: string
206        @keyword dbName: name of database
207        @type host: string
208        @keyword host: database host machine
209        @type pwd: string
210        @keyword pwd: password for database account.  If omitted and not in
211        the config file it will be prompted for from stdin
212        @type prompt: string
213        @keyword prompt: override default password prompt"""
214
215        if not host:
216            host = self.__host
217
218        if not dbName:
219            dbName = self.__dbName
220
221        if not username:
222            username = self.__username
223
224        if not pwd:
225            pwd = self.__pwd
226
227            if not pwd:
228                import getpass
229                pwd = getpass.getpass(prompt=prompt)
230
231        try:
232            self.__db = connect("host=%s dbname=%s user=%s password=%s" % 
233                                (host, dbName, username, pwd))
234            self.__cursor = self.__db.cursor()
235           
236        except NameError, e:
237            raise DatabaseUserIdentifiersConfigError(
238                    "Error accessing connect() function - check that the "
239                    "Postgres Python package psycopg2 is installed; error is: "
240                    "%s" % e)
241           
242        except Exception, e:
243            log.error("Error connecting to database \"%s\": %s" % (dbName, e))
244            raise 
245
246    def close(self):
247        """Close database connection"""
248        if self.__con:
249            self.__con.close()
250
251    @staticmethod
252    def makeUniqId(ids, idBase, newId=None, counter=1):
253        """Check newId is not already assigned in the ids list,
254        if it is, make a new id based on idBase and the counter.
255        Check this new id and again if it's already assigned try
256        again but incrementing the counter to give a new id.
257        Make recursive calls until a unique id is arrived at
258        """
259        if newId is None:
260            newId = idBase
261
262        if newId in ids:
263            newId = "%s%d" % (idBase, counter)
264            newId = DatabaseUserIdentifiers.makeUniqId(ids,
265                                                       idBase,
266                                                       newId=newId,
267                                                       counter=counter+1)
268        return newId
269
270    def generate(self):
271        """Generate a list of OpenID identifiers from the configured
272        query
273        """
274
275        try:
276            self.connect()
277
278            ids = []
279            accountIds = []
280            self.__cursor.execute(self.__generateQuery)
281            queryRes = self.__cursor.fetchall()
282
283            # Create OpenID URL snippets
284            for res in queryRes:
285                accountIds += [res[0]]
286                id = DatabaseUserIdentifiers.makeIdentifier(*res[1:3])
287                id = DatabaseUserIdentifiers.makeUniqId(ids, id)
288                ids += [id]
289        finally:
290            self.close()
291
292        # Associate the OpenID snippets with their account IDs
293        return zip(accountIds, ids)
294
295    def isUniqIdentifier(self, identifier):
296        """Check for the given OpenID User identifier URI snippet in the
297        database
298        """
299        queryTmpl = Template(self.__uniqIdentifierQuery)
300        identKw = {
301            DatabaseUserIdentifiers.UNIQ_IDENTIFIER_FIELD_NAME: identifier
302        }
303        query = queryTmpl.substitute(**identKw)
304
305        try:
306            self.connect()
307            self.__cursor.execute(query)
308            queryRes = self.__cursor.fetchall()
309        finally:
310            self.close()
311
312        return len(queryRes) == 0
313
314    def __getCursor(self):
315        """Return a database cursor instance"""
316        return self.__cursor
317
318    cursor = property(fget=__getCursor, doc="database cursor")
319
320
321import optparse
322import sys
323import os
324
325class Main(object): 
326    """Wrapper to DatabaseUserIdentifiers class to enable call from the
327    command line
328    """
329       
330    MIN_NARGS = 3
331    GENERATE_CMD_NARGS = ''
332    GENERATE_CMD_STR = 'generate'
333    QUERY_IDENTIFIER_CMD_STR = "query-identifier"
334    CMD_STRS = (GENERATE_CMD_STR, QUERY_IDENTIFIER_CMD_STR)
335    DEBUG_ENVVAR_NAME = 'OPENID_USER_IDENTIFIERS_DEBUG'
336   
337    @classmethod
338    def run(cls):
339        """Parse command line arguments and run the query specified"""
340       
341        if cls.DEBUG_ENVVAR_NAME in os.environ:
342            import pdb
343            pdb.set_trace()
344           
345        parser = optparse.OptionParser()
346        parser.add_option("-c",
347                          "--command",
348                          dest="command",
349                          help="Database query to execute [%s]" % 
350                               '|'.join(cls.CMD_STRS))
351   
352        parser.add_option("-f",
353                          "--config-file",
354                          dest="configFilePath",
355                          help="database configuration file path")
356   
357        parser.add_option("-u",
358                          "--user-identifier",
359                          dest="userIdentifier",
360                          help="user identifier for %s command only" % 
361                               cls.QUERY_IDENTIFIER_CMD_STR)
362   
363        (opt, args) = parser.parse_args()
364   
365        if not opt.command:
366            msg = "Error, no command set.\n\n" + parser.print_help()
367            raise SystemExit(msg)
368
369        ids = DatabaseUserIdentifiers(opt.configFilePath)
370
371        if opt.command == cls.GENERATE_CMD_STR:
372            for i in ids.generate():
373                print(i)
374   
375        elif opt.command == cls.QUERY_IDENTIFIER_CMD_STR:
376            if not opt.userIdentifier:
377                msg = "Error, no user identifier set for query command.\n\n" +\
378                    parser.print_help()
379                raise SystemExit(msg)
380           
381            isUniqIdent = ids.isUniqIdentifier(opt.userIdentifier)
382            print(isUniqIdent)
383   
384            # In line with exit status convention 0 status means it
385            # IS NOT currently allocated, 1 means it is
386            sys.exit(not isUniqIdent)
387        else:
388            msg = "Command %s not recognised.\n\n" % opt.command + \
389                    parser.print_help()
390            raise SystemExit(msg)
391       
392if __name__ == "__main__":
393    logging.basicConfig()
394    Main.run()
Note: See TracBrowser for help on using the repository browser.