7 "UnknownParamstyle", "UnknownDB", "TransactionError",
8 "sqllist", "sqlors", "reparam", "sqlquote",
9 "SQLQuery", "SQLParam", "sqlparam",
10 "SQLLiteral", "sqlliteral",
22 from sets import Set as set
24 from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
27 # db module can work independent of web.py
28 from webapi import debug, config
34 class UnknownDB(Exception):
35 """raised for unsupported dbms"""
38 class _ItplError(ValueError):
39 def __init__(self, text, pos):
40 ValueError.__init__(self)
44 return "unfinished expression in %s at char %d" % (
45 repr(self.text), self.pos)
47 class TransactionError(Exception): pass
49 class UnknownParamstyle(Exception):
51 raised for unsupported db paramstyles
53 (currently supported: qmark, numeric, format, pyformat)
57 class SQLParam(object):
59 Parameter in SQLQuery.
61 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
63 <sql: "SELECT * FROM test WHERE name='joe'">
65 'SELECT * FROM test WHERE name=%s'
71 def __init__(self, value):
74 def get_marker(self, paramstyle='pyformat'):
75 if paramstyle == 'qmark':
77 elif paramstyle == 'numeric':
79 elif paramstyle is None or paramstyle in ['format', 'pyformat']:
81 raise UnknownParamstyle, paramstyle
84 return SQLQuery([self])
86 def __add__(self, other):
87 return self.sqlquery() + other
89 def __radd__(self, other):
90 return other + self.sqlquery()
93 return str(self.value)
96 return '<param: %s>' % repr(self.value)
100 class SQLQuery(object):
102 You can pass this sort of thing as a clause in any db function.
103 Otherwise, you can pass a dictionary to the keyword argument `vars`
104 and the function will call reparam for you.
106 Internally, consists of `items`, which is a list of strings and
107 SQLParams, which get concatenated to produce the actual query.
109 __slots__ = ["items"]
111 # tested in sqlquote's docstring
112 def __init__(self, items=None):
113 r"""Creates a new SQLQuery.
117 >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
119 <sql: 'SELECT * FROM test WHERE x=1'>
120 >>> q.query(), q.values()
121 ('SELECT * FROM test WHERE x=%s', [1])
122 >>> SQLQuery(SQLParam(1))
127 elif isinstance(items, list):
129 elif isinstance(items, SQLParam):
131 elif isinstance(items, SQLQuery):
132 self.items = list(items.items)
136 # Take care of SQLLiterals
137 for i, item in enumerate(self.items):
138 if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
139 self.items[i] = item.value.v
141 def append(self, value):
142 self.items.append(value)
144 def __add__(self, other):
145 if isinstance(other, basestring):
147 elif isinstance(other, SQLQuery):
150 return NotImplemented
151 return SQLQuery(self.items + items)
153 def __radd__(self, other):
154 if isinstance(other, basestring):
157 return NotImplemented
159 return SQLQuery(items + self.items)
161 def __iadd__(self, other):
162 if isinstance(other, (basestring, SQLParam)):
163 self.items.append(other)
164 elif isinstance(other, SQLQuery):
165 self.items.extend(other.items)
167 return NotImplemented
171 return len(self.query())
173 def query(self, paramstyle=None):
175 Returns the query part of the sql query.
176 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
178 'SELECT * FROM test WHERE name=%s'
179 >>> q.query(paramstyle='qmark')
180 'SELECT * FROM test WHERE name=?'
184 if isinstance(x, SQLParam):
185 x = x.get_marker(paramstyle)
189 # automatically escape % characters in the query
190 # For backward compatability, ignore escaping when the query looks already escaped
191 if paramstyle in ['format', 'pyformat']:
192 if '%' in x and '%%' not in x:
193 x = x.replace('%', '%%')
199 Returns the values of the parameters used in the sql query.
200 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
204 return [i.value for i in self.items if isinstance(i, SQLParam)]
206 def join(items, sep=' ', prefix=None, suffix=None, target=None):
208 Joins multiple queries.
210 >>> SQLQuery.join(['a', 'b'], ', ')
213 Optinally, prefix and suffix arguments can be provided.
215 >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
218 If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
223 target_items = target.items
226 target_items.append(prefix)
228 for i, item in enumerate(items):
230 target_items.append(sep)
231 if isinstance(item, SQLQuery):
232 target_items.extend(item.items)
234 target_items.append(item)
237 target_items.append(suffix)
240 join = staticmethod(join)
244 return self.query() % tuple([sqlify(x) for x in self.values()])
245 except (ValueError, TypeError):
249 return safestr(self._str())
251 def __unicode__(self):
252 return safeunicode(self._str())
255 return '<sql: %s>' % repr(str(self))
259 Protects a string from `sqlquote`.
261 >>> sqlquote('NOW()')
263 >>> sqlquote(SQLLiteral('NOW()'))
266 def __init__(self, v):
272 sqlliteral = SQLLiteral
274 def _sqllist(values):
276 >>> _sqllist([1, 2, 3])
281 for i, v in enumerate(values):
284 items.append(sqlparam(v))
286 return SQLQuery(items)
288 def reparam(string_, dictionary):
290 Takes a string and a dictionary and interpolates the string
291 using values from the dictionary. Returns an `SQLQuery` for the result.
293 >>> reparam("s = $s", dict(s=True))
295 >>> reparam("s IN $s", dict(s=[1, 2]))
298 dictionary = dictionary.copy() # eval mucks with it
301 for live, chunk in _interpolate(string_):
303 v = eval(chunk, dictionary)
304 result.append(sqlquote(v))
307 return SQLQuery.join(result, '')
311 converts `obj` to its proper SQL version
320 # because `1 == True and hash(1) == hash(True)`
321 # we have to do this the hard way...
329 elif datetime and isinstance(obj, datetime.datetime):
330 return repr(obj.isoformat())
332 if isinstance(obj, unicode): obj = obj.encode('utf8')
337 Converts the arguments for use in something like a WHERE clause.
339 >>> sqllist(['a', 'b'])
346 if isinstance(lst, basestring):
349 return ', '.join(lst)
351 def sqlors(left, lst):
353 `left is a SQL clause like `tablename.arg = `
354 and `lst` is a list of values. Returns a reparam-style
355 pair featuring the SQL that ORs together the clause
356 for each item in the lst.
358 >>> sqlors('foo = ', [])
360 >>> sqlors('foo = ', [1])
362 >>> sqlors('foo = ', 1)
364 >>> sqlors('foo = ', [1,2,3])
365 <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
367 if isinstance(lst, iters):
371 return SQLQuery("1=2")
375 if isinstance(lst, iters):
376 return SQLQuery(['('] +
377 sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
381 return left + sqlparam(lst)
383 def sqlwhere(dictionary, grouping=' AND '):
385 Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
387 >>> sqlwhere({'cust_id': 2, 'order_id':3})
388 <sql: 'order_id = 3 AND cust_id = 2'>
389 >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
390 <sql: 'order_id = 3, cust_id = 2'>
391 >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
394 return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
398 Ensures `a` is quoted properly for use in a SQL query.
400 >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
401 <sql: "WHERE x = 't' AND y = 3">
402 >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
403 <sql: "WHERE x = 't' AND y IN (2, 3)">
405 if isinstance(a, list):
408 return sqlparam(a).sqlquery()
411 """Database transaction."""
412 def __init__(self, ctx):
414 self.transaction_count = transaction_count = len(ctx.transactions)
416 class transaction_engine:
417 """Transaction Engine used in top level transactions."""
418 def do_transact(self):
419 ctx.commit(unload=False)
424 def do_rollback(self):
427 class subtransaction_engine:
428 """Transaction Engine used in sub transactions."""
430 db_cursor = ctx.db.cursor()
431 ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
433 def do_transact(self):
434 self.query('SAVEPOINT webpy_sp_%s')
437 self.query('RELEASE SAVEPOINT webpy_sp_%s')
439 def do_rollback(self):
440 self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
443 """Transaction Engine used instead of subtransaction_engine
444 when sub transactions are not supported."""
445 do_transact = do_commit = do_rollback = lambda self: None
447 if self.transaction_count:
448 # nested transactions are not supported in some databases
449 if self.ctx.get('ignore_nested_transactions'):
450 self.engine = dummy_engine()
452 self.engine = subtransaction_engine()
454 self.engine = transaction_engine()
456 self.engine.do_transact()
457 self.ctx.transactions.append(self)
462 def __exit__(self, exctype, excvalue, traceback):
463 if exctype is not None:
469 if len(self.ctx.transactions) > self.transaction_count:
470 self.engine.do_commit()
471 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
474 if len(self.ctx.transactions) > self.transaction_count:
475 self.engine.do_rollback()
476 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
480 def __init__(self, db_module, keywords):
481 """Creates a database.
483 # some DB implementaions take optional paramater `driver` to use a specific driver modue
484 # but it should not be passed to connect
485 keywords.pop('driver', None)
487 self.db_module = db_module
488 self.keywords = keywords
490 self._ctx = threadeddict()
491 # flag to enable/disable printing queries
492 self.printing = config.get('debug_sql', config.get('debug', False))
493 self.supports_multiple_insert = False
497 # enable pooling if DBUtils module is available.
498 self.has_pooling = True
500 self.has_pooling = False
502 # Pooling can be disabled by passing pooling=False in the keywords.
503 self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
506 if not self._ctx.get('db'):
507 self._load_context(self._ctx)
509 ctx = property(_getctx)
511 def _load_context(self, ctx):
513 ctx.transactions = [] # stack of transactions
516 ctx.db = self._connect_with_pooling(self.keywords)
518 ctx.db = self._connect(self.keywords)
519 ctx.db_execute = self._db_execute
521 if not hasattr(ctx.db, 'commit'):
522 ctx.db.commit = lambda: None
524 if not hasattr(ctx.db, 'rollback'):
525 ctx.db.rollback = lambda: None
527 def commit(unload=True):
528 # do db commit and release the connection if pooling is enabled.
530 if unload and self.has_pooling:
531 self._unload_context(self._ctx)
534 # do db rollback and release the connection if pooling is enabled.
537 self._unload_context(self._ctx)
540 ctx.rollback = rollback
542 def _unload_context(self, ctx):
545 def _connect(self, keywords):
546 return self.db_module.connect(**keywords)
548 def _connect_with_pooling(self, keywords):
550 from DBUtils import PooledDB
552 # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
555 if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
556 return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
558 return PooledDB.PooledDB(creator=self.db_module, **keywords)
560 if getattr(self, '_pooleddb', None) is None:
561 self._pooleddb = get_pooled_db()
563 return self._pooleddb.connection()
565 def _db_cursor(self):
566 return self.ctx.db.cursor()
568 def _param_marker(self):
569 """Returns parameter marker based on paramstyle attribute if this database."""
570 style = getattr(self, 'paramstyle', 'pyformat')
574 elif style == 'numeric':
576 elif style in ['format', 'pyformat']:
578 raise UnknownParamstyle, style
580 def _db_execute(self, cur, sql_query):
581 """executes an sql query"""
582 self.ctx.dbq_count += 1
586 query, params = self._process_query(sql_query)
587 out = cur.execute(query, params)
591 print >> debug, 'ERR:', str(sql_query)
592 if self.ctx.transactions:
593 self.ctx.transactions[-1].rollback()
599 print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
602 def _process_query(self, sql_query):
603 """Takes the SQLQuery object and returns query string and parameters.
605 paramstyle = getattr(self, 'paramstyle', 'pyformat')
606 query = sql_query.query(paramstyle)
607 params = sql_query.values()
610 def _where(self, where, vars):
611 if isinstance(where, (int, long)):
612 where = "id = " + sqlparam(where)
613 #@@@ for backward-compatibility
614 elif isinstance(where, (list, tuple)) and len(where) == 2:
615 where = SQLQuery(where[0], where[1])
616 elif isinstance(where, SQLQuery):
619 where = reparam(where, vars)
622 def query(self, sql_query, vars=None, processed=False, _test=False):
624 Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
625 If `processed=True`, `vars` is a `reparam`-style list to use
626 instead of interpolating.
628 >>> db = DB(None, {})
629 >>> db.query("SELECT * FROM foo", _test=True)
630 <sql: 'SELECT * FROM foo'>
631 >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
632 <sql: "SELECT * FROM foo WHERE x = 'f'">
633 >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
634 <sql: "SELECT * FROM foo WHERE x = 'f'">
636 if vars is None: vars = {}
638 if not processed and not isinstance(sql_query, SQLQuery):
639 sql_query = reparam(sql_query, vars)
641 if _test: return sql_query
643 db_cursor = self._db_cursor()
644 self._db_execute(db_cursor, sql_query)
646 if db_cursor.description:
647 names = [x[0] for x in db_cursor.description]
649 row = db_cursor.fetchone()
651 yield storage(dict(zip(names, row)))
652 row = db_cursor.fetchone()
653 out = iterbetter(iterwrapper())
654 out.__len__ = lambda: int(db_cursor.rowcount)
655 out.list = lambda: [storage(dict(zip(names, x))) \
656 for x in db_cursor.fetchall()]
658 out = db_cursor.rowcount
660 if not self.ctx.transactions:
664 def select(self, tables, vars=None, what='*', where=None, order=None, group=None,
665 limit=None, offset=None, _test=False):
667 Selects `what` from `tables` with clauses `where`, `order`,
668 `group`, `limit`, and `offset`. Uses vars to interpolate.
669 Otherwise, each clause can be a SQLQuery.
671 >>> db = DB(None, {})
672 >>> db.select('foo', _test=True)
673 <sql: 'SELECT * FROM foo'>
674 >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
675 <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
677 if vars is None: vars = {}
678 sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
679 clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
680 qout = SQLQuery.join(clauses)
681 if _test: return qout
682 return self.query(qout, processed=True)
684 def where(self, table, what='*', order=None, group=None, limit=None,
685 offset=None, _test=False, **kwargs):
687 Selects from `table` where keys are equal to values in `kwargs`.
689 >>> db = DB(None, {})
690 >>> db.where('foo', bar_id=3, _test=True)
691 <sql: 'SELECT * FROM foo WHERE bar_id = 3'>
692 >>> db.where('foo', source=2, crust='dewey', _test=True)
693 <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
694 >>> db.where('foo', _test=True)
695 <sql: 'SELECT * FROM foo'>
698 for k, v in kwargs.iteritems():
699 where_clauses.append(k + ' = ' + sqlquote(v))
702 where = SQLQuery.join(where_clauses, " AND ")
706 return self.select(table, what=what, order=order,
707 group=group, limit=limit, offset=offset, _test=_test,
710 def sql_clauses(self, what, tables, where, group, order, limit, offset):
713 ('FROM', sqllist(tables)),
720 def gen_clause(self, sql, val, vars):
721 if isinstance(val, (int, long)):
723 nout = 'id = ' + sqlquote(val)
727 elif isinstance(val, (list, tuple)) and len(val) == 2:
728 nout = SQLQuery(val[0], val[1]) # backwards-compatibility
729 elif isinstance(val, SQLQuery):
732 nout = reparam(val, vars)
735 if a and b: return a + ' ' + b
738 return xjoin(sql, nout)
740 def insert(self, tablename, seqname=None, _test=False, **values):
742 Inserts `values` into `tablename`. Returns current sequence ID.
743 Set `seqname` to the ID if it's not the default, or to `False`
746 >>> db = DB(None, {})
747 >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
749 <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
751 'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
755 def q(x): return "(" + x + ")"
758 _keys = SQLQuery.join(values.keys(), ', ')
759 _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
760 sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
762 sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
764 if _test: return sql_query
766 db_cursor = self._db_cursor()
767 if seqname is not False:
768 sql_query = self._process_insert_query(sql_query, tablename, seqname)
770 if isinstance(sql_query, tuple):
771 # for some databases, a separate query has to be made to find
772 # the id of the inserted row.
774 self._db_execute(db_cursor, q1)
775 self._db_execute(db_cursor, q2)
777 self._db_execute(db_cursor, sql_query)
780 out = db_cursor.fetchone()[0]
784 if not self.ctx.transactions:
788 def _get_insert_default_values_query(self, table):
789 return "INSERT INTO %s DEFAULT VALUES" % table
791 def multiple_insert(self, tablename, values, seqname=None, _test=False):
793 Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries,
794 one for each row to be inserted, each with the same set of keys.
795 Returns the list of ids of the inserted rows.
796 Set `seqname` to the ID if it's not the default, or to `False`
799 >>> db = DB(None, {})
800 >>> db.supports_multiple_insert = True
801 >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
802 >>> db.multiple_insert('person', values=values, _test=True)
803 <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
808 if not self.supports_multiple_insert:
809 out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
815 keys = values[0].keys()
816 #@@ make sure all keys are valid
818 # make sure all rows have same keys.
821 raise ValueError, 'Bad data'
823 sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
825 for i, row in enumerate(values):
827 sql_query.append(", ")
828 SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
830 if _test: return sql_query
832 db_cursor = self._db_cursor()
833 if seqname is not False:
834 sql_query = self._process_insert_query(sql_query, tablename, seqname)
836 if isinstance(sql_query, tuple):
837 # for some databases, a separate query has to be made to find
838 # the id of the inserted row.
840 self._db_execute(db_cursor, q1)
841 self._db_execute(db_cursor, q2)
843 self._db_execute(db_cursor, sql_query)
846 out = db_cursor.fetchone()[0]
847 out = range(out-len(values)+1, out+1)
851 if not self.ctx.transactions:
856 def update(self, tables, where, vars=None, _test=False, **values):
858 Update `tables` with clause `where` (interpolated using `vars`)
859 and setting `values`.
861 >>> db = DB(None, {})
863 >>> q = db.update('foo', where='name = $name', name='bob', age=2,
864 ... created=SQLLiteral('NOW()'), vars=locals(), _test=True)
866 <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
868 'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
872 if vars is None: vars = {}
873 where = self._where(where, vars)
876 "UPDATE " + sqllist(tables) +
877 " SET " + sqlwhere(values, ', ') +
880 if _test: return query
882 db_cursor = self._db_cursor()
883 self._db_execute(db_cursor, query)
884 if not self.ctx.transactions:
886 return db_cursor.rowcount
888 def delete(self, table, where, using=None, vars=None, _test=False):
890 Deletes from `table` with clauses `where` and `using`.
892 >>> db = DB(None, {})
894 >>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
895 <sql: "DELETE FROM foo WHERE name = 'Joe'">
897 if vars is None: vars = {}
898 where = self._where(where, vars)
900 q = 'DELETE FROM ' + table
901 if using: q += ' USING ' + sqllist(using)
902 if where: q += ' WHERE ' + where
906 db_cursor = self._db_cursor()
907 self._db_execute(db_cursor, q)
908 if not self.ctx.transactions:
910 return db_cursor.rowcount
912 def _process_insert_query(self, query, tablename, seqname):
915 def transaction(self):
916 """Start a transaction."""
917 return Transaction(self.ctx)
919 class PostgresDB(DB):
920 """Postgres driver."""
921 def __init__(self, **keywords):
923 keywords['password'] = keywords.pop('pw')
925 db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
926 if db_module.__name__ == "psycopg2":
927 import psycopg2.extensions
928 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
930 # if db is not provided postgres driver will take it from PGDATABASE environment variable
932 keywords['database'] = keywords.pop('db')
934 self.dbname = "postgres"
935 self.paramstyle = db_module.paramstyle
936 DB.__init__(self, db_module, keywords)
937 self.supports_multiple_insert = True
938 self._sequences = None
940 def _process_insert_query(self, query, tablename, seqname):
942 # when seqname is not provided guess the seqname and make sure it exists
943 seqname = tablename + "_id_seq"
944 if seqname not in self._get_all_sequences():
948 query += "; SELECT currval('%s')" % seqname
952 def _get_all_sequences(self):
953 """Query postgres to find names of all sequences used in this database."""
954 if self._sequences is None:
955 q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
956 self._sequences = set([c.relname for c in self.query(q)])
957 return self._sequences
959 def _connect(self, keywords):
960 conn = DB._connect(self, keywords)
962 conn.set_client_encoding('UTF8')
963 except AttributeError:
964 # fallback for pgdb driver
965 conn.cursor().execute("set client_encoding to 'UTF-8'")
968 def _connect_with_pooling(self, keywords):
969 conn = DB._connect_with_pooling(self, keywords)
970 conn._con._con.set_client_encoding('UTF8')
974 def __init__(self, **keywords):
977 keywords['passwd'] = keywords['pw']
980 if 'charset' not in keywords:
981 keywords['charset'] = 'utf8'
982 elif keywords['charset'] is None:
983 del keywords['charset']
985 self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
986 self.dbname = "mysql"
987 DB.__init__(self, db, keywords)
988 self.supports_multiple_insert = True
990 def _process_insert_query(self, query, tablename, seqname):
991 return query, SQLQuery('SELECT last_insert_id();')
993 def _get_insert_default_values_query(self, table):
994 return "INSERT INTO %s () VALUES()" % table
996 def import_driver(drivers, preferred=None):
997 """Import the first available driver or preferred driver.
1000 drivers = [preferred]
1004 return __import__(d, None, None, ['x'])
1007 raise ImportError("Unable to import " + " or ".join(drivers))
1010 def __init__(self, **keywords):
1011 db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
1013 if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
1014 db.paramstyle = 'qmark'
1016 # sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
1017 # It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
1018 keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
1020 self.paramstyle = db.paramstyle
1021 keywords['database'] = keywords.pop('db')
1022 keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
1023 self.dbname = "sqlite"
1024 DB.__init__(self, db, keywords)
1026 def _process_insert_query(self, query, tablename, seqname):
1027 return query, SQLQuery('SELECT last_insert_rowid();')
1029 def query(self, *a, **kw):
1030 out = DB.query(self, *a, **kw)
1031 if isinstance(out, iterbetter):
1035 class FirebirdDB(DB):
1036 """Firebird Database.
1038 def __init__(self, **keywords):
1040 import kinterbasdb as db
1044 if 'pw' in keywords:
1045 keywords['passwd'] = keywords['pw']
1047 keywords['database'] = keywords['db']
1049 DB.__init__(self, db, keywords)
1051 def delete(self, table, where=None, using=None, vars=None, _test=False):
1052 # firebird doesn't support using clause
1054 return DB.delete(self, table, where, using, vars, _test)
1056 def sql_clauses(self, what, tables, where, group, order, limit, offset):
1062 ('FROM', sqllist(tables)),
1064 ('GROUP BY', group),
1069 def __init__(self, **keywords):
1070 import pymssql as db
1071 if 'pw' in keywords:
1072 keywords['password'] = keywords.pop('pw')
1073 keywords['database'] = keywords.pop('db')
1074 self.dbname = "mssql"
1075 DB.__init__(self, db, keywords)
1077 def _process_query(self, sql_query):
1078 """Takes the SQLQuery object and returns query string and parameters.
1080 # MSSQLDB expects params to be a tuple.
1081 # Overwriting the default implementation to convert params to tuple.
1082 paramstyle = getattr(self, 'paramstyle', 'pyformat')
1083 query = sql_query.query(paramstyle)
1084 params = sql_query.values()
1085 return query, tuple(params)
1087 def sql_clauses(self, what, tables, where, group, order, limit, offset):
1091 ('FROM', sqllist(tables)),
1093 ('GROUP BY', group),
1094 ('ORDER BY', order),
1100 Fake presence of pymssql module for running tests.
1102 >>> sys.modules['pymssql'] = sys.modules['sys']
1104 MSSQL has TOP clause instead of LIMIT clause.
1105 >>> db = MSSQLDB(db='test', user='joe', pw='secret')
1106 >>> db.select('foo', limit=4, _test=True)
1107 <sql: 'SELECT * TOP 4 FROM foo'>
1112 def __init__(self, **keywords):
1113 import cx_Oracle as db
1114 if 'pw' in keywords:
1115 keywords['password'] = keywords.pop('pw')
1117 #@@ TODO: use db.makedsn if host, port is specified
1118 keywords['dsn'] = keywords.pop('db')
1119 self.dbname = 'oracle'
1120 db.paramstyle = 'numeric'
1121 self.paramstyle = db.paramstyle
1123 # oracle doesn't support pooling
1124 keywords.pop('pooling', None)
1125 DB.__init__(self, db, keywords)
1127 def _process_insert_query(self, query, tablename, seqname):
1129 # It is not possible to get seq name from table name in Oracle
1132 return query + "; SELECT %s.currval FROM dual" % seqname
1135 def database(dburl=None, **params):
1136 """Creates appropriate database using params.
1138 Pooling will be enabled if DBUtils module is available.
1139 Pooling can be disabled by passing pooling=False in params.
1141 dbn = params.pop('dbn')
1142 if dbn in _databases:
1143 return _databases[dbn](**params)
1145 raise UnknownDB, dbn
1147 def register_database(name, clazz):
1149 Register a database.
1151 >>> class LegacyDB(DB):
1152 ... def __init__(self, **params):
1155 >>> register_database('legacy', LegacyDB)
1156 >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret')
1158 _databases[name] = clazz
1160 register_database('mysql', MySQLDB)
1161 register_database('postgres', PostgresDB)
1162 register_database('sqlite', SqliteDB)
1163 register_database('firebird', FirebirdDB)
1164 register_database('mssql', MSSQLDB)
1165 register_database('oracle', OracleDB)
1167 def _interpolate(format):
1169 Takes a format string and returns a list of 2-tuples of the form
1170 (boolean, string) where boolean says whether string should be evaled
1173 from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
1175 from tokenize import tokenprog
1177 def matchorfail(text, pos):
1178 match = tokenprog.match(text, pos)
1180 raise _ItplError(text, pos)
1181 return match, match.end()
1183 namechars = "abcdefghijklmnopqrstuvwxyz" \
1184 "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
1189 dollar = format.find("$", pos)
1192 nextchar = format[dollar + 1]
1195 chunks.append((0, format[pos:dollar]))
1196 pos, level = dollar + 2, 1
1198 match, pos = matchorfail(format, pos)
1199 tstart, tend = match.regs[3]
1200 token = format[tstart:tend]
1205 chunks.append((1, format[dollar + 2:pos - 1]))
1207 elif nextchar in namechars:
1208 chunks.append((0, format[pos:dollar]))
1209 match, pos = matchorfail(format, dollar + 1)
1210 while pos < len(format):
1211 if format[pos] == "." and \
1212 pos + 1 < len(format) and format[pos + 1] in namechars:
1213 match, pos = matchorfail(format, pos + 1)
1214 elif format[pos] in "([":
1215 pos, level = pos + 1, 1
1217 match, pos = matchorfail(format, pos)
1218 tstart, tend = match.regs[3]
1219 token = format[tstart:tend]
1220 if token[0] in "([":
1222 elif token[0] in ")]":
1226 chunks.append((1, format[dollar + 1:pos]))
1228 chunks.append((0, format[pos:dollar + 1]))
1229 pos = dollar + 1 + (nextchar == "$")
1231 if pos < len(format):
1232 chunks.append((0, format[pos:]))
1235 if __name__ == "__main__":