OpenSecurity/install/web.py-0.37/build/lib/web/db.py
author om
Mon, 02 Dec 2013 14:02:05 +0100
changeset 3 65432e6c6042
permissions -rwxr-xr-x
initial deployment and project layout commit
om@3
     1
"""
om@3
     2
Database API
om@3
     3
(part of web.py)
om@3
     4
"""
om@3
     5
om@3
     6
__all__ = [
om@3
     7
  "UnknownParamstyle", "UnknownDB", "TransactionError", 
om@3
     8
  "sqllist", "sqlors", "reparam", "sqlquote",
om@3
     9
  "SQLQuery", "SQLParam", "sqlparam",
om@3
    10
  "SQLLiteral", "sqlliteral",
om@3
    11
  "database", 'DB',
om@3
    12
]
om@3
    13
om@3
    14
import time
om@3
    15
try:
om@3
    16
    import datetime
om@3
    17
except ImportError:
om@3
    18
    datetime = None
om@3
    19
om@3
    20
try: set
om@3
    21
except NameError:
om@3
    22
    from sets import Set as set
om@3
    23
    
om@3
    24
from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
om@3
    25
om@3
    26
try:
om@3
    27
    # db module can work independent of web.py
om@3
    28
    from webapi import debug, config
om@3
    29
except:
om@3
    30
    import sys
om@3
    31
    debug = sys.stderr
om@3
    32
    config = storage()
om@3
    33
om@3
    34
class UnknownDB(Exception):
om@3
    35
    """raised for unsupported dbms"""
om@3
    36
    pass
om@3
    37
om@3
    38
class _ItplError(ValueError): 
om@3
    39
    def __init__(self, text, pos):
om@3
    40
        ValueError.__init__(self)
om@3
    41
        self.text = text
om@3
    42
        self.pos = pos
om@3
    43
    def __str__(self):
om@3
    44
        return "unfinished expression in %s at char %d" % (
om@3
    45
            repr(self.text), self.pos)
om@3
    46
om@3
    47
class TransactionError(Exception): pass
om@3
    48
om@3
    49
class UnknownParamstyle(Exception): 
om@3
    50
    """
om@3
    51
    raised for unsupported db paramstyles
om@3
    52
om@3
    53
    (currently supported: qmark, numeric, format, pyformat)
om@3
    54
    """
om@3
    55
    pass
om@3
    56
    
om@3
    57
class SQLParam(object):
om@3
    58
    """
om@3
    59
    Parameter in SQLQuery.
om@3
    60
    
om@3
    61
        >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
om@3
    62
        >>> q
om@3
    63
        <sql: "SELECT * FROM test WHERE name='joe'">
om@3
    64
        >>> q.query()
om@3
    65
        'SELECT * FROM test WHERE name=%s'
om@3
    66
        >>> q.values()
om@3
    67
        ['joe']
om@3
    68
    """
om@3
    69
    __slots__ = ["value"]
om@3
    70
om@3
    71
    def __init__(self, value):
om@3
    72
        self.value = value
om@3
    73
        
om@3
    74
    def get_marker(self, paramstyle='pyformat'):
om@3
    75
        if paramstyle == 'qmark':
om@3
    76
            return '?'
om@3
    77
        elif paramstyle == 'numeric':
om@3
    78
            return ':1'
om@3
    79
        elif paramstyle is None or paramstyle in ['format', 'pyformat']:
om@3
    80
            return '%s'
om@3
    81
        raise UnknownParamstyle, paramstyle
om@3
    82
        
om@3
    83
    def sqlquery(self): 
om@3
    84
        return SQLQuery([self])
om@3
    85
        
om@3
    86
    def __add__(self, other):
om@3
    87
        return self.sqlquery() + other
om@3
    88
        
om@3
    89
    def __radd__(self, other):
om@3
    90
        return other + self.sqlquery() 
om@3
    91
            
om@3
    92
    def __str__(self): 
om@3
    93
        return str(self.value)
om@3
    94
    
om@3
    95
    def __repr__(self):
om@3
    96
        return '<param: %s>' % repr(self.value)
om@3
    97
om@3
    98
sqlparam =  SQLParam
om@3
    99
om@3
   100
class SQLQuery(object):
om@3
   101
    """
om@3
   102
    You can pass this sort of thing as a clause in any db function.
om@3
   103
    Otherwise, you can pass a dictionary to the keyword argument `vars`
om@3
   104
    and the function will call reparam for you.
om@3
   105
om@3
   106
    Internally, consists of `items`, which is a list of strings and
om@3
   107
    SQLParams, which get concatenated to produce the actual query.
om@3
   108
    """
om@3
   109
    __slots__ = ["items"]
om@3
   110
om@3
   111
    # tested in sqlquote's docstring
om@3
   112
    def __init__(self, items=None):
om@3
   113
        r"""Creates a new SQLQuery.
om@3
   114
        
om@3
   115
            >>> SQLQuery("x")
om@3
   116
            <sql: 'x'>
om@3
   117
            >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
om@3
   118
            >>> q
om@3
   119
            <sql: 'SELECT * FROM test WHERE x=1'>
om@3
   120
            >>> q.query(), q.values()
om@3
   121
            ('SELECT * FROM test WHERE x=%s', [1])
om@3
   122
            >>> SQLQuery(SQLParam(1))
om@3
   123
            <sql: '1'>
om@3
   124
        """
om@3
   125
        if items is None:
om@3
   126
            self.items = []
om@3
   127
        elif isinstance(items, list):
om@3
   128
            self.items = items
om@3
   129
        elif isinstance(items, SQLParam):
om@3
   130
            self.items = [items]
om@3
   131
        elif isinstance(items, SQLQuery):
om@3
   132
            self.items = list(items.items)
om@3
   133
        else:
om@3
   134
            self.items = [items]
om@3
   135
            
om@3
   136
        # Take care of SQLLiterals
om@3
   137
        for i, item in enumerate(self.items):
om@3
   138
            if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
om@3
   139
                self.items[i] = item.value.v
om@3
   140
om@3
   141
    def append(self, value):
om@3
   142
        self.items.append(value)
om@3
   143
om@3
   144
    def __add__(self, other):
om@3
   145
        if isinstance(other, basestring):
om@3
   146
            items = [other]
om@3
   147
        elif isinstance(other, SQLQuery):
om@3
   148
            items = other.items
om@3
   149
        else:
om@3
   150
            return NotImplemented
om@3
   151
        return SQLQuery(self.items + items)
om@3
   152
om@3
   153
    def __radd__(self, other):
om@3
   154
        if isinstance(other, basestring):
om@3
   155
            items = [other]
om@3
   156
        else:
om@3
   157
            return NotImplemented
om@3
   158
            
om@3
   159
        return SQLQuery(items + self.items)
om@3
   160
om@3
   161
    def __iadd__(self, other):
om@3
   162
        if isinstance(other, (basestring, SQLParam)):
om@3
   163
            self.items.append(other)
