OpenSecurity/install/web.py-0.37/web/db.py
author om
Mon, 02 Dec 2013 14:02:05 +0100
changeset 3 65432e6c6042
permissions -rwxr-xr-x
initial deployment and project layout commit
     1 """
     2 Database API
     3 (part of web.py)
     4 """
     5 
     6 __all__ = [
     7   "UnknownParamstyle", "UnknownDB", "TransactionError", 
     8   "sqllist", "sqlors", "reparam", "sqlquote",
     9   "SQLQuery", "SQLParam", "sqlparam",
    10   "SQLLiteral", "sqlliteral",
    11   "database", 'DB',
    12 ]
    13 
    14 import time
    15 try:
    16     import datetime
    17 except ImportError:
    18     datetime = None
    19 
    20 try: set
    21 except NameError:
    22     from sets import Set as set
    23     
    24 from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
    25 
    26 try:
    27     # db module can work independent of web.py
    28     from webapi import debug, config
    29 except:
    30     import sys
    31     debug = sys.stderr
    32     config = storage()
    33 
    34 class UnknownDB(Exception):
    35     """raised for unsupported dbms"""
    36     pass
    37 
    38 class _ItplError(ValueError): 
    39     def __init__(self, text, pos):
    40         ValueError.__init__(self)
    41         self.text = text
    42         self.pos = pos
    43     def __str__(self):
    44         return "unfinished expression in %s at char %d" % (
    45             repr(self.text), self.pos)
    46 
    47 class TransactionError(Exception): pass
    48 
    49 class UnknownParamstyle(Exception): 
    50     """
    51     raised for unsupported db paramstyles
    52 
    53     (currently supported: qmark, numeric, format, pyformat)
    54     """
    55     pass
    56     
    57 class SQLParam(object):
    58     """
    59     Parameter in SQLQuery.
    60     
    61         >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
    62         >>> q
    63         <sql: "SELECT * FROM test WHERE name='joe'">
    64         >>> q.query()
    65         'SELECT * FROM test WHERE name=%s'
    66         >>> q.values()
    67         ['joe']
    68     """
    69     __slots__ = ["value"]
    70 
    71     def __init__(self, value):
    72         self.value = value
    73         
    74     def get_marker(self, paramstyle='pyformat'):
    75         if paramstyle == 'qmark':
    76             return '?'
    77         elif paramstyle == 'numeric':
    78             return ':1'
    79         elif paramstyle is None or paramstyle in ['format', 'pyformat']:
    80             return '%s'
    81         raise UnknownParamstyle, paramstyle
    82         
    83     def sqlquery(self): 
    84         return SQLQuery([self])
    85         
    86     def __add__(self, other):
    87         return self.sqlquery() + other
    88         
    89     def __radd__(self, other):
    90         return other + self.sqlquery() 
    91             
    92     def __str__(self): 
    93         return str(self.value)
    94     
    95     def __repr__(self):
    96         return '<param: %s>' % repr(self.value)
    97 
    98 sqlparam =  SQLParam
    99 
   100 class SQLQuery(object):
   101     """
   102     You can pass this sort of thing as a clause in any db function.
   103     Otherwise, you can pass a dictionary to the keyword argument `vars`
   104     and the function will call reparam for you.
   105 
   106     Internally, consists of `items`, which is a list of strings and
   107     SQLParams, which get concatenated to produce the actual query.
   108     """
   109     __slots__ = ["items"]
   110 
   111     # tested in sqlquote's docstring
   112     def __init__(self, items=None):
   113         r"""Creates a new SQLQuery.
   114         
   115             >>> SQLQuery("x")
   116             <sql: 'x'>
   117             >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
   118             >>> q
   119             <sql: 'SELECT * FROM test WHERE x=1'>
   120             >>> q.query(), q.values()
   121             ('SELECT * FROM test WHERE x=%s', [1])
   122             >>> SQLQuery(SQLParam(1))
   123             <sql: '1'>
   124         """
   125         if items is None:
   126             self.items = []
   127         elif isinstance(items, list):
   128             self.items = items
   129         elif isinstance(items, SQLParam):
   130             self.items = [items]
   131         elif isinstance(items, SQLQuery):
   132             self.items = list(items.items)
   133         else:
   134             self.items = [items]
   135             
   136         # Take care of SQLLiterals
   137         for i, item in enumerate(self.items):
   138             if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
   139                 self.items[i] = item.value.v
   140 
   141     def append(self, value):
   142         self.items.append(value)
   143 
   144     def __add__(self, other):
   145         if isinstance(other, basestring):
   146             items = [other]
   147         elif isinstance(other, SQLQuery):
   148             items = other.items
   149         else:
   150             return NotImplemented
   151         return SQLQuery(self.items + items)
   152 
   153     def __radd__(self, other):
   154         if isinstance(other, basestring):
   155             items = [other]
   156         else:
   157             return NotImplemented
   158             
   159         return SQLQuery(items + self.items)
   160 
   161     def __iadd__(self, other):
   162         if isinstance(other, (basestring, SQLParam)):
   163             self.items.append(other)
   164         elif isinstance(other, SQLQuery):
   165             self.items.extend(other.items)
   166         else:
   167             return NotImplemented
   168         return self
   169 
   170     def __len__(self):
   171         return len(self.query())
   172         
   173     def query(self, paramstyle=None):
   174         """
   175         Returns the query part of the sql query.
   176             >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
   177             >>> q.query()
   178             'SELECT * FROM test WHERE name=%s'
   179             >>> q.query(paramstyle='qmark')
   180             'SELECT * FROM test WHERE name=?'
   181         """
   182         s = []
   183         for x in self.items:
   184             if isinstance(x, SQLParam):
   185                 x = x.get_marker(paramstyle)
   186                 s.append(safestr(x))
   187             else:
   188                 x = safestr(x)
   189                 # automatically escape % characters in the query
   190                 # For backward compatability, ignore escaping when the query looks already escaped
   191                 if paramstyle in ['format', 'pyformat']:
   192                     if '%' in x and '%%' not in x:
   193                         x = x.replace('%', '%%')
   194                 s.append(x)
   195         return "".join(s)
   196     
   197     def values(self):
   198         """
   199         Returns the values of the parameters used in the sql query.
   200             >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
   201             >>> q.values()
   202             ['joe']
   203         """
   204         return [i.value for i in self.items if isinstance(i, SQLParam)]
   205         
   206     def join(items, sep=' ', prefix=None, suffix=None, target=None):
   207         """
   208         Joins multiple queries.
   209         
   210         >>> SQLQuery.join(['a', 'b'], ', ')
   211         <sql: 'a, b'>
   212 
   213         Optinally, prefix and suffix arguments can be provided.
   214 
   215         >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
   216         <sql: '(a, b)'>
   217 
   218         If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
   219         """
   220         if target is None:
   221             target = SQLQuery()
   222 
   223         target_items = target.items
   224 
   225         if prefix:
   226             target_items.append(prefix)
   227 
   228         for i, item in enumerate(items):
   229             if i != 0:
   230                 target_items.append(sep)
   231             if isinstance(item, SQLQuery):
   232                 target_items.extend(item.items)
   233             else:
   234                 target_items.append(item)
   235 
   236         if suffix:
   237             target_items.append(suffix)
   238         return target
   239     
   240     join = staticmethod(join)
   241     
   242     def _str(self):
   243         try:
   244             return self.query() % tuple([sqlify(x) for x in self.values()])            
   245         except (ValueError, TypeError):
   246             return self.query()
   247         
   248     def __str__(self):
   249         return safestr(self._str())
   250         
   251     def __unicode__(self):
   252         return safeunicode(self._str())
   253 
   254     def __repr__(self):
   255         return '<sql: %s>' % repr(str(self))
   256 
   257 class SQLLiteral: 
   258     """
   259     Protects a string from `sqlquote`.
   260 
   261         >>> sqlquote('NOW()')
   262         <sql: "'NOW()'">
   263         >>> sqlquote(SQLLiteral('NOW()'))
   264         <sql: 'NOW()'>
   265     """
   266     def __init__(self, v): 
   267         self.v = v
   268 
   269     def __repr__(self): 
   270         return self.v
   271 
   272 sqlliteral = SQLLiteral
   273 
   274 def _sqllist(values):
   275     """
   276         >>> _sqllist([1, 2, 3])
   277         <sql: '(1, 2, 3)'>
   278     """
   279     items = []
   280     items.append('(')
   281     for i, v in enumerate(values):
   282         if i != 0:
   283             items.append(', ')
   284         items.append(sqlparam(v))
   285     items.append(')')
   286     return SQLQuery(items)
   287 
   288 def reparam(string_, dictionary): 
   289     """
   290     Takes a string and a dictionary and interpolates the string
   291     using values from the dictionary. Returns an `SQLQuery` for the result.
   292 
   293         >>> reparam("s = $s", dict(s=True))
   294         <sql: "s = 't'">
   295         >>> reparam("s IN $s", dict(s=[1, 2]))
   296         <sql: 's IN (1, 2)'>
   297     """
   298     dictionary = dictionary.copy() # eval mucks with it
   299     vals = []
   300     result = []
   301     for live, chunk in _interpolate(string_):
   302         if live:
   303             v = eval(chunk, dictionary)
   304             result.append(sqlquote(v))
   305         else: 
   306             result.append(chunk)
   307     return SQLQuery.join(result, '')
   308 
   309 def sqlify(obj): 
   310     """
   311     converts `obj` to its proper SQL version
   312 
   313         >>> sqlify(None)
   314         'NULL'
   315         >>> sqlify(True)
   316         "'t'"
   317         >>> sqlify(3)
   318         '3'
   319     """
   320     # because `1 == True and hash(1) == hash(True)`
   321     # we have to do this the hard way...
   322 
   323     if obj is None:
   324         return 'NULL'
   325     elif obj is True:
   326         return "'t'"
   327     elif obj is False:
   328         return "'f'"
   329     elif datetime and isinstance(obj, datetime.datetime):
   330         return repr(obj.isoformat())
   331     else:
   332         if isinstance(obj, unicode): obj = obj.encode('utf8')
   333         return repr(obj)
   334 
   335 def sqllist(lst): 
   336     """
   337     Converts the arguments for use in something like a WHERE clause.
   338     
   339         >>> sqllist(['a', 'b'])
   340         'a, b'
   341         >>> sqllist('a')
   342         'a'
   343         >>> sqllist(u'abc')
   344         u'abc'
   345     """
   346     if isinstance(lst, basestring): 
   347         return lst
   348     else:
   349         return ', '.join(lst)
   350 
   351 def sqlors(left, lst):
   352     """
   353     `left is a SQL clause like `tablename.arg = ` 
   354     and `lst` is a list of values. Returns a reparam-style
   355     pair featuring the SQL that ORs together the clause
   356     for each item in the lst.
   357 
   358         >>> sqlors('foo = ', [])
   359         <sql: '1=2'>
   360         >>> sqlors('foo = ', [1])
   361         <sql: 'foo = 1'>
   362         >>> sqlors('foo = ', 1)
   363         <sql: 'foo = 1'>
   364         >>> sqlors('foo = ', [1,2,3])
   365         <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
   366     """
   367     if isinstance(lst, iters):
   368         lst = list(lst)
   369         ln = len(lst)
   370         if ln == 0:
   371             return SQLQuery("1=2")
   372         if ln == 1:
   373             lst = lst[0]
   374 
   375     if isinstance(lst, iters):
   376         return SQLQuery(['('] + 
   377           sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
   378           ['1=2)']
   379         )
   380     else:
   381         return left + sqlparam(lst)
   382         
   383 def sqlwhere(dictionary, grouping=' AND '): 
   384     """
   385     Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
   386     
   387         >>> sqlwhere({'cust_id': 2, 'order_id':3})
   388         <sql: 'order_id = 3 AND cust_id = 2'>
   389         >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
   390         <sql: 'order_id = 3, cust_id = 2'>
   391         >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
   392         'a = %s AND b = %s'
   393     """
   394     return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
   395 
   396 def sqlquote(a): 
   397     """
   398     Ensures `a` is quoted properly for use in a SQL query.
   399 
   400         >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
   401         <sql: "WHERE x = 't' AND y = 3">
   402         >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
   403         <sql: "WHERE x = 't' AND y IN (2, 3)">
   404     """
   405     if isinstance(a, list):
   406         return _sqllist(a)
   407     else:
   408         return sqlparam(a).sqlquery()
   409 
   410 class Transaction:
   411     """Database transaction."""
   412     def __init__(self, ctx):
   413         self.ctx = ctx
   414         self.transaction_count = transaction_count = len(ctx.transactions)
   415 
   416         class transaction_engine:
   417             """Transaction Engine used in top level transactions."""
   418             def do_transact(self):
   419                 ctx.commit(unload=False)
   420 
   421             def do_commit(self):
   422                 ctx.commit()
   423 
   424             def do_rollback(self):
   425                 ctx.rollback()
   426 
   427         class subtransaction_engine:
   428             """Transaction Engine used in sub transactions."""
   429             def query(self, q):
   430                 db_cursor = ctx.db.cursor()
   431                 ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
   432 
   433             def do_transact(self):
   434                 self.query('SAVEPOINT webpy_sp_%s')
   435 
   436             def do_commit(self):
   437                 self.query('RELEASE SAVEPOINT webpy_sp_%s')
   438 
   439             def do_rollback(self):
   440                 self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
   441 
   442         class dummy_engine:
   443             """Transaction Engine used instead of subtransaction_engine 
   444             when sub transactions are not supported."""
   445             do_transact = do_commit = do_rollback = lambda self: None
   446 
   447         if self.transaction_count:
   448             # nested transactions are not supported in some databases
   449             if self.ctx.get('ignore_nested_transactions'):
   450                 self.engine = dummy_engine()
   451             else:
   452                 self.engine = subtransaction_engine()
   453         else:
   454             self.engine = transaction_engine()
   455 
   456         self.engine.do_transact()
   457         self.ctx.transactions.append(self)
   458 
   459     def __enter__(self):
   460         return self
   461 
   462     def __exit__(self, exctype, excvalue, traceback):
   463         if exctype is not None:
   464             self.rollback()
   465         else:
   466             self.commit()
   467 
   468     def commit(self):
   469         if len(self.ctx.transactions) > self.transaction_count:
   470             self.engine.do_commit()
   471             self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
   472 
   473     def rollback(self):
   474         if len(self.ctx.transactions) > self.transaction_count:
   475             self.engine.do_rollback()
   476             self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
   477 
   478 class DB: 
   479     """Database"""
   480     def __init__(self, db_module, keywords):
   481         """Creates a database.
   482         """
   483         # some DB implementaions take optional paramater `driver` to use a specific driver modue
   484         # but it should not be passed to connect
   485         keywords.pop('driver', None)
   486 
   487         self.db_module = db_module
   488         self.keywords = keywords
   489 
   490         self._ctx = threadeddict()
   491         # flag to enable/disable printing queries
   492         self.printing = config.get('debug_sql', config.get('debug', False))
   493         self.supports_multiple_insert = False
   494         
   495         try:
   496             import DBUtils
   497             # enable pooling if DBUtils module is available.
   498             self.has_pooling = True
   499         except ImportError:
   500             self.has_pooling = False
   501             
   502         # Pooling can be disabled by passing pooling=False in the keywords.
   503         self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
   504             
   505     def _getctx(self): 
   506         if not self._ctx.get('db'):
   507             self._load_context(self._ctx)
   508         return self._ctx
   509     ctx = property(_getctx)
   510     
   511     def _load_context(self, ctx):
   512         ctx.dbq_count = 0
   513         ctx.transactions = [] # stack of transactions
   514         
   515         if self.has_pooling:
   516             ctx.db = self._connect_with_pooling(self.keywords)
   517         else:
   518             ctx.db = self._connect(self.keywords)
   519         ctx.db_execute = self._db_execute
   520         
   521         if not hasattr(ctx.db, 'commit'):
   522             ctx.db.commit = lambda: None
   523 
   524         if not hasattr(ctx.db, 'rollback'):
   525             ctx.db.rollback = lambda: None
   526             
   527         def commit(unload=True):
   528             # do db commit and release the connection if pooling is enabled.            
   529             ctx.db.commit()
   530             if unload and self.has_pooling:
   531                 self._unload_context(self._ctx)
   532                 
   533         def rollback():
   534             # do db rollback and release the connection if pooling is enabled.
   535             ctx.db.rollback()
   536             if self.has_pooling:
   537                 self._unload_context(self._ctx)
   538                 
   539         ctx.commit = commit
   540         ctx.rollback = rollback
   541             
   542     def _unload_context(self, ctx):
   543         del ctx.db
   544             
   545     def _connect(self, keywords):
   546         return self.db_module.connect(**keywords)
   547         
   548     def _connect_with_pooling(self, keywords):
   549         def get_pooled_db():
   550             from DBUtils import PooledDB
   551 
   552             # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
   553             # see Bug#122112
   554             
   555             if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
   556                 return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
   557             else:
   558                 return PooledDB.PooledDB(creator=self.db_module, **keywords)
   559         
   560         if getattr(self, '_pooleddb', None) is None:
   561             self._pooleddb = get_pooled_db()
   562         
   563         return self._pooleddb.connection()
   564         
   565     def _db_cursor(self):
   566         return self.ctx.db.cursor()
   567 
   568     def _param_marker(self):
   569         """Returns parameter marker based on paramstyle attribute if this database."""
   570         style = getattr(self, 'paramstyle', 'pyformat')
   571 
   572         if style == 'qmark':
   573             return '?'
   574         elif style == 'numeric':
   575             return ':1'
   576         elif style in ['format', 'pyformat']:
   577             return '%s'
   578         raise UnknownParamstyle, style
   579 
   580     def _db_execute(self, cur, sql_query): 
   581         """executes an sql query"""
   582         self.ctx.dbq_count += 1
   583         
   584         try:
   585             a = time.time()
   586             query, params = self._process_query(sql_query)
   587             out = cur.execute(query, params)
   588             b = time.time()
   589         except:
   590             if self.printing:
   591                 print >> debug, 'ERR:', str(sql_query)
   592             if self.ctx.transactions:
   593                 self.ctx.transactions[-1].rollback()
   594             else:
   595                 self.ctx.rollback()
   596             raise
   597 
   598         if self.printing:
   599             print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
   600         return out
   601 
   602     def _process_query(self, sql_query):
   603         """Takes the SQLQuery object and returns query string and parameters.
   604         """
   605         paramstyle = getattr(self, 'paramstyle', 'pyformat')
   606         query = sql_query.query(paramstyle)
   607         params = sql_query.values()
   608         return query, params
   609     
   610     def _where(self, where, vars): 
   611         if isinstance(where, (int, long)):
   612             where = "id = " + sqlparam(where)
   613         #@@@ for backward-compatibility
   614         elif isinstance(where, (list, tuple)) and len(where) == 2:
   615             where = SQLQuery(where[0], where[1])
   616         elif isinstance(where, SQLQuery):
   617             pass
   618         else:
   619             where = reparam(where, vars)        
   620         return where
   621     
   622     def query(self, sql_query, vars=None, processed=False, _test=False): 
   623         """
   624         Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
   625         If `processed=True`, `vars` is a `reparam`-style list to use 
   626         instead of interpolating.
   627         
   628             >>> db = DB(None, {})
   629             >>> db.query("SELECT * FROM foo", _test=True)
   630             <sql: 'SELECT * FROM foo'>
   631             >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
   632             <sql: "SELECT * FROM foo WHERE x = 'f'">
   633             >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
   634             <sql: "SELECT * FROM foo WHERE x = 'f'">
   635         """
   636         if vars is None: vars = {}
   637         
   638         if not processed and not isinstance(sql_query, SQLQuery):
   639             sql_query = reparam(sql_query, vars)
   640         
   641         if _test: return sql_query
   642         
   643         db_cursor = self._db_cursor()
   644         self._db_execute(db_cursor, sql_query)
   645         
   646         if db_cursor.description:
   647             names = [x[0] for x in db_cursor.description]
   648             def iterwrapper():
   649                 row = db_cursor.fetchone()
   650                 while row:
   651                     yield storage(dict(zip(names, row)))
   652                     row = db_cursor.fetchone()
   653             out = iterbetter(iterwrapper())
   654             out.__len__ = lambda: int(db_cursor.rowcount)
   655             out.list = lambda: [storage(dict(zip(names, x))) \
   656                                for x in db_cursor.fetchall()]
   657         else:
   658             out = db_cursor.rowcount
   659         
   660         if not self.ctx.transactions: 
   661             self.ctx.commit()
   662         return out
   663     
   664     def select(self, tables, vars=None, what='*', where=None, order=None, group=None, 
   665                limit=None, offset=None, _test=False): 
   666         """
   667         Selects `what` from `tables` with clauses `where`, `order`, 
   668         `group`, `limit`, and `offset`. Uses vars to interpolate. 
   669         Otherwise, each clause can be a SQLQuery.
   670         
   671             >>> db = DB(None, {})
   672             >>> db.select('foo', _test=True)
   673             <sql: 'SELECT * FROM foo'>
   674             >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
   675             <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
   676         """
   677         if vars is None: vars = {}
   678         sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
   679         clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
   680         qout = SQLQuery.join(clauses)
   681         if _test: return qout
   682         return self.query(qout, processed=True)
   683     
   684     def where(self, table, what='*', order=None, group=None, limit=None, 
   685               offset=None, _test=False, **kwargs):
   686         """
   687         Selects from `table` where keys are equal to values in `kwargs`.
   688         
   689             >>> db = DB(None, {})
   690             >>> db.where('foo', bar_id=3, _test=True)
   691             <sql: 'SELECT * FROM foo WHERE bar_id = 3'>
   692             >>> db.where('foo', source=2, crust='dewey', _test=True)
   693             <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
   694             >>> db.where('foo', _test=True)
   695             <sql: 'SELECT * FROM foo'>
   696         """
   697         where_clauses = []
   698         for k, v in kwargs.iteritems():
   699             where_clauses.append(k + ' = ' + sqlquote(v))
   700             
   701         if where_clauses:
   702             where = SQLQuery.join(where_clauses, " AND ")
   703         else:
   704             where = None
   705             
   706         return self.select(table, what=what, order=order, 
   707                group=group, limit=limit, offset=offset, _test=_test, 
   708                where=where)
   709     
   710     def sql_clauses(self, what, tables, where, group, order, limit, offset): 
   711         return (
   712             ('SELECT', what),
   713             ('FROM', sqllist(tables)),
   714             ('WHERE', where),
   715             ('GROUP BY', group),
   716             ('ORDER BY', order),
   717             ('LIMIT', limit),
   718             ('OFFSET', offset))
   719     
   720     def gen_clause(self, sql, val, vars): 
   721         if isinstance(val, (int, long)):
   722             if sql == 'WHERE':
   723                 nout = 'id = ' + sqlquote(val)
   724             else:
   725                 nout = SQLQuery(val)
   726         #@@@
   727         elif isinstance(val, (list, tuple)) and len(val) == 2:
   728             nout = SQLQuery(val[0], val[1]) # backwards-compatibility
   729         elif isinstance(val, SQLQuery):
   730             nout = val
   731         else:
   732             nout = reparam(val, vars)
   733 
   734         def xjoin(a, b):
   735             if a and b: return a + ' ' + b
   736             else: return a or b
   737 
   738         return xjoin(sql, nout)
   739 
   740     def insert(self, tablename, seqname=None, _test=False, **values): 
   741         """
   742         Inserts `values` into `tablename`. Returns current sequence ID.
   743         Set `seqname` to the ID if it's not the default, or to `False`
   744         if there isn't one.
   745         
   746             >>> db = DB(None, {})
   747             >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
   748             >>> q
   749             <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
   750             >>> q.query()
   751             'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
   752             >>> q.values()
   753             [2, 'bob']
   754         """
   755         def q(x): return "(" + x + ")"
   756         
   757         if values:
   758             _keys = SQLQuery.join(values.keys(), ', ')
   759             _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
   760             sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
   761         else:
   762             sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
   763 
   764         if _test: return sql_query
   765         
   766         db_cursor = self._db_cursor()
   767         if seqname is not False: 
   768             sql_query = self._process_insert_query(sql_query, tablename, seqname)
   769 
   770         if isinstance(sql_query, tuple):
   771             # for some databases, a separate query has to be made to find 
   772             # the id of the inserted row.
   773             q1, q2 = sql_query
   774             self._db_execute(db_cursor, q1)
   775             self._db_execute(db_cursor, q2)
   776         else:
   777             self._db_execute(db_cursor, sql_query)
   778 
   779         try: 
   780             out = db_cursor.fetchone()[0]
   781         except Exception: 
   782             out = None
   783         
   784         if not self.ctx.transactions: 
   785             self.ctx.commit()
   786         return out
   787         
   788     def _get_insert_default_values_query(self, table):
   789         return "INSERT INTO %s DEFAULT VALUES" % table
   790 
   791     def multiple_insert(self, tablename, values, seqname=None, _test=False):
   792         """
   793         Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, 
   794         one for each row to be inserted, each with the same set of keys.
   795         Returns the list of ids of the inserted rows.        
   796         Set `seqname` to the ID if it's not the default, or to `False`
   797         if there isn't one.
   798         
   799             >>> db = DB(None, {})
   800             >>> db.supports_multiple_insert = True
   801             >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
   802             >>> db.multiple_insert('person', values=values, _test=True)
   803             <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
   804         """        
   805         if not values:
   806             return []
   807             
   808         if not self.supports_multiple_insert:
   809             out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
   810             if seqname is False:
   811                 return None
   812             else:
   813                 return out
   814                 
   815         keys = values[0].keys()
   816         #@@ make sure all keys are valid
   817 
   818         # make sure all rows have same keys.
   819         for v in values:
   820             if v.keys() != keys:
   821                 raise ValueError, 'Bad data'
   822 
   823         sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
   824 
   825         for i, row in enumerate(values):
   826             if i != 0:
   827                 sql_query.append(", ")
   828             SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
   829         
   830         if _test: return sql_query
   831 
   832         db_cursor = self._db_cursor()
   833         if seqname is not False: 
   834             sql_query = self._process_insert_query(sql_query, tablename, seqname)
   835 
   836         if isinstance(sql_query, tuple):
   837             # for some databases, a separate query has to be made to find 
   838             # the id of the inserted row.
   839             q1, q2 = sql_query
   840             self._db_execute(db_cursor, q1)
   841             self._db_execute(db_cursor, q2)
   842         else:
   843             self._db_execute(db_cursor, sql_query)
   844 
   845         try: 
   846             out = db_cursor.fetchone()[0]
   847             out = range(out-len(values)+1, out+1)        
   848         except Exception: 
   849             out = None
   850 
   851         if not self.ctx.transactions: 
   852             self.ctx.commit()
   853         return out
   854 
   855     
   856     def update(self, tables, where, vars=None, _test=False, **values): 
   857         """
   858         Update `tables` with clause `where` (interpolated using `vars`)
   859         and setting `values`.
   860 
   861             >>> db = DB(None, {})
   862             >>> name = 'Joseph'
   863             >>> q = db.update('foo', where='name = $name', name='bob', age=2,
   864             ...     created=SQLLiteral('NOW()'), vars=locals(), _test=True)
   865             >>> q
   866             <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
   867             >>> q.query()
   868             'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
   869             >>> q.values()
   870             [2, 'bob', 'Joseph']
   871         """
   872         if vars is None: vars = {}
   873         where = self._where(where, vars)
   874 
   875         query = (
   876           "UPDATE " + sqllist(tables) + 
   877           " SET " + sqlwhere(values, ', ') + 
   878           " WHERE " + where)
   879 
   880         if _test: return query
   881         
   882         db_cursor = self._db_cursor()
   883         self._db_execute(db_cursor, query)
   884         if not self.ctx.transactions: 
   885             self.ctx.commit()
   886         return db_cursor.rowcount
   887     
   888     def delete(self, table, where, using=None, vars=None, _test=False): 
   889         """
   890         Deletes from `table` with clauses `where` and `using`.
   891 
   892             >>> db = DB(None, {})
   893             >>> name = 'Joe'
   894             >>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
   895             <sql: "DELETE FROM foo WHERE name = 'Joe'">
   896         """
   897         if vars is None: vars = {}
   898         where = self._where(where, vars)
   899 
   900         q = 'DELETE FROM ' + table
   901         if using: q += ' USING ' + sqllist(using)
   902         if where: q += ' WHERE ' + where
   903 
   904         if _test: return q
   905 
   906         db_cursor = self._db_cursor()
   907         self._db_execute(db_cursor, q)
   908         if not self.ctx.transactions: 
   909             self.ctx.commit()
   910         return db_cursor.rowcount
   911 
   912     def _process_insert_query(self, query, tablename, seqname):
   913         return query
   914 
   915     def transaction(self): 
   916         """Start a transaction."""
   917         return Transaction(self.ctx)
   918     
   919 class PostgresDB(DB): 
   920     """Postgres driver."""
   921     def __init__(self, **keywords):
   922         if 'pw' in keywords:
   923             keywords['password'] = keywords.pop('pw')
   924             
   925         db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
   926         if db_module.__name__ == "psycopg2":
   927             import psycopg2.extensions
   928             psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
   929 
   930         # if db is not provided postgres driver will take it from PGDATABASE environment variable
   931         if 'db' in keywords:
   932             keywords['database'] = keywords.pop('db')
   933         
   934         self.dbname = "postgres"
   935         self.paramstyle = db_module.paramstyle
   936         DB.__init__(self, db_module, keywords)
   937         self.supports_multiple_insert = True
   938         self._sequences = None
   939         
   940     def _process_insert_query(self, query, tablename, seqname):
   941         if seqname is None:
   942             # when seqname is not provided guess the seqname and make sure it exists
   943             seqname = tablename + "_id_seq"
   944             if seqname not in self._get_all_sequences():
   945                 seqname = None
   946         
   947         if seqname:
   948             query += "; SELECT currval('%s')" % seqname
   949             
   950         return query
   951     
   952     def _get_all_sequences(self):
   953         """Query postgres to find names of all sequences used in this database."""
   954         if self._sequences is None:
   955             q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
   956             self._sequences = set([c.relname for c in self.query(q)])
   957         return self._sequences
   958 
   959     def _connect(self, keywords):
   960         conn = DB._connect(self, keywords)
   961         try:
   962             conn.set_client_encoding('UTF8')
   963         except AttributeError:
   964             # fallback for pgdb driver
   965             conn.cursor().execute("set client_encoding to 'UTF-8'")
   966         return conn
   967         
   968     def _connect_with_pooling(self, keywords):
   969         conn = DB._connect_with_pooling(self, keywords)
   970         conn._con._con.set_client_encoding('UTF8')
   971         return conn
   972 
   973 class MySQLDB(DB): 
   974     def __init__(self, **keywords):
   975         import MySQLdb as db
   976         if 'pw' in keywords:
   977             keywords['passwd'] = keywords['pw']
   978             del keywords['pw']
   979 
   980         if 'charset' not in keywords:
   981             keywords['charset'] = 'utf8'
   982         elif keywords['charset'] is None:
   983             del keywords['charset']
   984 
   985         self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
   986         self.dbname = "mysql"
   987         DB.__init__(self, db, keywords)
   988         self.supports_multiple_insert = True
   989         
   990     def _process_insert_query(self, query, tablename, seqname):
   991         return query, SQLQuery('SELECT last_insert_id();')
   992         
   993     def _get_insert_default_values_query(self, table):
   994         return "INSERT INTO %s () VALUES()" % table
   995 
   996 def import_driver(drivers, preferred=None):
   997     """Import the first available driver or preferred driver.
   998     """
   999     if preferred:
  1000         drivers = [preferred]
  1001 
  1002     for d in drivers:
  1003         try:
  1004             return __import__(d, None, None, ['x'])
  1005         except ImportError:
  1006             pass
  1007     raise ImportError("Unable to import " + " or ".join(drivers))
  1008 
  1009 class SqliteDB(DB): 
  1010     def __init__(self, **keywords):
  1011         db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
  1012 
  1013         if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
  1014             db.paramstyle = 'qmark'
  1015             
  1016         # sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
  1017         # It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
  1018         keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
  1019 
  1020         self.paramstyle = db.paramstyle
  1021         keywords['database'] = keywords.pop('db')
  1022         keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
  1023         self.dbname = "sqlite"        
  1024         DB.__init__(self, db, keywords)
  1025 
  1026     def _process_insert_query(self, query, tablename, seqname):
  1027         return query, SQLQuery('SELECT last_insert_rowid();')
  1028     
  1029     def query(self, *a, **kw):
  1030         out = DB.query(self, *a, **kw)
  1031         if isinstance(out, iterbetter):
  1032             del out.__len__
  1033         return out
  1034 
  1035 class FirebirdDB(DB):
  1036     """Firebird Database.
  1037     """
  1038     def __init__(self, **keywords):
  1039         try:
  1040             import kinterbasdb as db
  1041         except Exception:
  1042             db = None
  1043             pass
  1044         if 'pw' in keywords:
  1045             keywords['passwd'] = keywords['pw']
  1046             del keywords['pw']
  1047         keywords['database'] = keywords['db']
  1048         del keywords['db']
  1049         DB.__init__(self, db, keywords)
  1050         
  1051     def delete(self, table, where=None, using=None, vars=None, _test=False):
  1052         # firebird doesn't support using clause
  1053         using=None
  1054         return DB.delete(self, table, where, using, vars, _test)
  1055 
  1056     def sql_clauses(self, what, tables, where, group, order, limit, offset):
  1057         return (
  1058             ('SELECT', ''),
  1059             ('FIRST', limit),
  1060             ('SKIP', offset),
  1061             ('', what),
  1062             ('FROM', sqllist(tables)),
  1063             ('WHERE', where),
  1064             ('GROUP BY', group),
  1065             ('ORDER BY', order)
  1066         )
  1067 
  1068 class MSSQLDB(DB):
  1069     def __init__(self, **keywords):
  1070         import pymssql as db    
  1071         if 'pw' in keywords:
  1072             keywords['password'] = keywords.pop('pw')
  1073         keywords['database'] = keywords.pop('db')
  1074         self.dbname = "mssql"
  1075         DB.__init__(self, db, keywords)
  1076 
  1077     def _process_query(self, sql_query):
  1078         """Takes the SQLQuery object and returns query string and parameters.
  1079         """
  1080         # MSSQLDB expects params to be a tuple. 
  1081         # Overwriting the default implementation to convert params to tuple.
  1082         paramstyle = getattr(self, 'paramstyle', 'pyformat')
  1083         query = sql_query.query(paramstyle)
  1084         params = sql_query.values()
  1085         return query, tuple(params)
  1086 
  1087     def sql_clauses(self, what, tables, where, group, order, limit, offset): 
  1088         return (
  1089             ('SELECT', what),
  1090             ('TOP', limit),
  1091             ('FROM', sqllist(tables)),
  1092             ('WHERE', where),
  1093             ('GROUP BY', group),
  1094             ('ORDER BY', order),
  1095             ('OFFSET', offset))
  1096             
  1097     def _test(self):
  1098         """Test LIMIT.
  1099 
  1100             Fake presence of pymssql module for running tests.
  1101             >>> import sys
  1102             >>> sys.modules['pymssql'] = sys.modules['sys']
  1103             
  1104             MSSQL has TOP clause instead of LIMIT clause.
  1105             >>> db = MSSQLDB(db='test', user='joe', pw='secret')
  1106             >>> db.select('foo', limit=4, _test=True)
  1107             <sql: 'SELECT * TOP 4 FROM foo'>
  1108         """
  1109         pass
  1110 
  1111 class OracleDB(DB): 
  1112     def __init__(self, **keywords): 
  1113         import cx_Oracle as db 
  1114         if 'pw' in keywords: 
  1115             keywords['password'] = keywords.pop('pw') 
  1116 
  1117         #@@ TODO: use db.makedsn if host, port is specified 
  1118         keywords['dsn'] = keywords.pop('db') 
  1119         self.dbname = 'oracle' 
  1120         db.paramstyle = 'numeric' 
  1121         self.paramstyle = db.paramstyle
  1122 
  1123         # oracle doesn't support pooling 
  1124         keywords.pop('pooling', None) 
  1125         DB.__init__(self, db, keywords) 
  1126 
  1127     def _process_insert_query(self, query, tablename, seqname): 
  1128         if seqname is None: 
  1129             # It is not possible to get seq name from table name in Oracle
  1130             return query
  1131         else:
  1132             return query + "; SELECT %s.currval FROM dual" % seqname 
  1133 
  1134 _databases = {}
  1135 def database(dburl=None, **params):
  1136     """Creates appropriate database using params.
  1137     
  1138     Pooling will be enabled if DBUtils module is available. 
  1139     Pooling can be disabled by passing pooling=False in params.
  1140     """
  1141     dbn = params.pop('dbn')
  1142     if dbn in _databases:
  1143         return _databases[dbn](**params)
  1144     else:
  1145         raise UnknownDB, dbn
  1146 
  1147 def register_database(name, clazz):
  1148     """
  1149     Register a database.
  1150 
  1151         >>> class LegacyDB(DB): 
  1152         ...     def __init__(self, **params): 
  1153         ...        pass 
  1154         ...
  1155         >>> register_database('legacy', LegacyDB)
  1156         >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') 
  1157     """
  1158     _databases[name] = clazz
  1159 
  1160 register_database('mysql', MySQLDB)
  1161 register_database('postgres', PostgresDB)
  1162 register_database('sqlite', SqliteDB)
  1163 register_database('firebird', FirebirdDB)
  1164 register_database('mssql', MSSQLDB)
  1165 register_database('oracle', OracleDB)
  1166 
  1167 def _interpolate(format): 
  1168     """
  1169     Takes a format string and returns a list of 2-tuples of the form
  1170     (boolean, string) where boolean says whether string should be evaled
  1171     or not.
  1172 
  1173     from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
  1174     """
  1175     from tokenize import tokenprog
  1176 
  1177     def matchorfail(text, pos):
  1178         match = tokenprog.match(text, pos)
  1179         if match is None:
  1180             raise _ItplError(text, pos)
  1181         return match, match.end()
  1182 
  1183     namechars = "abcdefghijklmnopqrstuvwxyz" \
  1184         "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
  1185     chunks = []
  1186     pos = 0
  1187 
  1188     while 1:
  1189         dollar = format.find("$", pos)
  1190         if dollar < 0: 
  1191             break
  1192         nextchar = format[dollar + 1]
  1193 
  1194         if nextchar == "{":
  1195             chunks.append((0, format[pos:dollar]))
  1196             pos, level = dollar + 2, 1
  1197             while level:
  1198                 match, pos = matchorfail(format, pos)
  1199                 tstart, tend = match.regs[3]
  1200                 token = format[tstart:tend]
  1201                 if token == "{": 
  1202                     level = level + 1
  1203                 elif token == "}":  
  1204                     level = level - 1
  1205             chunks.append((1, format[dollar + 2:pos - 1]))
  1206 
  1207         elif nextchar in namechars:
  1208             chunks.append((0, format[pos:dollar]))
  1209             match, pos = matchorfail(format, dollar + 1)
  1210             while pos < len(format):
  1211                 if format[pos] == "." and \
  1212                     pos + 1 < len(format) and format[pos + 1] in namechars:
  1213                     match, pos = matchorfail(format, pos + 1)
  1214                 elif format[pos] in "([":
  1215                     pos, level = pos + 1, 1
  1216                     while level:
  1217                         match, pos = matchorfail(format, pos)
  1218                         tstart, tend = match.regs[3]
  1219                         token = format[tstart:tend]
  1220                         if token[0] in "([": 
  1221                             level = level + 1
  1222                         elif token[0] in ")]":  
  1223                             level = level - 1
  1224                 else: 
  1225                     break
  1226             chunks.append((1, format[dollar + 1:pos]))
  1227         else:
  1228             chunks.append((0, format[pos:dollar + 1]))
  1229             pos = dollar + 1 + (nextchar == "$")
  1230 
  1231     if pos < len(format): 
  1232         chunks.append((0, format[pos:]))
  1233     return chunks
  1234 
  1235 if __name__ == "__main__":
  1236     import doctest
  1237     doctest.testmod()