OpenSecurity/install/web.py-0.37/web/db.py
changeset 3 65432e6c6042
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/OpenSecurity/install/web.py-0.37/web/db.py	Mon Dec 02 14:02:05 2013 +0100
     1.3 @@ -0,0 +1,1237 @@
     1.4 +"""
     1.5 +Database API
     1.6 +(part of web.py)
     1.7 +"""
     1.8 +
     1.9 +__all__ = [
    1.10 +  "UnknownParamstyle", "UnknownDB", "TransactionError", 
    1.11 +  "sqllist", "sqlors", "reparam", "sqlquote",
    1.12 +  "SQLQuery", "SQLParam", "sqlparam",
    1.13 +  "SQLLiteral", "sqlliteral",
    1.14 +  "database", 'DB',
    1.15 +]
    1.16 +
    1.17 +import time
    1.18 +try:
    1.19 +    import datetime
    1.20 +except ImportError:
    1.21 +    datetime = None
    1.22 +
    1.23 +try: set
    1.24 +except NameError:
    1.25 +    from sets import Set as set
    1.26 +    
    1.27 +from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
    1.28 +
    1.29 +try:
    1.30 +    # db module can work independent of web.py
    1.31 +    from webapi import debug, config
    1.32 +except:
    1.33 +    import sys
    1.34 +    debug = sys.stderr
    1.35 +    config = storage()
    1.36 +
    1.37 +class UnknownDB(Exception):
    1.38 +    """raised for unsupported dbms"""
    1.39 +    pass
    1.40 +
    1.41 +class _ItplError(ValueError): 
    1.42 +    def __init__(self, text, pos):
    1.43 +        ValueError.__init__(self)
    1.44 +        self.text = text
    1.45 +        self.pos = pos
    1.46 +    def __str__(self):
    1.47 +        return "unfinished expression in %s at char %d" % (
    1.48 +            repr(self.text), self.pos)
    1.49 +
    1.50 +class TransactionError(Exception): pass
    1.51 +
    1.52 +class UnknownParamstyle(Exception): 
    1.53 +    """
    1.54 +    raised for unsupported db paramstyles
    1.55 +
    1.56 +    (currently supported: qmark, numeric, format, pyformat)
    1.57 +    """
    1.58 +    pass
    1.59 +    
    1.60 +class SQLParam(object):
    1.61 +    """
    1.62 +    Parameter in SQLQuery.
    1.63 +    
    1.64 +        >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
    1.65 +        >>> q
    1.66 +        <sql: "SELECT * FROM test WHERE name='joe'">
    1.67 +        >>> q.query()
    1.68 +        'SELECT * FROM test WHERE name=%s'
    1.69 +        >>> q.values()
    1.70 +        ['joe']
    1.71 +    """
    1.72 +    __slots__ = ["value"]
    1.73 +
    1.74 +    def __init__(self, value):
    1.75 +        self.value = value
    1.76 +        
    1.77 +    def get_marker(self, paramstyle='pyformat'):
    1.78 +        if paramstyle == 'qmark':
    1.79 +            return '?'
    1.80 +        elif paramstyle == 'numeric':
    1.81 +            return ':1'
    1.82 +        elif paramstyle is None or paramstyle in ['format', 'pyformat']:
    1.83 +            return '%s'
    1.84 +        raise UnknownParamstyle, paramstyle
    1.85 +        
    1.86 +    def sqlquery(self): 
    1.87 +        return SQLQuery([self])
    1.88 +        
    1.89 +    def __add__(self, other):
    1.90 +        return self.sqlquery() + other
    1.91 +        
    1.92 +    def __radd__(self, other):
    1.93 +        return other + self.sqlquery() 
    1.94 +            
    1.95 +    def __str__(self): 
    1.96 +        return str(self.value)
    1.97 +    
    1.98 +    def __repr__(self):
    1.99 +        return '<param: %s>' % repr(self.value)
   1.100 +
   1.101 +sqlparam =  SQLParam
   1.102 +
   1.103 +class SQLQuery(object):
   1.104 +    """
   1.105 +    You can pass this sort of thing as a clause in any db function.
   1.106 +    Otherwise, you can pass a dictionary to the keyword argument `vars`
   1.107 +    and the function will call reparam for you.
   1.108 +
   1.109 +    Internally, consists of `items`, which is a list of strings and
   1.110 +    SQLParams, which get concatenated to produce the actual query.
   1.111 +    """
   1.112 +    __slots__ = ["items"]
   1.113 +
   1.114 +    # tested in sqlquote's docstring
   1.115 +    def __init__(self, items=None):
   1.116 +        r"""Creates a new SQLQuery.
   1.117 +        
   1.118 +            >>> SQLQuery("x")
   1.119 +            <sql: 'x'>
   1.120 +            >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
   1.121 +            >>> q
   1.122 +            <sql: 'SELECT * FROM test WHERE x=1'>
   1.123 +            >>> q.query(), q.values()
   1.124 +            ('SELECT * FROM test WHERE x=%s', [1])
   1.125 +            >>> SQLQuery(SQLParam(1))
   1.126 +            <sql: '1'>
   1.127 +        """
   1.128 +        if items is None:
   1.129 +            self.items = []
   1.130 +        elif isinstance(items, list):
   1.131 +            self.items = items
   1.132 +        elif isinstance(items, SQLParam):
   1.133 +            self.items = [items]
   1.134 +        elif isinstance(items, SQLQuery):
   1.135 +            self.items = list(items.items)
   1.136 +        else:
   1.137 +            self.items = [items]
   1.138 +            
   1.139 +        # Take care of SQLLiterals
   1.140 +        for i, item in enumerate(self.items):
   1.141 +            if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
   1.142 +                self.items[i] = item.value.v
   1.143 +
   1.144 +    def append(self, value):
   1.145 +        self.items.append(value)
   1.146 +
   1.147 +    def __add__(self, other):
   1.148 +        if isinstance(other, basestring):
   1.149 +            items = [other]
   1.150 +        elif isinstance(other, SQLQuery):
   1.151 +            items = other.items
   1.152 +        else:
   1.153 +            return NotImplemented
   1.154 +        return SQLQuery(self.items + items)
   1.155 +
   1.156 +    def __radd__(self, other):
   1.157 +        if isinstance(other, basestring):
   1.158 +            items = [other]
   1.159 +        else:
   1.160 +            return NotImplemented
   1.161 +            
   1.162 +        return SQLQuery(items + self.items)
   1.163 +
   1.164 +    def __iadd__(self, other):
   1.165 +        if isinstance(other, (basestring, SQLParam)):
   1.166 +            self.items.append(other)
   1.167 +        elif isinstance(other, SQLQuery):
   1.168 +            self.items.extend(other.items)
   1.169 +        else:
   1.170 +            return NotImplemented
   1.171 +        return self
   1.172 +
   1.173 +    def __len__(self):
   1.174 +        return len(self.query())
   1.175 +        
   1.176 +    def query(self, paramstyle=None):
   1.177 +        """
   1.178 +        Returns the query part of the sql query.
   1.179 +            >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
   1.180 +            >>> q.query()
   1.181 +            'SELECT * FROM test WHERE name=%s'
   1.182 +            >>> q.query(paramstyle='qmark')
   1.183 +            'SELECT * FROM test WHERE name=?'
   1.184 +        """
   1.185 +        s = []
   1.186 +        for x in self.items:
   1.187 +            if isinstance(x, SQLParam):
   1.188 +                x = x.get_marker(paramstyle)
   1.189 +                s.append(safestr(x))
   1.190 +            else:
   1.191 +                x = safestr(x)
   1.192 +                # automatically escape % characters in the query
   1.193 +                # For backward compatability, ignore escaping when the query looks already escaped
   1.194 +                if paramstyle in ['format', 'pyformat']:
   1.195 +                    if '%' in x and '%%' not in x:
   1.196 +                        x = x.replace('%', '%%')
   1.197 +                s.append(x)
   1.198 +        return "".join(s)
   1.199 +    
   1.200 +    def values(self):
   1.201 +        """
   1.202 +        Returns the values of the parameters used in the sql query.
   1.203 +            >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
   1.204 +            >>> q.values()
   1.205 +            ['joe']
   1.206 +        """
   1.207 +        return [i.value for i in self.items if isinstance(i, SQLParam)]
   1.208 +        
   1.209 +    def join(items, sep=' ', prefix=None, suffix=None, target=None):
   1.210 +        """
   1.211 +        Joins multiple queries.
   1.212 +        
   1.213 +        >>> SQLQuery.join(['a', 'b'], ', ')
   1.214 +        <sql: 'a, b'>
   1.215 +
   1.216 +        Optinally, prefix and suffix arguments can be provided.
   1.217 +
   1.218 +        >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
   1.219 +        <sql: '(a, b)'>
   1.220 +
   1.221 +        If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
   1.222 +        """
   1.223 +        if target is None:
   1.224 +            target = SQLQuery()
   1.225 +
   1.226 +        target_items = target.items
   1.227 +
   1.228 +        if prefix:
   1.229 +            target_items.append(prefix)
   1.230 +
   1.231 +        for i, item in enumerate(items):
   1.232 +            if i != 0:
   1.233 +                target_items.append(sep)
   1.234 +            if isinstance(item, SQLQuery):
   1.235 +                target_items.extend(item.items)
   1.236 +            else:
   1.237 +                target_items.append(item)
   1.238 +
   1.239 +        if suffix:
   1.240 +            target_items.append(suffix)
   1.241 +        return target
   1.242 +    
   1.243 +    join = staticmethod(join)
   1.244 +    
   1.245 +    def _str(self):
   1.246 +        try:
   1.247 +            return self.query() % tuple([sqlify(x) for x in self.values()])            
   1.248 +        except (ValueError, TypeError):
   1.249 +            return self.query()
   1.250 +        
   1.251 +    def __str__(self):
   1.252 +        return safestr(self._str())
   1.253 +        
   1.254 +    def __unicode__(self):
   1.255 +        return safeunicode(self._str())
   1.256 +
   1.257 +    def __repr__(self):
   1.258 +        return '<sql: %s>' % repr(str(self))
   1.259 +
   1.260 +class SQLLiteral: 
   1.261 +    """
   1.262 +    Protects a string from `sqlquote`.
   1.263 +
   1.264 +        >>> sqlquote('NOW()')
   1.265 +        <sql: "'NOW()'">
   1.266 +        >>> sqlquote(SQLLiteral('NOW()'))
   1.267 +        <sql: 'NOW()'>
   1.268 +    """
   1.269 +    def __init__(self, v): 
   1.270 +        self.v = v
   1.271 +
   1.272 +    def __repr__(self): 
   1.273 +        return self.v
   1.274 +
   1.275 +sqlliteral = SQLLiteral
   1.276 +
   1.277 +def _sqllist(values):
   1.278 +    """
   1.279 +        >>> _sqllist([1, 2, 3])
   1.280 +        <sql: '(1, 2, 3)'>
   1.281 +    """
   1.282 +    items = []
   1.283 +    items.append('(')
   1.284 +    for i, v in enumerate(values):
   1.285 +        if i != 0:
   1.286 +            items.append(', ')
   1.287 +        items.append(sqlparam(v))
   1.288 +    items.append(')')
   1.289 +    return SQLQuery(items)
   1.290 +
   1.291 +def reparam(string_, dictionary): 
   1.292 +    """
   1.293 +    Takes a string and a dictionary and interpolates the string
   1.294 +    using values from the dictionary. Returns an `SQLQuery` for the result.
   1.295 +
   1.296 +        >>> reparam("s = $s", dict(s=True))
   1.297 +        <sql: "s = 't'">
   1.298 +        >>> reparam("s IN $s", dict(s=[1, 2]))
   1.299 +        <sql: 's IN (1, 2)'>
   1.300 +    """
   1.301 +    dictionary = dictionary.copy() # eval mucks with it
   1.302 +    vals = []
   1.303 +    result = []
   1.304 +    for live, chunk in _interpolate(string_):
   1.305 +        if live:
   1.306 +            v = eval(chunk, dictionary)
   1.307 +            result.append(sqlquote(v))
   1.308 +        else: 
   1.309 +            result.append(chunk)
   1.310 +    return SQLQuery.join(result, '')
   1.311 +
   1.312 +def sqlify(obj): 
   1.313 +    """
   1.314 +    converts `obj` to its proper SQL version
   1.315 +
   1.316 +        >>> sqlify(None)
   1.317 +        'NULL'
   1.318 +        >>> sqlify(True)
   1.319 +        "'t'"
   1.320 +        >>> sqlify(3)
   1.321 +        '3'
   1.322 +    """
   1.323 +    # because `1 == True and hash(1) == hash(True)`
   1.324 +    # we have to do this the hard way...
   1.325 +
   1.326 +    if obj is None:
   1.327 +        return 'NULL'
   1.328 +    elif obj is True:
   1.329 +        return "'t'"
   1.330 +    elif obj is False:
   1.331 +        return "'f'"
   1.332 +    elif datetime and isinstance(obj, datetime.datetime):
   1.333 +        return repr(obj.isoformat())
   1.334 +    else:
   1.335 +        if isinstance(obj, unicode): obj = obj.encode('utf8')
   1.336 +        return repr(obj)
   1.337 +
   1.338 +def sqllist(lst): 
   1.339 +    """
   1.340 +    Converts the arguments for use in something like a WHERE clause.
   1.341 +    
   1.342 +        >>> sqllist(['a', 'b'])
   1.343 +        'a, b'
   1.344 +        >>> sqllist('a')
   1.345 +        'a'
   1.346 +        >>> sqllist(u'abc')
   1.347 +        u'abc'
   1.348 +    """
   1.349 +    if isinstance(lst, basestring): 
   1.350 +        return lst
   1.351 +    else:
   1.352 +        return ', '.join(lst)
   1.353 +
   1.354 +def sqlors(left, lst):
   1.355 +    """
   1.356 +    `left is a SQL clause like `tablename.arg = ` 
   1.357 +    and `lst` is a list of values. Returns a reparam-style
   1.358 +    pair featuring the SQL that ORs together the clause
   1.359 +    for each item in the lst.
   1.360 +
   1.361 +        >>> sqlors('foo = ', [])
   1.362 +        <sql: '1=2'>
   1.363 +        >>> sqlors('foo = ', [1])
   1.364 +        <sql: 'foo = 1'>
   1.365 +        >>> sqlors('foo = ', 1)
   1.366 +        <sql: 'foo = 1'>
   1.367 +        >>> sqlors('foo = ', [1,2,3])
   1.368 +        <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
   1.369 +    """
   1.370 +    if isinstance(lst, iters):
   1.371 +        lst = list(lst)
   1.372 +        ln = len(lst)
   1.373 +        if ln == 0:
   1.374 +            return SQLQuery("1=2")
   1.375 +        if ln == 1:
   1.376 +            lst = lst[0]
   1.377 +
   1.378 +    if isinstance(lst, iters):
   1.379 +        return SQLQuery(['('] + 
   1.380 +          sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
   1.381 +          ['1=2)']
   1.382 +        )
   1.383 +    else:
   1.384 +        return left + sqlparam(lst)
   1.385 +        
   1.386 +def sqlwhere(dictionary, grouping=' AND '): 
   1.387 +    """
   1.388 +    Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
   1.389 +    
   1.390 +        >>> sqlwhere({'cust_id': 2, 'order_id':3})
   1.391 +        <sql: 'order_id = 3 AND cust_id = 2'>
   1.392 +        >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
   1.393 +        <sql: 'order_id = 3, cust_id = 2'>
   1.394 +        >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
   1.395 +        'a = %s AND b = %s'
   1.396 +    """
   1.397 +    return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
   1.398 +
   1.399 +def sqlquote(a): 
   1.400 +    """
   1.401 +    Ensures `a` is quoted properly for use in a SQL query.
   1.402 +
   1.403 +        >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
   1.404 +        <sql: "WHERE x = 't' AND y = 3">
   1.405 +        >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
   1.406 +        <sql: "WHERE x = 't' AND y IN (2, 3)">
   1.407 +    """
   1.408 +    if isinstance(a, list):
   1.409 +        return _sqllist(a)
   1.410 +    else:
   1.411 +        return sqlparam(a).sqlquery()
   1.412 +
   1.413 +class Transaction:
   1.414 +    """Database transaction."""
   1.415 +    def __init__(self, ctx):
   1.416 +        self.ctx = ctx
   1.417 +        self.transaction_count = transaction_count = len(ctx.transactions)
   1.418 +
   1.419 +        class transaction_engine:
   1.420 +            """Transaction Engine used in top level transactions."""
   1.421 +            def do_transact(self):
   1.422 +                ctx.commit(unload=False)
   1.423 +
   1.424 +            def do_commit(self):
   1.425 +                ctx.commit()
   1.426 +
   1.427 +            def do_rollback(self):
   1.428 +                ctx.rollback()
   1.429 +
   1.430 +        class subtransaction_engine:
   1.431 +            """Transaction Engine used in sub transactions."""
   1.432 +            def query(self, q):
   1.433 +                db_cursor = ctx.db.cursor()
   1.434 +                ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
   1.435 +
   1.436 +            def do_transact(self):
   1.437 +                self.query('SAVEPOINT webpy_sp_%s')
   1.438 +
   1.439 +            def do_commit(self):
   1.440 +                self.query('RELEASE SAVEPOINT webpy_sp_%s')
   1.441 +
   1.442 +            def do_rollback(self):
   1.443 +                self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
   1.444 +
   1.445 +        class dummy_engine:
   1.446 +            """Transaction Engine used instead of subtransaction_engine 
   1.447 +            when sub transactions are not supported."""
   1.448 +            do_transact = do_commit = do_rollback = lambda self: None
   1.449 +
   1.450 +        if self.transaction_count:
   1.451 +            # nested transactions are not supported in some databases
   1.452 +            if self.ctx.get('ignore_nested_transactions'):
   1.453 +                self.engine = dummy_engine()
   1.454 +            else:
   1.455 +                self.engine = subtransaction_engine()
   1.456 +        else:
   1.457 +            self.engine = transaction_engine()
   1.458 +
   1.459 +        self.engine.do_transact()
   1.460 +        self.ctx.transactions.append(self)
   1.461 +
   1.462 +    def __enter__(self):
   1.463 +        return self
   1.464 +
   1.465 +    def __exit__(self, exctype, excvalue, traceback):
   1.466 +        if exctype is not None:
   1.467 +            self.rollback()
   1.468 +        else:
   1.469 +            self.commit()
   1.470 +
   1.471 +    def commit(self):
   1.472 +        if len(self.ctx.transactions) > self.transaction_count:
   1.473 +            self.engine.do_commit()
   1.474 +            self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
   1.475 +
   1.476 +    def rollback(self):
   1.477 +        if len(self.ctx.transactions) > self.transaction_count:
   1.478 +            self.engine.do_rollback()
   1.479 +            self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
   1.480 +
   1.481 +class DB: 
   1.482 +    """Database"""
   1.483 +    def __init__(self, db_module, keywords):
   1.484 +        """Creates a database.
   1.485 +        """
   1.486 +        # some DB implementaions take optional paramater `driver` to use a specific driver modue
   1.487 +        # but it should not be passed to connect
   1.488 +        keywords.pop('driver', None)
   1.489 +
   1.490 +        self.db_module = db_module
   1.491 +        self.keywords = keywords
   1.492 +
   1.493 +        self._ctx = threadeddict()
   1.494 +        # flag to enable/disable printing queries
   1.495 +        self.printing = config.get('debug_sql', config.get('debug', False))
   1.496 +        self.supports_multiple_insert = False
   1.497 +        
   1.498 +        try:
   1.499 +            import DBUtils
   1.500 +            # enable pooling if DBUtils module is available.
   1.501 +            self.has_pooling = True
   1.502 +        except ImportError:
   1.503 +            self.has_pooling = False
   1.504 +            
   1.505 +        # Pooling can be disabled by passing pooling=False in the keywords.
   1.506 +        self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
   1.507 +            
   1.508 +    def _getctx(self): 
   1.509 +        if not self._ctx.get('db'):
   1.510 +            self._load_context(self._ctx)
   1.511 +        return self._ctx
   1.512 +    ctx = property(_getctx)
   1.513 +    
   1.514 +    def _load_context(self, ctx):
   1.515 +        ctx.dbq_count = 0
   1.516 +        ctx.transactions = [] # stack of transactions
   1.517 +        
   1.518 +        if self.has_pooling:
   1.519 +            ctx.db = self._connect_with_pooling(self.keywords)
   1.520 +        else:
   1.521 +            ctx.db = self._connect(self.keywords)
   1.522 +        ctx.db_execute = self._db_execute
   1.523 +        
   1.524 +        if not hasattr(ctx.db, 'commit'):
   1.525 +            ctx.db.commit = lambda: None
   1.526 +
   1.527 +        if not hasattr(ctx.db, 'rollback'):
   1.528 +            ctx.db.rollback = lambda: None
   1.529 +            
   1.530 +        def commit(unload=True):
   1.531 +            # do db commit and release the connection if pooling is enabled.            
   1.532 +            ctx.db.commit()
   1.533 +            if unload and self.has_pooling:
   1.534 +                self._unload_context(self._ctx)
   1.535 +                
   1.536 +        def rollback():
   1.537 +            # do db rollback and release the connection if pooling is enabled.
   1.538 +            ctx.db.rollback()
   1.539 +            if self.has_pooling:
   1.540 +                self._unload_context(self._ctx)
   1.541 +                
   1.542 +        ctx.commit = commit
   1.543 +        ctx.rollback = rollback
   1.544 +            
   1.545 +    def _unload_context(self, ctx):
   1.546 +        del ctx.db
   1.547 +            
   1.548 +    def _connect(self, keywords):
   1.549 +        return self.db_module.connect(**keywords)
   1.550 +        
   1.551 +    def _connect_with_pooling(self, keywords):
   1.552 +        def get_pooled_db():
   1.553 +            from DBUtils import PooledDB
   1.554 +
   1.555 +            # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
   1.556 +            # see Bug#122112
   1.557 +            
   1.558 +            if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
   1.559 +                return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
   1.560 +            else:
   1.561 +                return PooledDB.PooledDB(creator=self.db_module, **keywords)
   1.562 +        
   1.563 +        if getattr(self, '_pooleddb', None) is None:
   1.564 +            self._pooleddb = get_pooled_db()
   1.565 +        
   1.566 +        return self._pooleddb.connection()
   1.567 +        
   1.568 +    def _db_cursor(self):
   1.569 +        return self.ctx.db.cursor()
   1.570 +
   1.571 +    def _param_marker(self):
   1.572 +        """Returns parameter marker based on paramstyle attribute if this database."""
   1.573 +        style = getattr(self, 'paramstyle', 'pyformat')
   1.574 +
   1.575 +        if style == 'qmark':
   1.576 +            return '?'
   1.577 +        elif style == 'numeric':
   1.578 +            return ':1'
   1.579 +        elif style in ['format', 'pyformat']:
   1.580 +            return '%s'
   1.581 +        raise UnknownParamstyle, style
   1.582 +
   1.583 +    def _db_execute(self, cur, sql_query): 
   1.584 +        """executes an sql query"""
   1.585 +        self.ctx.dbq_count += 1
   1.586 +        
   1.587 +        try:
   1.588 +            a = time.time()
   1.589 +            query, params = self._process_query(sql_query)
   1.590 +            out = cur.execute(query, params)
   1.591 +            b = time.time()
   1.592 +        except:
   1.593 +            if self.printing:
   1.594 +                print >> debug, 'ERR:', str(sql_query)
   1.595 +            if self.ctx.transactions:
   1.596 +                self.ctx.transactions[-1].rollback()
   1.597 +            else:
   1.598 +                self.ctx.rollback()
   1.599 +            raise
   1.600 +
   1.601 +        if self.printing:
   1.602 +            print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
   1.603 +        return out
   1.604 +
   1.605 +    def _process_query(self, sql_query):
   1.606 +        """Takes the SQLQuery object and returns query string and parameters.
   1.607 +        """
   1.608 +        paramstyle = getattr(self, 'paramstyle', 'pyformat')
   1.609 +        query = sql_query.query(paramstyle)
   1.610 +        params = sql_query.values()
   1.611 +        return query, params
   1.612 +    
   1.613 +    def _where(self, where, vars): 
   1.614 +        if isinstance(where, (int, long)):
   1.615 +            where = "id = " + sqlparam(where)
   1.616 +        #@@@ for backward-compatibility
   1.617 +        elif isinstance(where, (list, tuple)) and len(where) == 2:
   1.618 +            where = SQLQuery(where[0], where[1])
   1.619 +        elif isinstance(where, SQLQuery):
   1.620 +            pass
   1.621 +        else:
   1.622 +            where = reparam(where, vars)        
   1.623 +        return where
   1.624 +    
   1.625 +    def query(self, sql_query, vars=None, processed=False, _test=False): 
   1.626 +        """
   1.627 +        Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
   1.628 +        If `processed=True`, `vars` is a `reparam`-style list to use 
   1.629 +        instead of interpolating.
   1.630 +        
   1.631 +            >>> db = DB(None, {})
   1.632 +            >>> db.query("SELECT * FROM foo", _test=True)
   1.633 +            <sql: 'SELECT * FROM foo'>
   1.634 +            >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
   1.635 +            <sql: "SELECT * FROM foo WHERE x = 'f'">
   1.636 +            >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
   1.637 +            <sql: "SELECT * FROM foo WHERE x = 'f'">
   1.638 +        """
   1.639 +        if vars is None: vars = {}
   1.640 +        
   1.641 +        if not processed and not isinstance(sql_query, SQLQuery):
   1.642 +            sql_query = reparam(sql_query, vars)
   1.643 +        
   1.644 +        if _test: return sql_query
   1.645 +        
   1.646 +        db_cursor = self._db_cursor()
   1.647 +        self._db_execute(db_cursor, sql_query)
   1.648 +        
   1.649 +        if db_cursor.description:
   1.650 +            names = [x[0] for x in db_cursor.description]
   1.651 +            def iterwrapper():
   1.652 +                row = db_cursor.fetchone()
   1.653 +                while row:
   1.654 +                    yield storage(dict(zip(names, row)))
   1.655 +                    row = db_cursor.fetchone()
   1.656 +            out = iterbetter(iterwrapper())
   1.657 +            out.__len__ = lambda: int(db_cursor.rowcount)
   1.658 +            out.list = lambda: [storage(dict(zip(names, x))) \
   1.659 +                               for x in db_cursor.fetchall()]
   1.660 +        else:
   1.661 +            out = db_cursor.rowcount
   1.662 +        
   1.663 +        if not self.ctx.transactions: 
   1.664 +            self.ctx.commit()
   1.665 +        return out
   1.666 +    
   1.667 +    def select(self, tables, vars=None, what='*', where=None, order=None, group=None, 
   1.668 +               limit=None, offset=None, _test=False): 
   1.669 +        """
   1.670 +        Selects `what` from `tables` with clauses `where`, `order`, 
   1.671 +        `group`, `limit`, and `offset`. Uses vars to interpolate. 
   1.672 +        Otherwise, each clause can be a SQLQuery.
   1.673 +        
   1.674 +            >>> db = DB(None, {})
   1.675 +            >>> db.select('foo', _test=True)
   1.676 +            <sql: 'SELECT * FROM foo'>
   1.677 +            >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
   1.678 +            <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
   1.679 +        """
   1.680 +        if vars is None: vars = {}
   1.681 +        sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
   1.682 +        clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
   1.683 +        qout = SQLQuery.join(clauses)
   1.684 +        if _test: return qout
   1.685 +        return self.query(qout, processed=True)
   1.686 +    
   1.687 +    def where(self, table, what='*', order=None, group=None, limit=None, 
   1.688 +              offset=None, _test=False, **kwargs):
   1.689 +        """
   1.690 +        Selects from `table` where keys are equal to values in `kwargs`.
   1.691 +        
   1.692 +            >>> db = DB(None, {})
   1.693 +            >>> db.where('foo', bar_id=3, _test=True)
   1.694 +            <sql: 'SELECT * FROM foo WHERE bar_id = 3'>
   1.695 +            >>> db.where('foo', source=2, crust='dewey', _test=True)
   1.696 +            <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
   1.697 +            >>> db.where('foo', _test=True)
   1.698 +            <sql: 'SELECT * FROM foo'>
   1.699 +        """
   1.700 +        where_clauses = []
   1.701 +        for k, v in kwargs.iteritems():
   1.702 +            where_clauses.append(k + ' = ' + sqlquote(v))
   1.703 +            
   1.704 +        if where_clauses:
   1.705 +            where = SQLQuery.join(where_clauses, " AND ")
   1.706 +        else:
   1.707 +            where = None
   1.708 +            
   1.709 +        return self.select(table, what=what, order=order, 
   1.710 +               group=group, limit=limit, offset=offset, _test=_test, 
   1.711 +               where=where)
   1.712 +    
   1.713 +    def sql_clauses(self, what, tables, where, group, order, limit, offset): 
   1.714 +        return (
   1.715 +            ('SELECT', what),
   1.716 +            ('FROM', sqllist(tables)),
   1.717 +            ('WHERE', where),
   1.718 +            ('GROUP BY', group),
   1.719 +            ('ORDER BY', order),
   1.720 +            ('LIMIT', limit),
   1.721 +            ('OFFSET', offset))
   1.722 +    
   1.723 +    def gen_clause(self, sql, val, vars): 
   1.724 +        if isinstance(val, (int, long)):
   1.725 +            if sql == 'WHERE':
   1.726 +                nout = 'id = ' + sqlquote(val)
   1.727 +            else:
   1.728 +                nout = SQLQuery(val)
   1.729 +        #@@@
   1.730 +        elif isinstance(val, (list, tuple)) and len(val) == 2:
   1.731 +            nout = SQLQuery(val[0], val[1]) # backwards-compatibility
   1.732 +        elif isinstance(val, SQLQuery):
   1.733 +            nout = val
   1.734 +        else:
   1.735 +            nout = reparam(val, vars)
   1.736 +
   1.737 +        def xjoin(a, b):
   1.738 +            if a and b: return a + ' ' + b
   1.739 +            else: return a or b
   1.740 +
   1.741 +        return xjoin(sql, nout)
   1.742 +
   1.743 +    def insert(self, tablename, seqname=None, _test=False, **values): 
   1.744 +        """
   1.745 +        Inserts `values` into `tablename`. Returns current sequence ID.
   1.746 +        Set `seqname` to the ID if it's not the default, or to `False`
   1.747 +        if there isn't one.
   1.748 +        
   1.749 +            >>> db = DB(None, {})
   1.750 +            >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
   1.751 +            >>> q
   1.752 +            <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
   1.753 +            >>> q.query()
   1.754 +            'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
   1.755 +            >>> q.values()
   1.756 +            [2, 'bob']
   1.757 +        """
   1.758 +        def q(x): return "(" + x + ")"
   1.759 +        
   1.760 +        if values:
   1.761 +            _keys = SQLQuery.join(values.keys(), ', ')
   1.762 +            _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
   1.763 +            sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
   1.764 +        else:
   1.765 +            sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
   1.766 +
   1.767 +        if _test: return sql_query
   1.768 +        
   1.769 +        db_cursor = self._db_cursor()
   1.770 +        if seqname is not False: 
   1.771 +            sql_query = self._process_insert_query(sql_query, tablename, seqname)
   1.772 +
   1.773 +        if isinstance(sql_query, tuple):
   1.774 +            # for some databases, a separate query has to be made to find 
   1.775 +            # the id of the inserted row.
   1.776 +            q1, q2 = sql_query
   1.777 +            self._db_execute(db_cursor, q1)
   1.778 +            self._db_execute(db_cursor, q2)
   1.779 +        else:
   1.780 +            self._db_execute(db_cursor, sql_query)
   1.781 +
   1.782 +        try: 
   1.783 +            out = db_cursor.fetchone()[0]
   1.784 +        except Exception: 
   1.785 +            out = None
   1.786 +        
   1.787 +        if not self.ctx.transactions: 
   1.788 +            self.ctx.commit()
   1.789 +        return out
   1.790 +        
   1.791 +    def _get_insert_default_values_query(self, table):
   1.792 +        return "INSERT INTO %s DEFAULT VALUES" % table
   1.793 +
   1.794 +    def multiple_insert(self, tablename, values, seqname=None, _test=False):
   1.795 +        """
   1.796 +        Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, 
   1.797 +        one for each row to be inserted, each with the same set of keys.
   1.798 +        Returns the list of ids of the inserted rows.        
   1.799 +        Set `seqname` to the ID if it's not the default, or to `False`
   1.800 +        if there isn't one.
   1.801 +        
   1.802 +            >>> db = DB(None, {})
   1.803 +            >>> db.supports_multiple_insert = True
   1.804 +            >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
   1.805 +            >>> db.multiple_insert('person', values=values, _test=True)
   1.806 +            <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
   1.807 +        """        
   1.808 +        if not values:
   1.809 +            return []
   1.810 +            
   1.811 +        if not self.supports_multiple_insert:
   1.812 +            out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
   1.813 +            if seqname is False:
   1.814 +                return None
   1.815 +            else:
   1.816 +                return out
   1.817 +                
   1.818 +        keys = values[0].keys()
   1.819 +        #@@ make sure all keys are valid
   1.820 +
   1.821 +        # make sure all rows have same keys.
   1.822 +        for v in values:
   1.823 +            if v.keys() != keys:
   1.824 +                raise ValueError, 'Bad data'
   1.825 +
   1.826 +        sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
   1.827 +
   1.828 +        for i, row in enumerate(values):
   1.829 +            if i != 0:
   1.830 +                sql_query.append(", ")
   1.831 +            SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
   1.832 +        
   1.833 +        if _test: return sql_query
   1.834 +
   1.835 +        db_cursor = self._db_cursor()
   1.836 +        if seqname is not False: 
   1.837 +            sql_query = self._process_insert_query(sql_query, tablename, seqname)
   1.838 +
   1.839 +        if isinstance(sql_query, tuple):
   1.840 +            # for some databases, a separate query has to be made to find 
   1.841 +            # the id of the inserted row.
   1.842 +            q1, q2 = sql_query
   1.843 +            self._db_execute(db_cursor, q1)
   1.844 +            self._db_execute(db_cursor, q2)
   1.845 +        else:
   1.846 +            self._db_execute(db_cursor, sql_query)
   1.847 +
   1.848 +        try: 
   1.849 +            out = db_cursor.fetchone()[0]
   1.850 +            out = range(out-len(values)+1, out+1)        
   1.851 +        except Exception: 
   1.852 +            out = None
   1.853 +
   1.854 +        if not self.ctx.transactions: 
   1.855 +            self.ctx.commit()
   1.856 +        return out
   1.857 +
   1.858 +    
   1.859 +    def update(self, tables, where, vars=None, _test=False, **values): 
   1.860 +        """
   1.861 +        Update `tables` with clause `where` (interpolated using `vars`)
   1.862 +        and setting `values`.
   1.863 +
   1.864 +            >>> db = DB(None, {})
   1.865 +            >>> name = 'Joseph'
   1.866 +            >>> q = db.update('foo', where='name = $name', name='bob', age=2,
   1.867 +            ...     created=SQLLiteral('NOW()'), vars=locals(), _test=True)
   1.868 +            >>> q
   1.869 +            <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
   1.870 +            >>> q.query()
   1.871 +            'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
   1.872 +            >>> q.values()
   1.873 +            [2, 'bob', 'Joseph']
   1.874 +        """
   1.875 +        if vars is None: vars = {}
   1.876 +        where = self._where(where, vars)
   1.877 +
   1.878 +        query = (
   1.879 +          "UPDATE " + sqllist(tables) + 
   1.880 +          " SET " + sqlwhere(values, ', ') + 
   1.881 +          " WHERE " + where)
   1.882 +
   1.883 +        if _test: return query
   1.884 +        
   1.885 +        db_cursor = self._db_cursor()
   1.886 +        self._db_execute(db_cursor, query)
   1.887 +        if not self.ctx.transactions: 
   1.888 +            self.ctx.commit()
   1.889 +        return db_cursor.rowcount
   1.890 +    
   1.891 +    def delete(self, table, where, using=None, vars=None, _test=False): 
   1.892 +        """
   1.893 +        Deletes from `table` with clauses `where` and `using`.
   1.894 +
   1.895 +            >>> db = DB(None, {})
   1.896 +            >>> name = 'Joe'
   1.897 +            >>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
   1.898 +            <sql: "DELETE FROM foo WHERE name = 'Joe'">
   1.899 +        """
   1.900 +        if vars is None: vars = {}
   1.901 +        where = self._where(where, vars)
   1.902 +
   1.903 +        q = 'DELETE FROM ' + table
   1.904 +        if using: q += ' USING ' + sqllist(using)
   1.905 +        if where: q += ' WHERE ' + where
   1.906 +
   1.907 +        if _test: return q
   1.908 +
   1.909 +        db_cursor = self._db_cursor()
   1.910 +        self._db_execute(db_cursor, q)
   1.911 +        if not self.ctx.transactions: 
   1.912 +            self.ctx.commit()
   1.913 +        return db_cursor.rowcount
   1.914 +
   1.915 +    def _process_insert_query(self, query, tablename, seqname):
   1.916 +        return query
   1.917 +
   1.918 +    def transaction(self): 
   1.919 +        """Start a transaction."""
   1.920 +        return Transaction(self.ctx)
   1.921 +    
   1.922 +class PostgresDB(DB): 
   1.923 +    """Postgres driver."""
   1.924 +    def __init__(self, **keywords):
   1.925 +        if 'pw' in keywords:
   1.926 +            keywords['password'] = keywords.pop('pw')
   1.927 +            
   1.928 +        db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
   1.929 +        if db_module.__name__ == "psycopg2":
   1.930 +            import psycopg2.extensions
   1.931 +            psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
   1.932 +
   1.933 +        # if db is not provided postgres driver will take it from PGDATABASE environment variable
   1.934 +        if 'db' in keywords:
   1.935 +            keywords['database'] = keywords.pop('db')
   1.936 +        
   1.937 +        self.dbname = "postgres"
   1.938 +        self.paramstyle = db_module.paramstyle
   1.939 +        DB.__init__(self, db_module, keywords)
   1.940 +        self.supports_multiple_insert = True
   1.941 +        self._sequences = None
   1.942 +        
   1.943 +    def _process_insert_query(self, query, tablename, seqname):
   1.944 +        if seqname is None:
   1.945 +            # when seqname is not provided guess the seqname and make sure it exists
   1.946 +            seqname = tablename + "_id_seq"
   1.947 +            if seqname not in self._get_all_sequences():
   1.948 +                seqname = None
   1.949 +        
   1.950 +        if seqname:
   1.951 +            query += "; SELECT currval('%s')" % seqname
   1.952 +            
   1.953 +        return query
   1.954 +    
   1.955 +    def _get_all_sequences(self):
   1.956 +        """Query postgres to find names of all sequences used in this database."""
   1.957 +        if self._sequences is None:
   1.958 +            q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
   1.959 +            self._sequences = set([c.relname for c in self.query(q)])
   1.960 +        return self._sequences
   1.961 +
   1.962 +    def _connect(self, keywords):
   1.963 +        conn = DB._connect(self, keywords)
   1.964 +        try:
   1.965 +            conn.set_client_encoding('UTF8')
   1.966 +        except AttributeError:
   1.967 +            # fallback for pgdb driver
   1.968 +            conn.cursor().execute("set client_encoding to 'UTF-8'")
   1.969 +        return conn
   1.970 +        
   1.971 +    def _connect_with_pooling(self, keywords):
   1.972 +        conn = DB._connect_with_pooling(self, keywords)
   1.973 +        conn._con._con.set_client_encoding('UTF8')
   1.974 +        return conn
   1.975 +
   1.976 +class MySQLDB(DB): 
   1.977 +    def __init__(self, **keywords):
   1.978 +        import MySQLdb as db
   1.979 +        if 'pw' in keywords:
   1.980 +            keywords['passwd'] = keywords['pw']
   1.981 +            del keywords['pw']
   1.982 +
   1.983 +        if 'charset' not in keywords:
   1.984 +            keywords['charset'] = 'utf8'
   1.985 +        elif keywords['charset'] is None:
   1.986 +            del keywords['charset']
   1.987 +
   1.988 +        self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
   1.989 +        self.dbname = "mysql"
   1.990 +        DB.__init__(self, db, keywords)
   1.991 +        self.supports_multiple_insert = True
   1.992 +        
   1.993 +    def _process_insert_query(self, query, tablename, seqname):
   1.994 +        return query, SQLQuery('SELECT last_insert_id();')
   1.995 +        
   1.996 +    def _get_insert_default_values_query(self, table):
   1.997 +        return "INSERT INTO %s () VALUES()" % table
   1.998 +
   1.999 +def import_driver(drivers, preferred=None):
  1.1000 +    """Import the first available driver or preferred driver.
  1.1001 +    """
  1.1002 +    if preferred:
  1.1003 +        drivers = [preferred]
  1.1004 +
  1.1005 +    for d in drivers:
  1.1006 +        try:
  1.1007 +            return __import__(d, None, None, ['x'])
  1.1008 +        except ImportError:
  1.1009 +            pass
  1.1010 +    raise ImportError("Unable to import " + " or ".join(drivers))
  1.1011 +
  1.1012 +class SqliteDB(DB): 
  1.1013 +    def __init__(self, **keywords):
  1.1014 +        db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
  1.1015 +
  1.1016 +        if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
  1.1017 +            db.paramstyle = 'qmark'
  1.1018 +            
  1.1019 +        # sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
  1.1020 +        # It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
  1.1021 +        keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
  1.1022 +
  1.1023 +        self.paramstyle = db.paramstyle
  1.1024 +        keywords['database'] = keywords.pop('db')
  1.1025 +        keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
  1.1026 +        self.dbname = "sqlite"        
  1.1027 +        DB.__init__(self, db, keywords)
  1.1028 +
  1.1029 +    def _process_insert_query(self, query, tablename, seqname):
  1.1030 +        return query, SQLQuery('SELECT last_insert_rowid();')
  1.1031 +    
  1.1032 +    def query(self, *a, **kw):
  1.1033 +        out = DB.query(self, *a, **kw)
  1.1034 +        if isinstance(out, iterbetter):
  1.1035 +            del out.__len__
  1.1036 +        return out
  1.1037 +
  1.1038 +class FirebirdDB(DB):
  1.1039 +    """Firebird Database.
  1.1040 +    """
  1.1041 +    def __init__(self, **keywords):
  1.1042 +        try:
  1.1043 +            import kinterbasdb as db
  1.1044 +        except Exception:
  1.1045 +            db = None
  1.1046 +            pass
  1.1047 +        if 'pw' in keywords:
  1.1048 +            keywords['passwd'] = keywords['pw']
  1.1049 +            del keywords['pw']
  1.1050 +        keywords['database'] = keywords['db']
  1.1051 +        del keywords['db']
  1.1052 +        DB.__init__(self, db, keywords)
  1.1053 +        
  1.1054 +    def delete(self, table, where=None, using=None, vars=None, _test=False):
  1.1055 +        # firebird doesn't support using clause
  1.1056 +        using=None
  1.1057 +        return DB.delete(self, table, where, using, vars, _test)
  1.1058 +
  1.1059 +    def sql_clauses(self, what, tables, where, group, order, limit, offset):
  1.1060 +        return (
  1.1061 +            ('SELECT', ''),
  1.1062 +            ('FIRST', limit),
  1.1063 +            ('SKIP', offset),
  1.1064 +            ('', what),
  1.1065 +            ('FROM', sqllist(tables)),
  1.1066 +            ('WHERE', where),
  1.1067 +            ('GROUP BY', group),
  1.1068 +            ('ORDER BY', order)
  1.1069 +        )
  1.1070 +
  1.1071 +class MSSQLDB(DB):
  1.1072 +    def __init__(self, **keywords):
  1.1073 +        import pymssql as db    
  1.1074 +        if 'pw' in keywords:
  1.1075 +            keywords['password'] = keywords.pop('pw')
  1.1076 +        keywords['database'] = keywords.pop('db')
  1.1077 +        self.dbname = "mssql"
  1.1078 +        DB.__init__(self, db, keywords)
  1.1079 +
  1.1080 +    def _process_query(self, sql_query):
  1.1081 +        """Takes the SQLQuery object and returns query string and parameters.
  1.1082 +        """
  1.1083 +        # MSSQLDB expects params to be a tuple. 
  1.1084 +        # Overwriting the default implementation to convert params to tuple.
  1.1085 +        paramstyle = getattr(self, 'paramstyle', 'pyformat')
  1.1086 +        query = sql_query.query(paramstyle)
  1.1087 +        params = sql_query.values()
  1.1088 +        return query, tuple(params)
  1.1089 +
  1.1090 +    def sql_clauses(self, what, tables, where, group, order, limit, offset): 
  1.1091 +        return (
  1.1092 +            ('SELECT', what),
  1.1093 +            ('TOP', limit),
  1.1094 +            ('FROM', sqllist(tables)),
  1.1095 +            ('WHERE', where),
  1.1096 +            ('GROUP BY', group),
  1.1097 +            ('ORDER BY', order),
  1.1098 +            ('OFFSET', offset))
  1.1099 +            
  1.1100 +    def _test(self):
  1.1101 +        """Test LIMIT.
  1.1102 +
  1.1103 +            Fake presence of pymssql module for running tests.
  1.1104 +            >>> import sys
  1.1105 +            >>> sys.modules['pymssql'] = sys.modules['sys']
  1.1106 +            
  1.1107 +            MSSQL has TOP clause instead of LIMIT clause.
  1.1108 +            >>> db = MSSQLDB(db='test', user='joe', pw='secret')
  1.1109 +            >>> db.select('foo', limit=4, _test=True)
  1.1110 +            <sql: 'SELECT * TOP 4 FROM foo'>
  1.1111 +        """
  1.1112 +        pass
  1.1113 +
  1.1114 +class OracleDB(DB): 
  1.1115 +    def __init__(self, **keywords): 
  1.1116 +        import cx_Oracle as db 
  1.1117 +        if 'pw' in keywords: 
  1.1118 +            keywords['password'] = keywords.pop('pw') 
  1.1119 +
  1.1120 +        #@@ TODO: use db.makedsn if host, port is specified 
  1.1121 +        keywords['dsn'] = keywords.pop('db') 
  1.1122 +        self.dbname = 'oracle' 
  1.1123 +        db.paramstyle = 'numeric' 
  1.1124 +        self.paramstyle = db.paramstyle
  1.1125 +
  1.1126 +        # oracle doesn't support pooling 
  1.1127 +        keywords.pop('pooling', None) 
  1.1128 +        DB.__init__(self, db, keywords) 
  1.1129 +
  1.1130 +    def _process_insert_query(self, query, tablename, seqname): 
  1.1131 +        if seqname is None: 
  1.1132 +            # It is not possible to get seq name from table name in Oracle
  1.1133 +            return query
  1.1134 +        else:
  1.1135 +            return query + "; SELECT %s.currval FROM dual" % seqname 
  1.1136 +
  1.1137 +_databases = {}
  1.1138 +def database(dburl=None, **params):
  1.1139 +    """Creates appropriate database using params.
  1.1140 +    
  1.1141 +    Pooling will be enabled if DBUtils module is available. 
  1.1142 +    Pooling can be disabled by passing pooling=False in params.
  1.1143 +    """
  1.1144 +    dbn = params.pop('dbn')
  1.1145 +    if dbn in _databases:
  1.1146 +        return _databases[dbn](**params)
  1.1147 +    else:
  1.1148 +        raise UnknownDB, dbn
  1.1149 +
  1.1150 +def register_database(name, clazz):
  1.1151 +    """
  1.1152 +    Register a database.
  1.1153 +
  1.1154 +        >>> class LegacyDB(DB): 
  1.1155 +        ...     def __init__(self, **params): 
  1.1156 +        ...        pass 
  1.1157 +        ...
  1.1158 +        >>> register_database('legacy', LegacyDB)
  1.1159 +        >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') 
  1.1160 +    """
  1.1161 +    _databases[name] = clazz
  1.1162 +
  1.1163 +register_database('mysql', MySQLDB)
  1.1164 +register_database('postgres', PostgresDB)
  1.1165 +register_database('sqlite', SqliteDB)
  1.1166 +register_database('firebird', FirebirdDB)
  1.1167 +register_database('mssql', MSSQLDB)
  1.1168 +register_database('oracle', OracleDB)
  1.1169 +
  1.1170 +def _interpolate(format): 
  1.1171 +    """
  1.1172 +    Takes a format string and returns a list of 2-tuples of the form
  1.1173 +    (boolean, string) where boolean says whether string should be evaled
  1.1174 +    or not.
  1.1175 +
  1.1176 +    from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
  1.1177 +    """
  1.1178 +    from tokenize import tokenprog
  1.1179 +
  1.1180 +    def matchorfail(text, pos):
  1.1181 +        match = tokenprog.match(text, pos)
  1.1182 +        if match is None:
  1.1183 +            raise _ItplError(text, pos)
  1.1184 +        return match, match.end()
  1.1185 +
  1.1186 +    namechars = "abcdefghijklmnopqrstuvwxyz" \
  1.1187 +        "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
  1.1188 +    chunks = []
  1.1189 +    pos = 0
  1.1190 +
  1.1191 +    while 1:
  1.1192 +        dollar = format.find("$", pos)
  1.1193 +        if dollar < 0: 
  1.1194 +            break
  1.1195 +        nextchar = format[dollar + 1]
  1.1196 +
  1.1197 +        if nextchar == "{":
  1.1198 +            chunks.append((0, format[pos:dollar]))
  1.1199 +            pos, level = dollar + 2, 1
  1.1200 +            while level:
  1.1201 +                match, pos = matchorfail(format, pos)
  1.1202 +                tstart, tend = match.regs[3]
  1.1203 +                token = format[tstart:tend]
  1.1204 +                if token == "{": 
  1.1205 +                    level = level + 1
  1.1206 +                elif token == "}":  
  1.1207 +                    level = level - 1
  1.1208 +            chunks.append((1, format[dollar + 2:pos - 1]))
  1.1209 +
  1.1210 +        elif nextchar in namechars:
  1.1211 +            chunks.append((0, format[pos:dollar]))
  1.1212 +            match, pos = matchorfail(format, dollar + 1)
  1.1213 +            while pos < len(format):
  1.1214 +                if format[pos] == "." and \
  1.1215 +                    pos + 1 < len(format) and format[pos + 1] in namechars:
  1.1216 +                    match, pos = matchorfail(format, pos + 1)
  1.1217 +                elif format[pos] in "([":
  1.1218 +                    pos, level = pos + 1, 1
  1.1219 +                    while level:
  1.1220 +                        match, pos = matchorfail(format, pos)
  1.1221 +                        tstart, tend = match.regs[3]
  1.1222 +                        token = format[tstart:tend]
  1.1223 +                        if token[0] in "([": 
  1.1224 +                            level = level + 1
  1.1225 +                        elif token[0] in ")]":  
  1.1226 +                            level = level - 1
  1.1227 +                else: 
  1.1228 +                    break
  1.1229 +            chunks.append((1, format[dollar + 1:pos]))
  1.1230 +        else:
  1.1231 +            chunks.append((0, format[pos:dollar + 1]))
  1.1232 +            pos = dollar + 1 + (nextchar == "$")
  1.1233 +
  1.1234 +    if pos < len(format): 
  1.1235 +        chunks.append((0, format[pos:]))
  1.1236 +    return chunks
  1.1237 +
  1.1238 +if __name__ == "__main__":
  1.1239 +    import doctest
  1.1240 +    doctest.testmod()