om@3
   164
        elif isinstance(other, SQLQuery):
om@3
   165
            self.items.extend(other.items)
om@3
   166
        else:
om@3
   167
            return NotImplemented
om@3
   168
        return self
om@3
   169
om@3
   170
    def __len__(self):
om@3
   171
        return len(self.query())
om@3
   172
        
om@3
   173
    def query(self, paramstyle=None):
om@3
   174
        """
om@3
   175
        Returns the query part of the sql query.
om@3
   176
            >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
om@3
   177
            >>> q.query()
om@3
   178
            'SELECT * FROM test WHERE name=%s'
om@3
   179
            >>> q.query(paramstyle='qmark')
om@3
   180
            'SELECT * FROM test WHERE name=?'
om@3
   181
        """
om@3
   182
        s = []
om@3
   183
        for x in self.items:
om@3
   184
            if isinstance(x, SQLParam):
om@3
   185
                x = x.get_marker(paramstyle)
om@3
   186
                s.append(safestr(x))
om@3
   187
            else:
om@3
   188
                x = safestr(x)
om@3
   189
                # automatically escape % characters in the query
om@3
   190
                # For backward compatability, ignore escaping when the query looks already escaped
om@3
   191
                if paramstyle in ['format', 'pyformat']:
om@3
   192
                    if '%' in x and '%%' not in x:
om@3
   193
                        x = x.replace('%', '%%')
om@3
   194
                s.append(x)
om@3
   195
        return "".join(s)
om@3
   196
    
om@3
   197
    def values(self):
om@3
   198
        """
om@3
   199
        Returns the values of the parameters used in the sql query.
om@3
   200
            >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
om@3
   201
            >>> q.values()
om@3
   202
            ['joe']
om@3
   203
        """
om@3
   204
        return [i.value for i in self.items if isinstance(i, SQLParam)]
om@3
   205
        
om@3
   206
    def join(items, sep=' ', prefix=None, suffix=None, target=None):
om@3
   207
        """
om@3
   208
        Joins multiple queries.
om@3
   209
        
om@3
   210
        >>> SQLQuery.join(['a', 'b'], ', ')
om@3
   211
        <sql: 'a, b'>
om@3
   212
om@3
   213
        Optinally, prefix and suffix arguments can be provided.
om@3
   214
om@3
   215
        >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
om@3
   216
        <sql: '(a, b)'>
om@3
   217
om@3
   218
        If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
om@3
   219
        """
om@3
   220
        if target is None:
om@3
   221
            target = SQLQuery()
om@3
   222
om@3
   223
        target_items = target.items
om@3
   224
om@3
   225
        if prefix:
om@3
   226
            target_items.append(prefix)
om@3
   227
om@3
   228
        for i, item in enumerate(items):
om@3
   229
            if i != 0:
om@3
   230
                target_items.append(sep)
om@3
   231
            if isinstance(item, SQLQuery):
om@3
   232
                target_items.extend(item.items)
om@3
   233
            else:
om@3
   234
                target_items.append(item)
om@3
   235
om@3
   236
        if suffix:
om@3
   237
            target_items.append(suffix)
om@3
   238
        return target
om@3
   239
    
om@3
   240
    join = staticmethod(join)
om@3
   241
    
om@3
   242
    def _str(self):
om@3
   243
        try:
om@3
   244
            return self.query() % tuple([sqlify(x) for x in self.values()])            
om@3
   245
        except (ValueError, TypeError):
om@3
   246
            return self.query()
om@3
   247
        
om@3
   248
    def __str__(self):
om@3
   249
        return safestr(self._str())
om@3
   250
        
om@3
   251
    def __unicode__(self):
om@3
   252
        return safeunicode(self._str())
om@3
   253
om@3
   254
    def __repr__(self):
om@3
   255
        return '<sql: %s>' % repr(str(self))
om@3
   256
om@3
   257
class SQLLiteral: 
om@3
   258
    """
om@3
   259
    Protects a string from `sqlquote`.
om@3
   260
om@3
   261
        >>> sqlquote('NOW()')
om@3
   262
        <sql: "'NOW()'">
om@3
   263
        >>> sqlquote(SQLLiteral('NOW()'))
om@3
   264
        <sql: 'NOW()'>
om@3
   265
    """
om@3
   266
    def __init__(self, v): 
om@3
   267
        self.v = v
om@3
   268
om@3
   269
    def __repr__(self): 
om@3
   270
        return self.v
om@3
   271
om@3
   272
sqlliteral = SQLLiteral
om@3
   273
om@3
   274
def _sqllist(values):
om@3
   275
    """
om@3
   276
        >>> _sqllist([1, 2, 3])
om@3
   277
        <sql: '(1, 2, 3)'>
om@3
   278
    """
om@3
   279
    items = []
om@3
   280
    items.append('(')
om@3
   281
    for i, v in enumerate(values):
om@3
   282
        if i != 0:
om@3
   283
            items.append(', ')
om@3
   284
        items.append(sqlparam(v))
om@3
   285
    items.append(')')
om@3
   286
    return SQLQuery(items)
om@3
   287
om@3
   288
def reparam(string_, dictionary): 
om@3
   289
    """
om@3
   290
    Takes a string and a dictionary and interpolates the string
om@3
   291
    using values from the dictionary. Returns an `SQLQuery` for the result.
om@3
   292
om@3
   293
        >>> reparam("s = $s", dict(s=True))
om@3
   294
        <sql: "s = 't'">
om@3
   295
        >>> reparam("s IN $s", dict(s=[1, 2]))
om@3
   296
        <sql: 's IN (1, 2)'>
om@3
   297
    """
om@3
   298
    dictionary = dictionary.copy() # eval mucks with it
om@3
   299
    vals = []
om@3
   300
    result = []
om@3
   301
    for live, chunk in _interpolate(string_):
om@3
   302
        if live:
om@3
   303
            v = eval(chunk, dictionary)
om@3
   304
            result.append(sqlquote(v))
om@3
   305
        else: 
om@3
   306
            result.append(chunk)
om@3
   307
    return SQLQuery.join(result, '')
om@3
   308
om@3
   309
def sqlify(obj): 
om@3
   310
    """
om@3
   311
    converts `obj` to its proper SQL version
om@3
   312
om@3
   313
        >>> sqlify(None)
om@3
   314
        'NULL'
om@3
   315
        >>> sqlify(True)
om@3
   316
        "'t'"
om@3
   317
        >>> sqlify(3)
om@3
   318
        '3'
om@3
   319
    """
om@3
   320
    # because `1 == True and hash(1) == hash(True)`
om@3
   321
    # we have to do this the hard way...
om@3
   322
om@3
   323
    if obj is None:
om@3
   324
        return 'NULL'
om@3
   325
    elif obj is True:
om@3
   326
        return "'t'"
om@3
   327
    elif obj is False:
om@3
   328
        return "'f'"
om@3
   329
    elif datetime and isinstance(obj, datetime.datetime):
om@3
   330
        return repr(obj.isoformat())
om@3
   331
    else:
om@3
   332
        if isinstance(obj, unicode): obj = obj.encode('utf8')
om@3
   333
        return repr(obj)
om@3
   334
om@3
   335
def sqllist(lst): 
om@3
   336
    """
om@3
   337
    Converts the arguments for use in something like a WHERE clause.
om@3
   338
    
om@3
   339
        >>> sqllist(['a', 'b'])
om@3
   340
        'a, b'
om@3
   341
        >>> sqllist('a')
om@3
   342
        'a'
om@3
   343
        >>> sqllist(u'abc')
om@3
   344
        u'abc'
om@3
   345
    """
om@3
   346
    if isinstance(lst, basestring): 
om@3
   347
        return lst
om@3
   348
    else:
om@3
   349
        return ', '.join(lst)
om@3
   350
om@3
   351
def sqlors(left, lst):
om@3
   352
    """
om@3
   353
    `left is a SQL clause like `tablename.arg = ` 
om@3
   354
    and `lst` is a list of values. Returns a reparam-style
om@3
   355
    pair featuring the SQL that ORs together the clause
om@3
   356
    for each item in the lst.
om@3
   357
om@3
   358
        >>> sqlors('foo = ', [])
om@3
   359
        <sql: '1=2'>
om@3
   360
        >>> sqlors('foo = ', [1])
om@3
   361
        <sql: 'foo = 1'>
om@3
   362
        >>> sqlors('foo = ', 1)
om@3
   363
        <sql: 'foo = 1'>
om@3
   364
        >>> sqlors('foo = ', [1,2,3])
om@3
   365
        <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
om@3
   366
    """
om@3
   367
    if isinstance(lst, iters):
om@3
   368
        lst = list(lst)
om@3
   369
        ln = len(lst)
om@3
   370
        if ln == 0:
om@3
   371
            return SQLQuery("1=2")
om@3
   372
        if ln == 1:
om@3
   373
            lst = lst[0]
om@3
   374
om@3
   375
    if isinstance(lst, iters):
om@3
   376
        return SQLQuery(['('] + 
om@3
   377
          sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
om@3
   378
          ['1=2)']
om@3
   379
        )
om@3
   380
    else:
om@3
   381
        return left + sqlparam(lst)
om@3
   382
        
om@3
   383
def sqlwhere(dictionary, grouping=' AND '): 
om@3
   384
    """
om@3
   385
    Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
om@3
   386
    
om@3
   387
        >>> sqlwhere({'cust_id': 2, 'order_id':3})
om@3
   388
        <sql: 'order_id = 3 AND cust_id = 2'>
om@3
   389
        >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
om@3
   390
        <sql: 'order_id = 3, cust_id = 2'>
om@3
   391
        >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
om@3
   392
        'a = %s AND b = %s'
om@3
   393
    """
om@3
   394
    return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
om@3
   395
om@3
   396
def sqlquote(a): 
om@3
   397
    """
om@3
   398
    Ensures `a` is quoted properly for use in a SQL query.
om@3
   399
om@3
   400
        >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
om@3
   401
        <sql: "WHERE x = 't' AND y = 3">
om@3
   402
        >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
om@3
   403
        <sql: "WHERE x = 't' AND y IN (2, 3)">
om@3
   404
    """
om@3
   405
    if isinstance(a, list):
om@3
   406
        return _sqllist(a)
om@3
   407
    else:
om@3
   408
        return sqlparam(a).sqlquery()
om@3
   409
om@3
   410
class Transaction:
om@3
   411
    """Database transaction."""
om@3
   412
    def __init__(self, ctx):
om@3
   413
        self.ctx = ctx
om@3
   414
        self.transaction_count = transaction_count = len(ctx.transactions)
om@3
   415
om@3
   416
        class transaction_engine:
om@3
   417
            """Transaction Engine used in top level transactions."""
om@3
   418
            def do_transact(self):
om@3
   419
                ctx.commit(unload=False)
om@3
   420
om@3
   421
            def do_commit(self):
om@3
   422
                ctx.commit()
om@3
   423
om@3
   424
            def do_rollback(self):
om@3
   425
                ctx.rollback()
om@3
   426
om@3
   427
        class subtransaction_engine:
om@3
   428
            """Transaction Engine used in sub transactions."""
om@3
   429
            def query(self, q):
om@3
   430
                db_cursor = ctx.db.cursor()
om@3
   431
                ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
om@3
   432
om@3
   433
            def do_transact(self):
om@3
   434
                self.query('SAVEPOINT webpy_sp_%s')
om@3
   435
om@3
   436
            def do_commit(self):
om@3
   437
                self.query('RELEASE SAVEPOINT webpy_sp_%s')
om@3
   438
om@3
   439
            def do_rollback(self):
om@3
   440
                self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
om@3
   441
om@3
   442
        class dummy_engine:
om@3
   443
            """Transaction Engine used instead of subtransaction_engine 
om@3
   444
            when sub transactions are not supported."""
om@3
   445
            do_transact = do_commit = do_rollback = lambda self: None
om@3
   446
om@3
   447
        if self.transaction_count:
om@3
   448
            # nested transactions are not supported in some databases
om@3
   449
            if self.ctx.get('ignore_nested_transactions'):
om@3
   450
                self.engine = dummy_engine()
om@3
   451
            else:
om@3
   452
                self.engine = subtransaction_engine()
om@3
   453
        else:
om@3
   454
            self.engine = transaction_engine()
om@3
   455
om@3
   456
        self.engine.do_transact()
om@3
   457
        self.ctx.transactions.append(self)
om@3
   458
om@3
   459
    def __enter__(self):
om@3
   460
        return self
om@3
   461
om@3
   462
    def __exit__(self, exctype, excvalue, traceback):
om@3
   463
        if exctype is not None:
om@3
   464
            self.rollback()
om@3
   465
        else:
om@3
   466
            self.commit()
om@3
   467
om@3
   468
    def commit(self):
om@3
   469
        if len(self.ctx.transactions) > self.transaction_count:
om@3
   470
            self.engine.do_commit()
om@3
   471
            self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
om@3
   472
om@3
   473
    def rollback(self):
om@3
   474
        if len(self.ctx.transactions) > self.transaction_count:
om@3
   475
            self.engine.do_rollback()
om@3
   476
            self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
om@3
   477
om@3
   478
class DB: 
om@3
   479
    """Database"""
om@3
   480
    def __init__(self, db_module, keywords):
om@3
   481
        """Creates a database.
om@3
   482
        """
om@3
   483
        # some DB implementaions take optional paramater `driver` to use a specific driver modue
om@3
   484
        # but it should not be passed to connect
om@3
   485
        keywords.pop('driver', None)
om@3
   486
om@3
   487
        self.db_module = db_module
om@3
   488
        self.keywords = keywords
om@3
   489
om@3
   490
        self._ctx = threadeddict()
om@3
   491
        # flag to enable/disable printing queries
om@3
   492
        self.printing = config.get('debug_sql', config.get('debug', False))
om@3
   493
        self.supports_multiple_insert = False
om@3
   494
        
om@3
   495
        try:
om@3
   496
            import DBUtils
om@3
   497
            # enable pooling if DBUtils module is available.
om@3
   498
            self.has_pooling = True
om@3
   499
        except ImportError:
om@3
   500
            self.has_pooling = False
om@3
   501
            
om@3
   502
        # Pooling can be disabled by passing pooling=False in the keywords.
om@3
   503
        self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
om@3
   504
            
om@3
   505
    def _getctx(self): 
om@3
   506
        if not self._ctx.get('db'):
om@3
   507
            self._load_context(self._ctx)
om@3
   508
        return self._ctx
om@3
   509
    ctx = property(_getctx)
om@3
   510
    
om@3
   511
    def _load_context(self, ctx):
om@3
   512
        ctx.dbq_count = 0
om@3
   513
        ctx.transactions = [] # stack of transactions
om@3
   514
        
om@3
   515
        if self.has_pooling:
om@3
   516
            ctx.db = self._connect_with_pooling(self.keywords)
om@3
   517
        else:
om@3
   518
            ctx.db = self._connect(self.keywords)
om@3
   519
        ctx.db_execute = self._db_execute
om@3
   520
        
om@3
   521
        if not hasattr(ctx.db, 'commit'):
om@3
   522
            ctx.db.commit = lambda: None
om@3
   523
om@3
   524
        if not hasattr(ctx.db, 'rollback'):
om@3
   525
            ctx.db.rollback = lambda: None
om@3
   526
            
om@3
   527
        def commit(unload=True):
om@3
   528
            # do db commit and release the connection if pooling is enabled.            
om@3
   529
            ctx.db.commit()
om@3
   530
            if unload and self.has_pooling:
om@3
   531
                self._unload_context(self._ctx)
om@3
   532
                
om@3
   533
        def rollback():
om@3
   534
            # do db rollback and release the connection if pooling is enabled.
om@3
   535
            ctx.db.rollback()
om@3
   536
            if self.has_pooling:
om@3
   537
                self._unload_context(self._ctx)
om@3
   538
                
om@3
   539
        ctx.commit = commit
om@3
   540
        ctx.rollback = rollback
om@3
   541
            
om@3
   542
    def _unload_context(self, ctx):
om@3
   543
        del ctx.db
om@3
   544
            
om@3
   545
    def _connect(self, keywords):
om@3
   546
        return self.db_module.connect(**keywords)
om@3
   547
        
om@3
   548
    def _connect_with_pooling(self, keywords):
om@3
   549
        def get_pooled_db():
om@3
   550
            from DBUtils import PooledDB
om@3
   551
om@3
   552
            # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
om@3
   553
            # see Bug#122112
om@3
   554
            
om@3
   555
            if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
om@3
   556
                return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
om@3
   557
            else:
om@3
   558
                return PooledDB.PooledDB(creator=self.db_module, **keywords)
om@3
   559
        
om@3
   560
        if getattr(self, '_pooleddb', None) is None:
om@3
   561
            self._pooleddb = get_pooled_db()
om@3
   562
        
om@3
   563
        return self._pooleddb.connection()
om@3
   564
        
om@3
   565
    def _db_cursor(self):
om@3
   566
        return self.ctx.db.cursor()
om@3
   567
om@3
   568
    def _param_marker(self):
om@3
   569
        """Returns parameter marker based on paramstyle attribute if this database."""
om@3
   570
        style = getattr(self, 'paramstyle', 'pyformat')
om@3
   571
om@3
   572
        if style == 'qmark':
om@3
   573
            return '?'
om@3
   574
        elif style == 'numeric':
om@3
   575
            return ':1'
om@3
   576
        elif style in ['format', 'pyformat']:
om@3
   577
            return '%s'
om@3
   578
        raise UnknownParamstyle, style
om@3
   579
om@3
   580
    def _db_execute(self, cur, sql_query): 
om@3
   581
        """executes an sql query"""
om@3
   582
        self.ctx.dbq_count += 1
om@3
   583
        
om@3
   584
        try:
om@3
   585
            a = time.time()
om@3
   586
            query, params = self._process_query(sql_query)
om@3
   587
            out = cur.execute(query, params)
om@3
   588
            b = time.time()
om@3
   589
        except:
om@3
   590
            if self.printing:
om@3
   591
                print >> debug, 'ERR:', str(sql_query)
om@3
   592
            if self.ctx.transactions:
om@3
   593
                self.ctx.transactions[-1].rollback()
om@3
   594
            else:
om@3
   595
                self.ctx.rollback()
om@3
   596
            raise
om@3
   597
om@3
   598
        if self.printing:
om@3
   599
            print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
om@3
   600
        return out
om@3
   601
om@3
   602
    def _process_query(self, sql_query):
om@3
   603
        """Takes the SQLQuery object and returns query string and parameters.
om@3
   604
        """
om@3
   605
        paramstyle = getattr(self, 'paramstyle', 'pyformat')
om@3
   606
        query = sql_query.query(paramstyle)
om@3
   607
        params = sql_query.values()
om@3
   608
        return query, params
om@3
   609
    
om@3
   610
    def _where(self, where, vars): 
om@3
   611
        if isinstance(where, (int, long)):
om@3
   612
            where = "id = " + sqlparam(where)
om@3
   613
        #@@@ for backward-compatibility
om@3
   614
        elif isinstance(where, (list, tuple)) and len(where) == 2:
om@3
   615
            where = SQLQuery(where[0], where[1])
om@3
   616
        elif isinstance(where, SQLQuery):
om@3
   617
            pass
om@3
   618
        else:
om@3
   619
            where = reparam(where, vars)        
om@3
   620
        return where
om@3
   621
    
om@3
   622
    def query(self, sql_query, vars=None, processed=False, _test=False): 
om@3
   623
        """
om@3
   624
        Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
om@3
   625
        If `processed=True`, `vars` is a `reparam`-style list to use 
om@3
   626
        instead of interpolating.
om@3
   627
        
om@3
   628
            >>> db = DB(None, {})
om@3
   629
            >>> db.query("SELECT * FROM foo", _test=True)
om@3
   630
            <sql: 'SELECT * FROM foo'>
om@3
   631
            >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
om@3
   632
            <sql: "SELECT * FROM foo WHERE x = 'f'">
om@3
   633
            >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
om@3
   634
            <sql: "SELECT * FROM foo WHERE x = 'f'">
om@3
   635
        """
om@3
   636
        if vars is None: vars = {}
om@3
   637
        
om@3
   638
        if not processed and not isinstance(sql_query, SQLQuery):
om@3
   639
            sql_query = reparam(sql_query, vars)
om@3
   640
        
om@3
   641
        if _test: return sql_query
om@3
   642
        
om@3
   643
        db_cursor = self._db_cursor()
om@3
   644
        self._db_execute(db_cursor, sql_query)
om@3
   645
        
om@3
   646
        if db_cursor.description:
om@3
   647
            names = [x[0] for x in db_cursor.description]
om@3
   648
            def iterwrapper():
om@3
   649
                row = db_cursor.fetchone()
om@3
   650
                while row:
om@3
   651
                    yield storage(dict(zip(names, row)))
om@3
   652
                    row = db_cursor.fetchone()
om@3
   653
            out = iterbetter(iterwrapper())
om@3
   654
            out.__len__ = lambda: int(db_cursor.rowcount)
om@3
   655
            out.list = lambda: [storage(dict(zip(names, x))) \
om@3
   656
                               for x in db_cursor.fetchall()]
om@3
   657
        else:
om@3
   658
            out = db_cursor.rowcount
om@3
   659
        
om@3
   660
        if not self.ctx.transactions: 
om@3
   661
            self.ctx.commit()
om@3
   662
        return out
om@3
   663
    
om@3
   664
    def select(self, tables, vars=None, what='*', where=None, order=None, group=None, 
om@3
   665
               limit=None, offset=None, _test=False): 
om@3
   666
        """
om@3
   667
        Selects `what` from `tables` with clauses `where`, `order`, 
om@3
   668
        `group`, `limit`, and `offset`. Uses vars to interpolate. 
om@3
   669
        Otherwise, each clause can be a SQLQuery.
om@3
   670
        
om@3
   671
            >>> db = DB(None, {})
om@3
   672
            >>> db.select('foo', _test=True)
om@3
   673
            <sql: 'SELECT * FROM foo'>
om@3
   674
            >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
om@3
   675
            <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
om@3
   676
        """
om@3
   677
        if vars is None: vars = {}
om@3
   678
        sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
om@3
   679
        clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
om@3
   680
        qout = SQLQuery.join(clauses)
om@3
   681
        if _test: return qout
om@3
   682
        return self.query(qout, processed=True)
om@3
   683
    
om@3
   684
    def where(self, table, what='*', order=None, group=None, limit=None, 
om@3
   685
              offset=None, _test=False, **kwargs):
om@3
   686
        """
om@3
   687
        Selects from `table` where keys are equal to values in `kwargs`.
om@3
   688
        
om@3
   689
            >>> db = DB(None, {})
om@3
   690
            >>> db.where('foo', bar_id=3, _test=True)
om@3
   691
            <sql: 'SELECT * FROM foo WHERE bar_id = 3'>
om@3
   692
            >>> db.where('foo', source=2, crust='dewey', _test=True)
om@3
   693
            <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
om@3
   694
            >>> db.where('foo', _test=True)
om@3
   695
            <sql: 'SELECT * FROM foo'>
om@3
   696
        """
om@3
   697
        where_clauses = []
om@3
   698
        for k, v in kwargs.iteritems():
om@3
   699
            where_clauses.append(k + ' = ' + sqlquote(v))
om@3
   700
            
om@3
   701
        if where_clauses:
om@3
   702
            where = SQLQuery.join(where_clauses, " AND ")
om@3
   703
        else:
om@3
   704
            where = None
om@3
   705
            
om@3
   706
        return self.select(table, what=what, order=order, 
om@3
   707
               group=group, limit=limit, offset=offset, _test=_test, 
om@3
   708
               where=where)
om@3
   709
    
om@3
   710
    def sql_clauses(self, what, tables, where, group, order, limit, offset): 
om@3
   711
        return (
om@3
   712
            ('SELECT', what),
om@3
   713
            ('FROM', sqllist(tables)),
om@3
   714
            ('WHERE', where),
om@3
   715
            ('GROUP BY', group),
om@3
   716
            ('ORDER BY', order),
om@3
   717
            ('LIMIT', limit),
om@3
   718
            ('OFFSET', offset))
om@3
   719
    
om@3
   720
    def gen_clause(self, sql, val, vars): 
om@3
   721
        if isinstance(val, (int, long)):
om@3
   722
            if sql == 'WHERE':
om@3
   723
                nout = 'id = ' + sqlquote(val)
om@3
   724
            else:
om@3
   725
                nout = SQLQuery(val)
om@3
   726
        #@@@
om@3
   727
        elif isinstance(val, (list, tuple)) and len(val) == 2:
om@3
   728
            nout = SQLQuery(val[0], val[1]) # backwards-compatibility
om@3
   729
        elif isinstance(val, SQLQuery):
om@3
   730
            nout = val
om@3
   731
        else:
om@3
   732
            nout = reparam(val, vars)
om@3
   733
om@3
   734
        def xjoin(a, b):
om@3
   735
            if a and b: return a + ' ' + b
om@3
   736
            else: return a or b
om@3
   737
om@3
   738
        return xjoin(sql, nout)
om@3
   739
om@3
   740
    def insert(self, tablename, seqname=None, _test=False, **values): 
om@3
   741
        """
om@3
   742
        Inserts `values` into `tablename`. Returns current sequence ID.
om@3
   743
        Set `seqname` to the ID if it's not the default, or to `False`
om@3
   744
        if there isn't one.
om@3
   745
        
om@3
   746
            >>> db = DB(None, {})
om@3
   747
            >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
om@3
   748
            >>> q
om@3
   749
            <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
om@3
   750
            >>> q.query()
om@3
   751
            'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
om@3
   752
            >>> q.values()
om@3
   753
            [2, 'bob']
om@3
   754
        """
om@3
   755
        def q(x): return "(" + x + ")"
om@3
   756
        
om@3
   757
        if values:
om@3
   758
            _keys = SQLQuery.join(values.keys(), ', ')
om@3
   759
            _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
om@3
   760
            sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
om@3
   761
        else:
om@3
   762
            sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
om@3
   763
om@3
   764
        if _test: return sql_query
om@3
   765
        
om@3
   766
        db_cursor = self._db_cursor()
om@3
   767
        if seqname is not False: 
om@3
   768
            sql_query = self._process_insert_query(sql_query, tablename, seqname)
om@3
   769
om@3
   770
        if isinstance(sql_query, tuple):
om@3
   771
            # for some databases, a separate query has to be made to find 
om@3
   772
            # the id of the inserted row.
om@3
   773
            q1, q2 = sql_query
om@3
   774
            self._db_execute(db_cursor, q1)
om@3
   775
            self._db_execute(db_cursor, q2)
om@3
   776
        else:
om@3
   777
            self._db_execute(db_cursor, sql_query)
om@3
   778
om@3
   779
        try: 
om@3
   780
            out = db_cursor.fetchone()[0]
om@3
   781
        except Exception: 
om@3
   782
            out = None
om@3
   783
        
om@3
   784
        if not self.ctx.transactions: 
om@3
   785
            self.ctx.commit()
om@3
   786
        return out
om@3
   787
        
om@3
   788
    def _get_insert_default_values_query(self, table):
om@3
   789
        return "INSERT INTO %s DEFAULT VALUES" % table
om@3
   790
om@3
   791
    def multiple_insert(self, tablename, values, seqname=None, _test=False):
om@3
   792
        """
om@3
   793
        Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, 
om@3
   794
        one for each row to be inserted, each with the same set of keys.
om@3
   795
        Returns the list of ids of the inserted rows.        
om@3
   796
        Set `seqname` to the ID if it's not the default, or to `False`
om@3
   797
        if there isn't one.
om@3
   798
        
om@3
   799
            >>> db = DB(None, {})
om@3
   800
            >>> db.supports_multiple_insert = True
om@3
   801
            >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
om@3
   802
            >>> db.multiple_insert('person', values=values, _test=True)
om@3
   803
            <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
om@3
   804
        """        
om@3
   805
        if not values:
om@3
   806
            return []
om@3
   807
            
om@3
   808
        if not self.supports_multiple_insert:
om@3
   809
            out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
om@3
   810
            if seqname is False:
om@3
   811
                return None
om@3
   812
            else:
om@3
   813
                return out
om@3
   814
                
om@3
   815
        keys = values[0].keys()
om@3
   816
        #@@ make sure all keys are valid
om@3
   817
om@3
   818
        # make sure all rows have same keys.
om@3
   819
        for v in values:
om@3
   820
            if v.keys() != keys:
om@3
   821
                raise ValueError, 'Bad data'
om@3
   822
om@3
   823
        sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
om@3
   824
om@3
   825
        for i, row in enumerate(values):
om@3
   826
            if i != 0:
om@3
   827
                sql_query.append(", ")
om@3
   828
            SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
om@3
   829
        
om@3
   830
        if _test: return sql_query
om@3
   831
om@3
   832
        db_cursor = self._db_cursor()
om@3
   833
        if seqname is not False: 
om@3
   834
            sql_query = self._process_insert_query(sql_query, tablename, seqname)
om@3
   835
om@3
   836
        if isinstance(sql_query, tuple):
om@3
   837
            # for some databases, a separate query has to be made to find 
om@3
   838
            # the id of the inserted row.
om@3
   839
            q1, q2 = sql_query
om@3
   840
            self._db_execute(db_cursor, q1)
om@3
   841
            self._db_execute(db_cursor, q2)
om@3
   842
        else:
om@3
   843
            self._db_execute(db_cursor, sql_query)
om@3
   844
om@3
   845
        try: 
om@3
   846
            out = db_cursor.fetchone()[0]
om@3
   847
            out = range(out-len(values)+1, out+1)        
om@3
   848
        except Exception: 
om@3
   849
            out = None
om@3
   850
om@3
   851
        if not self.ctx.transactions: 
om@3
   852
            self.ctx.commit()
om@3
   853
        return out
om@3
   854
om@3
   855
    
om@3
   856
    def update(self, tables, where, vars=None, _test=False, **values): 
om@3
   857
        """
om@3
   858
        Update `tables` with clause `where` (interpolated using `vars`)
om@3
   859
        and setting `values`.
om@3
   860
om@3
   861
            >>> db = DB(None, {})
om@3
   862
            >>> name = 'Joseph'
om@3
   863
            >>> q = db.update('foo', where='name = $name', name='bob', age=2,
om@3
   864
            ...     created=SQLLiteral('NOW()'), vars=locals(), _test=True)
om@3
   865
            >>> q
om@3
   866
            <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
om@3
   867
            >>> q.query()
om@3
   868
            'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
om@3
   869
            >>> q.values()
om@3
   870
            [2, 'bob', 'Joseph']
om@3
   871
        """
om@3
   872
        if vars is None: vars = {}
om@3
   873
        where = self._where(where, vars)
om@3
   874
om@3
   875
        query = (
om@3
   876
          "UPDATE " + sqllist(tables) + 
om@3
   877
          " SET " + sqlwhere(values, ', ') + 
om@3
   878
          " WHERE " + where)
om@3
   879
om@3
   880
        if _test: return query
om@3
   881
        
om@3
   882
        db_cursor = self._db_cursor()
om@3
   883
        self._db_execute(db_cursor, query)
om@3
   884
        if not self.ctx.transactions: 
om@3
   885
            self.ctx.commit()
om@3
   886
        return db_cursor.rowcount
om@3
   887
    
om@3
   888
    def delete(self, table, where, using=None, vars=None, _test=False): 
om@3
   889
        """
om@3
   890
        Deletes from `table` with clauses `where` and `using`.
om@3
   891
om@3
   892
            >>> db = DB(None, {})
om@3
   893
            >>> name = 'Joe'
om@3
   894
            >>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
om@3
   895
            <sql: "DELETE FROM foo WHERE name = 'Joe'">
om@3
   896
        """
om@3
   897
        if vars is None: vars = {}
om@3
   898
        where = self._where(where, vars)
om@3
   899
om@3
   900
        q = 'DELETE FROM ' + table
om@3
   901
        if using: q += ' USING ' + sqllist(using)
om@3
   902
        if where: q += ' WHERE ' + where
om@3
   903
om@3
   904
        if _test: return q
om@3
   905
om@3
   906
        db_cursor = self._db_cursor()
om@3
   907
        self._db_execute(db_cursor, q)
om@3
   908
        if not self.ctx.transactions: 
om@3
   909
            self.ctx.commit()
om@3
   910
        return db_cursor.rowcount
om@3
   911
om@3
   912
    def _process_insert_query(self, query, tablename, seqname):
om@3
   913
        return query
om@3
   914
om@3
   915
    def transaction(self): 
om@3
   916
        """Start a transaction."""
om@3
   917
        return Transaction(self.ctx)
om@3
   918
    
om@3
   919
class PostgresDB(DB): 
om@3
   920
    """Postgres driver."""
om@3
   921
    def __init__(self, **keywords):
om@3
   922
        if 'pw' in keywords:
om@3
   923
            keywords['password'] = keywords.pop('pw')
om@3
   924
            
om@3
   925
        db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
om@3
   926
        if db_module.__name__ == "psycopg2":
om@3
   927
            import psycopg2.extensions
om@3
   928
            psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
om@3
   929
om@3
   930
        # if db is not provided postgres driver will take it from PGDATABASE environment variable
om@3
   931
        if 'db' in keywords:
om@3
   932
            keywords['database'] = keywords.pop('db')
om@3
   933
        
om@3
   934
        self.dbname = "postgres"
om@3
   935
        self.paramstyle = db_module.paramstyle
om@3
   936
        DB.__init__(self, db_module, keywords)
om@3
   937
        self.supports_multiple_insert = True
om@3
   938
        self._sequences = None
om@3
   939
        
om@3
   940
    def _process_insert_query(self, query, tablename, seqname):
om@3
   941
        if seqname is None:
om@3
   942
            # when seqname is not provided guess the seqname and make sure it exists
om@3
   943
            seqname = tablename + "_id_seq"
om@3
   944
            if seqname not in self._get_all_sequences():
om@3
   945
                seqname = None
om@3
   946
        
om@3
   947
        if seqname:
om@3
   948
            query += "; SELECT currval('%s')" % seqname
om@3
   949
            
om@3
   950
        return query
om@3
   951
    
om@3
   952
    def _get_all_sequences(self):
om@3
   953
        """Query postgres to find names of all sequences used in this database."""
om@3
   954
        if self._sequences is None:
om@3
   955
            q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
om@3
   956
            self._sequences = set([c.relname for c in self.query(q)])
om@3
   957
        return self._sequences
om@3
   958
om@3
   959
    def _connect(self, keywords):
om@3
   960
        conn = DB._connect(self, keywords)
om@3
   961
        try:
om@3
   962
            conn.set_client_encoding('UTF8')
om@3
   963
        except AttributeError:
om@3
   964
            # fallback for pgdb driver
om@3
   965
            conn.cursor().execute("set client_encoding to 'UTF-8'")
om@3
   966
        return conn
om@3
   967
        
om@3
   968
    def _connect_with_pooling(self, keywords):
om@3
   969
        conn = DB._connect_with_pooling(self, keywords)
om@3
   970
        conn._con._con.set_client_encoding('UTF8')
om@3
   971
        return conn
om@3
   972
om@3
   973
class MySQLDB(DB): 
om@3
   974
    def __init__(self, **keywords):
om@3
   975
        import MySQLdb as db
om@3
   976
        if 'pw' in keywords:
om@3
   977
            keywords['passwd'] = keywords['pw']
om@3
   978
            del keywords['pw']
om@3
   979
om@3
   980
        if 'charset' not in keywords:
om@3
   981
            keywords['charset'] = 'utf8'
om@3
   982
        elif keywords['charset'] is None:
om@3
   983
            del keywords['charset']
om@3
   984
om@3
   985
        self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
om@3
   986
        self.dbname = "mysql"
om@3
   987
        DB.__init__(self, db, keywords)
om@3
   988
        self.supports_multiple_insert = True
om@3
   989
        
om@3
   990
    def _process_insert_query(self, query, tablename, seqname):
om@3
   991
        return query, SQLQuery('SELECT last_insert_id();')
om@3
   992
        
om@3
   993
    def _get_insert_default_values_query(self, table):
om@3
   994
        return "INSERT INTO %s () VALUES()" % table
om@3
   995
om@3
   996
def import_driver(drivers, preferred=None):
om@3
   997
    """Import the first available driver or preferred driver.
om@3
   998
    """
om@3
   999
    if preferred:
om@3
  1000
        drivers = [preferred]
om@3
  1001
om@3
  1002
    for d in drivers:
om@3
  1003
        try:
om@3
  1004
            return __import__(d, None, None, ['x'])
om@3
  1005
        except ImportError:
om@3
  1006
            pass
om@3
  1007
    raise ImportError("Unable to import " + " or ".join(drivers))
om@3
  1008
om@3
  1009
class SqliteDB(DB): 
om@3
  1010
    def __init__(self, **keywords):
om@3
  1011
        db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
om@3
  1012
om@3
  1013
        if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
om@3
  1014
            db.paramstyle = 'qmark'
om@3
  1015
            
om@3
  1016
        # sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
om@3
  1017
        # It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
om@3
  1018
        keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
om@3
  1019
om@3
  1020
        self.paramstyle = db.paramstyle
om@3
  1021
        keywords['database'] = keywords.pop('db')
om@3
  1022
        keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
om@3
  1023
        self.dbname = "sqlite"        
om@3
  1024
        DB.__init__(self, db, keywords)
om@3
  1025
om@3
  1026
    def _process_insert_query(self, query, tablename, seqname):
om@3
  1027
        return query, SQLQuery('SELECT last_insert_rowid();')
om@3
  1028
    
om@3
  1029
    def query(self, *a, **kw):
om@3
  1030
        out = DB.query(self, *a, **kw)
om@3
  1031
        if isinstance(out, iterbetter):
om@3
  1032
            del out.__len__
om@3
  1033
        return out
om@3
  1034
om@3
  1035
class FirebirdDB(DB):
om@3
  1036
    """Firebird Database.
om@3
  1037
    """
om@3
  1038
    def __init__(self, **keywords):
om@3
  1039
        try:
om@3
  1040
            import kinterbasdb as db
om@3
  1041
        except Exception:
om@3
  1042
            db = None
om@3
  1043
            pass
om@3
  1044
        if 'pw' in keywords:
om@3
  1045
            keywords['passwd'] = keywords['pw']
om@3
  1046
            del keywords['pw']
om@3
  1047
        keywords['database'] = keywords['db']
om@3
  1048
        del keywords['db']
om@3
  1049
        DB.__init__(self, db, keywords)
om@3
  1050
        
om@3
  1051
    def delete(self, table, where=None, using=None, vars=None, _test=False):
om@3
  1052
        # firebird doesn't support using clause
om@3
  1053
        using=None
om@3
  1054
        return DB.delete(self, table, where, using, vars, _test)
om@3
  1055
om@3
  1056
    def sql_clauses(self, what, tables, where, group, order, limit, offset):
om@3
  1057
        return (
om@3
  1058
            ('SELECT', ''),
om@3
  1059
            ('FIRST', limit),
om@3
  1060
            ('SKIP', offset),
om@3
  1061
            ('', what),
om@3
  1062
            ('FROM', sqllist(tables)),
om@3
  1063
            ('WHERE', where),
om@3
  1064
            ('GROUP BY', group),
om@3
  1065
            ('ORDER BY', order)
om@3
  1066
        )
om@3
  1067
om@3
  1068
class MSSQLDB(DB):
om@3
  1069
    def __init__(self, **keywords):
om@3
  1070
        import pymssql as db    
om@3
  1071
        if 'pw' in keywords:
om@3
  1072
            keywords['password'] = keywords.pop('pw')
om@3
  1073
        keywords['database'] = keywords.pop('db')
om@3
  1074
        self.dbname = "mssql"
om@3
  1075
        DB.__init__(self, db, keywords)
om@3
  1076
om@3
  1077
    def _process_query(self, sql_query):
om@3
  1078
        """Takes the SQLQuery object and returns query string and parameters.
om@3
  1079
        """
om@3
  1080
        # MSSQLDB expects params to be a tuple. 
om@3
  1081
        # Overwriting the default implementation to convert params to tuple.
om@3
  1082
        paramstyle = getattr(self, 'paramstyle', 'pyformat')
om@3
  1083
        query = sql_query.query(paramstyle)
om@3
  1084
        params = sql_query.values()
om@3
  1085
        return query, tuple(params)
om@3
  1086
om@3
  1087
    def sql_clauses(self, what, tables, where, group, order, limit, offset): 
om@3
  1088
        return (
om@3
  1089
            ('SELECT', what),
om@3
  1090
            ('TOP', limit),
om@3
  1091
            ('FROM', sqllist(tables)),
om@3
  1092
            ('WHERE', where),
om@3
  1093
            ('GROUP BY', group),
om@3
  1094
            ('ORDER BY', order),
om@3
  1095
            ('OFFSET', offset))
om@3
  1096
            
om@3
  1097
    def _test(self):
om@3
  1098
        """Test LIMIT.
om@3
  1099
om@3
  1100
            Fake presence of pymssql module for running tests.
om@3
  1101
            >>> import sys
om@3
  1102
            >>> sys.modules['pymssql'] = sys.modules['sys']
om@3
  1103
            
om@3
  1104
            MSSQL has TOP clause instead of LIMIT clause.
om@3
  1105
            >>> db = MSSQLDB(db='test', user='joe', pw='secret')
om@3
  1106
            >>> db.select('foo', limit=4, _test=True)
om@3
  1107
            <sql: 'SELECT * TOP 4 FROM foo'>
om@3
  1108
        """
om@3
  1109
        pass
om@3
  1110
om@3
  1111
class OracleDB(DB): 
om@3
  1112
    def __init__(self, **keywords): 
om@3
  1113
        import cx_Oracle as db 
om@3
  1114
        if 'pw' in keywords: 
om@3
  1115
            keywords['password'] = keywords.pop('pw') 
om@3
  1116
om@3
  1117
        #@@ TODO: use db.makedsn if host, port is specified 
om@3
  1118
        keywords['dsn'] = keywords.pop('db') 
om@3
  1119
        self.dbname = 'oracle' 
om@3
  1120
        db.paramstyle = 'numeric' 
om@3
  1121
        self.paramstyle = db.paramstyle
om@3
  1122
om@3
  1123
        # oracle doesn't support pooling 
om@3
  1124
        keywords.pop('pooling', None) 
om@3
  1125
        DB.__init__(self, db, keywords) 
om@3
  1126
om@3
  1127
    def _process_insert_query(self, query, tablename, seqname): 
om@3
  1128
        if seqname is None: 
om@3
  1129
            # It is not possible to get seq name from table name in Oracle
om@3
  1130
            return query
om@3
  1131
        else:
om@3
  1132
            return query + "; SELECT %s.currval FROM dual" % seqname 
om@3
  1133
om@3
  1134
_databases = {}
om@3
  1135
def database(dburl=None, **params):
om@3
  1136
    """Creates appropriate database using params.
om@3
  1137
    
om@3
  1138
    Pooling will be enabled if DBUtils module is available. 
om@3
  1139
    Pooling can be disabled by passing pooling=False in params.
om@3
  1140
    """
om@3
  1141
    dbn = params.pop('dbn')
om@3
  1142
    if dbn in _databases:
om@3
  1143
        return _databases[dbn](**params)
om@3
  1144
    else:
om@3
  1145
        raise UnknownDB, dbn
om@3
  1146
om@3
  1147
def register_database(name, clazz):
om@3
  1148
    """
om@3
  1149
    Register a database.
om@3
  1150
om@3
  1151
        >>> class LegacyDB(DB): 
om@3
  1152
        ...     def __init__(self, **params): 
om@3
  1153
        ...        pass 
om@3
  1154
        ...
om@3
  1155
        >>> register_database('legacy', LegacyDB)
om@3
  1156
        >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') 
om@3
  1157
    """
om@3
  1158
    _databases[name] = clazz
om@3
  1159
om@3
  1160
register_database('mysql', MySQLDB)
om@3
  1161
register_database('postgres', PostgresDB)
om@3
  1162
register_database('sqlite', SqliteDB)
om@3
  1163
register_database('firebird', FirebirdDB)
om@3
  1164
register_database('mssql', MSSQLDB)
om@3
  1165
register_database('oracle', OracleDB)
om@3
  1166
om@3
  1167
def _interpolate(format): 
om@3
  1168
    """
om@3
  1169
    Takes a format string and returns a list of 2-tuples of the form
om@3
  1170
    (boolean, string) where boolean says whether string should be evaled
om@3
  1171
    or not.
om@3
  1172
om@3
  1173
    from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
om@3
  1174
    """
om@3
  1175
    from tokenize import tokenprog
om@3
  1176
om@3
  1177
    def matchorfail(text, pos):
om@3
  1178
        match = tokenprog.match(text, pos)
om@3
  1179
        if match is None:
om@3
  1180
            raise _ItplError(text, pos)
om@3
  1181
        return match, match.end()
om@3
  1182
om@3
  1183
    namechars = "abcdefghijklmnopqrstuvwxyz" \
om@3
  1184
        "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
om@3
  1185
    chunks = []
om@3
  1186
    pos = 0
om@3
  1187
om@3
  1188
    while 1:
om@3
  1189
        dollar = format.find("$", pos)
om@3
  1190
        if dollar < 0: 
om@3
  1191
            break
om@3
  1192
        nextchar = format[dollar + 1]
om@3
  1193
om@3
  1194
        if nextchar == "{":
om@3
  1195
            chunks.append((0, format[pos:dollar]))
om@3
  1196
            pos, level = dollar + 2, 1
om@3
  1197
            while level:
om@3
  1198
                match, pos = matchorfail(format, pos)
om@3
  1199
                tstart, tend = match.regs[3]
om@3
  1200
                token = format[tstart:tend]
om@3
  1201
                if token == "{": 
om@3
  1202
                    level = level + 1
om@3
  1203
                elif token == "}":  
om@3
  1204
                    level = level - 1
om@3
  1205
            chunks.append((1, format[dollar + 2:pos - 1]))
om@3
  1206
om@3
  1207
        elif nextchar in namechars:
om@3
  1208
            chunks.append((0, format[pos:dollar]))
om@3
  1209
            match, pos = matchorfail(format, dollar + 1)
om@3
  1210
            while pos < len(format):
om@3
  1211
                if format[pos] == "." and \
om@3
  1212
                    pos + 1 < len(format) and format[pos + 1] in namechars:
om@3
  1213
                    match, pos = matchorfail(format, pos + 1)
om@3
  1214
                elif format[pos] in "([":
om@3
  1215
                    pos, level = pos + 1, 1
om@3
  1216
                    while level:
om@3
  1217
                        match, pos = matchorfail(format, pos)
om@3
  1218
                        tstart, tend = match.regs[3]
om@3
  1219
                        token = format[tstart:tend]
om@3
  1220
                        if token[0] in "([": 
om@3
  1221
                            level = level + 1
om@3
  1222
                        elif token[0] in ")]":  
om@3
  1223
                            level = level - 1
om@3
  1224
                else: 
om@3
  1225
                    break
om@3
  1226
            chunks.append((1, format[dollar + 1:pos]))
om@3
  1227
        else:
om@3
  1228
            chunks.append((0, format[pos:dollar + 1]))
om@3
  1229
            pos = dollar + 1 + (nextchar == "$")
om@3
  1230
om@3
  1231
    if pos < len(format): 
om@3
  1232
        chunks.append((0, format[pos:]))
om@3
  1233
    return chunks
om@3
  1234
om@3
  1235
if __name__ == "__main__":
om@3
  1236
    import doctest
om@3
  1237
    doctest.testmod()