om@3
|
1 |
"""
|
om@3
|
2 |
Database API
|
om@3
|
3 |
(part of web.py)
|
om@3
|
4 |
"""
|
om@3
|
5 |
|
om@3
|
6 |
__all__ = [
|
om@3
|
7 |
"UnknownParamstyle", "UnknownDB", "TransactionError",
|
om@3
|
8 |
"sqllist", "sqlors", "reparam", "sqlquote",
|
om@3
|
9 |
"SQLQuery", "SQLParam", "sqlparam",
|
om@3
|
10 |
"SQLLiteral", "sqlliteral",
|
om@3
|
11 |
"database", 'DB',
|
om@3
|
12 |
]
|
om@3
|
13 |
|
om@3
|
14 |
import time
|
om@3
|
15 |
try:
|
om@3
|
16 |
import datetime
|
om@3
|
17 |
except ImportError:
|
om@3
|
18 |
datetime = None
|
om@3
|
19 |
|
om@3
|
20 |
try: set
|
om@3
|
21 |
except NameError:
|
om@3
|
22 |
from sets import Set as set
|
om@3
|
23 |
|
om@3
|
24 |
from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
|
om@3
|
25 |
|
om@3
|
26 |
try:
|
om@3
|
27 |
# db module can work independent of web.py
|
om@3
|
28 |
from webapi import debug, config
|
om@3
|
29 |
except:
|
om@3
|
30 |
import sys
|
om@3
|
31 |
debug = sys.stderr
|
om@3
|
32 |
config = storage()
|
om@3
|
33 |
|
om@3
|
34 |
class UnknownDB(Exception):
|
om@3
|
35 |
"""raised for unsupported dbms"""
|
om@3
|
36 |
pass
|
om@3
|
37 |
|
om@3
|
38 |
class _ItplError(ValueError):
|
om@3
|
39 |
def __init__(self, text, pos):
|
om@3
|
40 |
ValueError.__init__(self)
|
om@3
|
41 |
self.text = text
|
om@3
|
42 |
self.pos = pos
|
om@3
|
43 |
def __str__(self):
|
om@3
|
44 |
return "unfinished expression in %s at char %d" % (
|
om@3
|
45 |
repr(self.text), self.pos)
|
om@3
|
46 |
|
om@3
|
47 |
class TransactionError(Exception): pass
|
om@3
|
48 |
|
om@3
|
49 |
class UnknownParamstyle(Exception):
|
om@3
|
50 |
"""
|
om@3
|
51 |
raised for unsupported db paramstyles
|
om@3
|
52 |
|
om@3
|
53 |
(currently supported: qmark, numeric, format, pyformat)
|
om@3
|
54 |
"""
|
om@3
|
55 |
pass
|
om@3
|
56 |
|
om@3
|
57 |
class SQLParam(object):
|
om@3
|
58 |
"""
|
om@3
|
59 |
Parameter in SQLQuery.
|
om@3
|
60 |
|
om@3
|
61 |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
|
om@3
|
62 |
>>> q
|
om@3
|
63 |
<sql: "SELECT * FROM test WHERE name='joe'">
|
om@3
|
64 |
>>> q.query()
|
om@3
|
65 |
'SELECT * FROM test WHERE name=%s'
|
om@3
|
66 |
>>> q.values()
|
om@3
|
67 |
['joe']
|
om@3
|
68 |
"""
|
om@3
|
69 |
__slots__ = ["value"]
|
om@3
|
70 |
|
om@3
|
71 |
def __init__(self, value):
|
om@3
|
72 |
self.value = value
|
om@3
|
73 |
|
om@3
|
74 |
def get_marker(self, paramstyle='pyformat'):
|
om@3
|
75 |
if paramstyle == 'qmark':
|
om@3
|
76 |
return '?'
|
om@3
|
77 |
elif paramstyle == 'numeric':
|
om@3
|
78 |
return ':1'
|
om@3
|
79 |
elif paramstyle is None or paramstyle in ['format', 'pyformat']:
|
om@3
|
80 |
return '%s'
|
om@3
|
81 |
raise UnknownParamstyle, paramstyle
|
om@3
|
82 |
|
om@3
|
83 |
def sqlquery(self):
|
om@3
|
84 |
return SQLQuery([self])
|
om@3
|
85 |
|
om@3
|
86 |
def __add__(self, other):
|
om@3
|
87 |
return self.sqlquery() + other
|
om@3
|
88 |
|
om@3
|
89 |
def __radd__(self, other):
|
om@3
|
90 |
return other + self.sqlquery()
|
om@3
|
91 |
|
om@3
|
92 |
def __str__(self):
|
om@3
|
93 |
return str(self.value)
|
om@3
|
94 |
|
om@3
|
95 |
def __repr__(self):
|
om@3
|
96 |
return '<param: %s>' % repr(self.value)
|
om@3
|
97 |
|
om@3
|
98 |
sqlparam = SQLParam
|
om@3
|
99 |
|
om@3
|
100 |
class SQLQuery(object):
|
om@3
|
101 |
"""
|
om@3
|
102 |
You can pass this sort of thing as a clause in any db function.
|
om@3
|
103 |
Otherwise, you can pass a dictionary to the keyword argument `vars`
|
om@3
|
104 |
and the function will call reparam for you.
|
om@3
|
105 |
|
om@3
|
106 |
Internally, consists of `items`, which is a list of strings and
|
om@3
|
107 |
SQLParams, which get concatenated to produce the actual query.
|
om@3
|
108 |
"""
|
om@3
|
109 |
__slots__ = ["items"]
|
om@3
|
110 |
|
om@3
|
111 |
# tested in sqlquote's docstring
|
om@3
|
112 |
def __init__(self, items=None):
|
om@3
|
113 |
r"""Creates a new SQLQuery.
|
om@3
|
114 |
|
om@3
|
115 |
>>> SQLQuery("x")
|
om@3
|
116 |
<sql: 'x'>
|
om@3
|
117 |
>>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
|
om@3
|
118 |
>>> q
|
om@3
|
119 |
<sql: 'SELECT * FROM test WHERE x=1'>
|
om@3
|
120 |
>>> q.query(), q.values()
|
om@3
|
121 |
('SELECT * FROM test WHERE x=%s', [1])
|
om@3
|
122 |
>>> SQLQuery(SQLParam(1))
|
om@3
|
123 |
<sql: '1'>
|
om@3
|
124 |
"""
|
om@3
|
125 |
if items is None:
|
om@3
|
126 |
self.items = []
|
om@3
|
127 |
elif isinstance(items, list):
|
om@3
|
128 |
self.items = items
|
om@3
|
129 |
elif isinstance(items, SQLParam):
|
om@3
|
130 |
self.items = [items]
|
om@3
|
131 |
elif isinstance(items, SQLQuery):
|
om@3
|
132 |
self.items = list(items.items)
|
om@3
|
133 |
else:
|
om@3
|
134 |
self.items = [items]
|
om@3
|
135 |
|
om@3
|
136 |
# Take care of SQLLiterals
|
om@3
|
137 |
for i, item in enumerate(self.items):
|
om@3
|
138 |
if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
|
om@3
|
139 |
self.items[i] = item.value.v
|
om@3
|
140 |
|
om@3
|
141 |
def append(self, value):
|
om@3
|
142 |
self.items.append(value)
|
om@3
|
143 |
|
om@3
|
144 |
def __add__(self, other):
|
om@3
|
145 |
if isinstance(other, basestring):
|
om@3
|
146 |
items = [other]
|
om@3
|
147 |
elif isinstance(other, SQLQuery):
|
om@3
|
148 |
items = other.items
|
om@3
|
149 |
else:
|
om@3
|
150 |
return NotImplemented
|
om@3
|
151 |
return SQLQuery(self.items + items)
|
om@3
|
152 |
|
om@3
|
153 |
def __radd__(self, other):
|
om@3
|
154 |
if isinstance(other, basestring):
|
om@3
|
155 |
items = [other]
|
om@3
|
156 |
else:
|
om@3
|
157 |
return NotImplemented
|
om@3
|
158 |
|
om@3
|
159 |
return SQLQuery(items + self.items)
|
om@3
|
160 |
|
om@3
|
161 |
def __iadd__(self, other):
|
om@3
|
162 |
if isinstance(other, (basestring, SQLParam)):
|
om@3
|
163 |
self.items.append(other)
|
om@3
|
164 |
elif isinstance(other, SQLQuery):
|
om@3
|
165 |
self.items.extend(other.items)
|
om@3
|
166 |
else:
|
om@3
|
167 |
return NotImplemented
|
om@3
|
168 |
return self
|
om@3
|
169 |
|
om@3
|
170 |
def __len__(self):
|
om@3
|
171 |
return len(self.query())
|
om@3
|
172 |
|
om@3
|
173 |
def query(self, paramstyle=None):
|
om@3
|
174 |
"""
|
om@3
|
175 |
Returns the query part of the sql query.
|
om@3
|
176 |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
|
om@3
|
177 |
>>> q.query()
|
om@3
|
178 |
'SELECT * FROM test WHERE name=%s'
|
om@3
|
179 |
>>> q.query(paramstyle='qmark')
|
om@3
|
180 |
'SELECT * FROM test WHERE name=?'
|
om@3
|
181 |
"""
|
om@3
|
182 |
s = []
|
om@3
|
183 |
for x in self.items:
|
om@3
|
184 |
if isinstance(x, SQLParam):
|
om@3
|
185 |
x = x.get_marker(paramstyle)
|
om@3
|
186 |
s.append(safestr(x))
|
om@3
|
187 |
else:
|
om@3
|
188 |
x = safestr(x)
|
om@3
|
189 |
# automatically escape % characters in the query
|
om@3
|
190 |
# For backward compatability, ignore escaping when the query looks already escaped
|
om@3
|
191 |
if paramstyle in ['format', 'pyformat']:
|
om@3
|
192 |
if '%' in x and '%%' not in x:
|
om@3
|
193 |
x = x.replace('%', '%%')
|
om@3
|
194 |
s.append(x)
|
om@3
|
195 |
return "".join(s)
|
om@3
|
196 |
|
om@3
|
197 |
def values(self):
|
om@3
|
198 |
"""
|
om@3
|
199 |
Returns the values of the parameters used in the sql query.
|
om@3
|
200 |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
|
om@3
|
201 |
>>> q.values()
|
om@3
|
202 |
['joe']
|
om@3
|
203 |
"""
|
om@3
|
204 |
return [i.value for i in self.items if isinstance(i, SQLParam)]
|
om@3
|
205 |
|
om@3
|
206 |
def join(items, sep=' ', prefix=None, suffix=None, target=None):
|
om@3
|
207 |
"""
|
om@3
|
208 |
Joins multiple queries.
|
om@3
|
209 |
|
om@3
|
210 |
>>> SQLQuery.join(['a', 'b'], ', ')
|
om@3
|
211 |
<sql: 'a, b'>
|
om@3
|
212 |
|
om@3
|
213 |
Optinally, prefix and suffix arguments can be provided.
|
om@3
|
214 |
|
om@3
|
215 |
>>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
|
om@3
|
216 |
<sql: '(a, b)'>
|
om@3
|
217 |
|
om@3
|
218 |
If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
|
om@3
|
219 |
"""
|
om@3
|
220 |
if target is None:
|
om@3
|
221 |
target = SQLQuery()
|
om@3
|
222 |
|
om@3
|
223 |
target_items = target.items
|
om@3
|
224 |
|
om@3
|
225 |
if prefix:
|
om@3
|
226 |
target_items.append(prefix)
|
om@3
|
227 |
|
om@3
|
228 |
for i, item in enumerate(items):
|
om@3
|
229 |
if i != 0:
|
om@3
|
230 |
target_items.append(sep)
|
om@3
|
231 |
if isinstance(item, SQLQuery):
|
om@3
|
232 |
target_items.extend(item.items)
|
om@3
|
233 |
else:
|
om@3
|
234 |
target_items.append(item)
|
om@3
|
235 |
|
om@3
|
236 |
if suffix:
|
om@3
|
237 |
target_items.append(suffix)
|
om@3
|
238 |
return target
|
om@3
|
239 |
|
om@3
|
240 |
join = staticmethod(join)
|
om@3
|
241 |
|
om@3
|
242 |
def _str(self):
|
om@3
|
243 |
try:
|
om@3
|
244 |
return self.query() % tuple([sqlify(x) for x in self.values()])
|
om@3
|
245 |
except (ValueError, TypeError):
|
om@3
|
246 |
return self.query()
|
om@3
|
247 |
|
om@3
|
248 |
def __str__(self):
|
om@3
|
249 |
return safestr(self._str())
|
om@3
|
250 |
|
om@3
|
251 |
def __unicode__(self):
|
om@3
|
252 |
return safeunicode(self._str())
|
om@3
|
253 |
|
om@3
|
254 |
def __repr__(self):
|
om@3
|
255 |
return '<sql: %s>' % repr(str(self))
|
om@3
|
256 |
|
om@3
|
257 |
class SQLLiteral:
|
om@3
|
258 |
"""
|
om@3
|
259 |
Protects a string from `sqlquote`.
|
om@3
|
260 |
|
om@3
|
261 |
>>> sqlquote('NOW()')
|
om@3
|
262 |
<sql: "'NOW()'">
|
om@3
|
263 |
>>> sqlquote(SQLLiteral('NOW()'))
|
om@3
|
264 |
<sql: 'NOW()'>
|
om@3
|
265 |
"""
|
om@3
|
266 |
def __init__(self, v):
|
om@3
|
267 |
self.v = v
|
om@3
|
268 |
|
om@3
|
269 |
def __repr__(self):
|
om@3
|
270 |
return self.v
|
om@3
|
271 |
|
om@3
|
272 |
sqlliteral = SQLLiteral
|
om@3
|
273 |
|
om@3
|
274 |
def _sqllist(values):
|
om@3
|
275 |
"""
|
om@3
|
276 |
>>> _sqllist([1, 2, 3])
|
om@3
|
277 |
<sql: '(1, 2, 3)'>
|
om@3
|
278 |
"""
|
om@3
|
279 |
items = []
|
om@3
|
280 |
items.append('(')
|
om@3
|
281 |
for i, v in enumerate(values):
|
om@3
|
282 |
if i != 0:
|
om@3
|
283 |
items.append(', ')
|
om@3
|
284 |
items.append(sqlparam(v))
|
om@3
|
285 |
items.append(')')
|
om@3
|
286 |
return SQLQuery(items)
|
om@3
|
287 |
|
om@3
|
288 |
def reparam(string_, dictionary):
|
om@3
|
289 |
"""
|
om@3
|
290 |
Takes a string and a dictionary and interpolates the string
|
om@3
|
291 |
using values from the dictionary. Returns an `SQLQuery` for the result.
|
om@3
|
292 |
|
om@3
|
293 |
>>> reparam("s = $s", dict(s=True))
|
om@3
|
294 |
<sql: "s = 't'">
|
om@3
|
295 |
>>> reparam("s IN $s", dict(s=[1, 2]))
|
om@3
|
296 |
<sql: 's IN (1, 2)'>
|
om@3
|
297 |
"""
|
om@3
|
298 |
dictionary = dictionary.copy() # eval mucks with it
|
om@3
|
299 |
vals = []
|
om@3
|
300 |
result = []
|
om@3
|
301 |
for live, chunk in _interpolate(string_):
|
om@3
|
302 |
if live:
|
om@3
|
303 |
v = eval(chunk, dictionary)
|
om@3
|
304 |
result.append(sqlquote(v))
|
om@3
|
305 |
else:
|
om@3
|
306 |
result.append(chunk)
|
om@3
|
307 |
return SQLQuery.join(result, '')
|
om@3
|
308 |
|
om@3
|
309 |
def sqlify(obj):
|
om@3
|
310 |
"""
|
om@3
|
311 |
converts `obj` to its proper SQL version
|
om@3
|
312 |
|
om@3
|
313 |
>>> sqlify(None)
|
om@3
|
314 |
'NULL'
|
om@3
|
315 |
>>> sqlify(True)
|
om@3
|
316 |
"'t'"
|
om@3
|
317 |
>>> sqlify(3)
|
om@3
|
318 |
'3'
|
om@3
|
319 |
"""
|
om@3
|
320 |
# because `1 == True and hash(1) == hash(True)`
|
om@3
|
321 |
# we have to do this the hard way...
|
om@3
|
322 |
|
om@3
|
323 |
if obj is None:
|
om@3
|
324 |
return 'NULL'
|
om@3
|
325 |
elif obj is True:
|
om@3
|
326 |
return "'t'"
|
om@3
|
327 |
elif obj is False:
|
om@3
|
328 |
return "'f'"
|
om@3
|
329 |
elif datetime and isinstance(obj, datetime.datetime):
|
om@3
|
330 |
return repr(obj.isoformat())
|
om@3
|
331 |
else:
|
om@3
|
332 |
if isinstance(obj, unicode): obj = obj.encode('utf8')
|
om@3
|
333 |
return repr(obj)
|
om@3
|
334 |
|
om@3
|
335 |
def sqllist(lst):
|
om@3
|
336 |
"""
|
om@3
|
337 |
Converts the arguments for use in something like a WHERE clause.
|
om@3
|
338 |
|
om@3
|
339 |
>>> sqllist(['a', 'b'])
|
om@3
|
340 |
'a, b'
|
om@3
|
341 |
>>> sqllist('a')
|
om@3
|
342 |
'a'
|
om@3
|
343 |
>>> sqllist(u'abc')
|
om@3
|
344 |
u'abc'
|
om@3
|
345 |
"""
|
om@3
|
346 |
if isinstance(lst, basestring):
|
om@3
|
347 |
return lst
|
om@3
|
348 |
else:
|
om@3
|
349 |
return ', '.join(lst)
|
om@3
|
350 |
|
om@3
|
351 |
def sqlors(left, lst):
|
om@3
|
352 |
"""
|
om@3
|
353 |
`left is a SQL clause like `tablename.arg = `
|
om@3
|
354 |
and `lst` is a list of values. Returns a reparam-style
|
om@3
|
355 |
pair featuring the SQL that ORs together the clause
|
om@3
|
356 |
for each item in the lst.
|
om@3
|
357 |
|
om@3
|
358 |
>>> sqlors('foo = ', [])
|
om@3
|
359 |
<sql: '1=2'>
|
om@3
|
360 |
>>> sqlors('foo = ', [1])
|
om@3
|
361 |
<sql: 'foo = 1'>
|
om@3
|
362 |
>>> sqlors('foo = ', 1)
|
om@3
|
363 |
<sql: 'foo = 1'>
|
om@3
|
364 |
>>> sqlors('foo = ', [1,2,3])
|
om@3
|
365 |
<sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
|
om@3
|
366 |
"""
|
om@3
|
367 |
if isinstance(lst, iters):
|
om@3
|
368 |
lst = list(lst)
|
om@3
|
369 |
ln = len(lst)
|
om@3
|
370 |
if ln == 0:
|
om@3
|
371 |
return SQLQuery("1=2")
|
om@3
|
372 |
if ln == 1:
|
om@3
|
373 |
lst = lst[0]
|
om@3
|
374 |
|
om@3
|
375 |
if isinstance(lst, iters):
|
om@3
|
376 |
return SQLQuery(['('] +
|
om@3
|
377 |
sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
|
om@3
|
378 |
['1=2)']
|
om@3
|
379 |
)
|
om@3
|
380 |
else:
|
om@3
|
381 |
return left + sqlparam(lst)
|
om@3
|
382 |
|
om@3
|
383 |
def sqlwhere(dictionary, grouping=' AND '):
|
om@3
|
384 |
"""
|
om@3
|
385 |
Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
|
om@3
|
386 |
|
om@3
|
387 |
>>> sqlwhere({'cust_id': 2, 'order_id':3})
|
om@3
|
388 |
<sql: 'order_id = 3 AND cust_id = 2'>
|
om@3
|
389 |
>>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
|
om@3
|
390 |
<sql: 'order_id = 3, cust_id = 2'>
|
om@3
|
391 |
>>> sqlwhere({'a': 'a', 'b': 'b'}).query()
|
om@3
|
392 |
'a = %s AND b = %s'
|
om@3
|
393 |
"""
|
om@3
|
394 |
return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
|
om@3
|
395 |
|
om@3
|
396 |
def sqlquote(a):
|
om@3
|
397 |
"""
|
om@3
|
398 |
Ensures `a` is quoted properly for use in a SQL query.
|
om@3
|
399 |
|
om@3
|
400 |
>>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
|
om@3
|
401 |
<sql: "WHERE x = 't' AND y = 3">
|
om@3
|
402 |
>>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
|
om@3
|
403 |
<sql: "WHERE x = 't' AND y IN (2, 3)">
|
om@3
|
404 |
"""
|
om@3
|
405 |
if isinstance(a, list):
|
om@3
|
406 |
return _sqllist(a)
|
om@3
|
407 |
else:
|
om@3
|
408 |
return sqlparam(a).sqlquery()
|
om@3
|
409 |
|
om@3
|
410 |
class Transaction:
|
om@3
|
411 |
"""Database transaction."""
|
om@3
|
412 |
def __init__(self, ctx):
|
om@3
|
413 |
self.ctx = ctx
|
om@3
|
414 |
self.transaction_count = transaction_count = len(ctx.transactions)
|
om@3
|
415 |
|
om@3
|
416 |
class transaction_engine:
|
om@3
|
417 |
"""Transaction Engine used in top level transactions."""
|
om@3
|
418 |
def do_transact(self):
|
om@3
|
419 |
ctx.commit(unload=False)
|
om@3
|
420 |
|
om@3
|
421 |
def do_commit(self):
|
om@3
|
422 |
ctx.commit()
|
om@3
|
423 |
|
om@3
|
424 |
def do_rollback(self):
|
om@3
|
425 |
ctx.rollback()
|
om@3
|
426 |
|
om@3
|
427 |
class subtransaction_engine:
|
om@3
|
428 |
"""Transaction Engine used in sub transactions."""
|
om@3
|
429 |
def query(self, q):
|
om@3
|
430 |
db_cursor = ctx.db.cursor()
|
om@3
|
431 |
ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
|
om@3
|
432 |
|
om@3
|
433 |
def do_transact(self):
|
om@3
|
434 |
self.query('SAVEPOINT webpy_sp_%s')
|
om@3
|
435 |
|
om@3
|
436 |
def do_commit(self):
|
om@3
|
437 |
self.query('RELEASE SAVEPOINT webpy_sp_%s')
|
om@3
|
438 |
|
om@3
|
439 |
def do_rollback(self):
|
om@3
|
440 |
self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
|
om@3
|
441 |
|
om@3
|
442 |
class dummy_engine:
|
om@3
|
443 |
"""Transaction Engine used instead of subtransaction_engine
|
om@3
|
444 |
when sub transactions are not supported."""
|
om@3
|
445 |
do_transact = do_commit = do_rollback = lambda self: None
|
om@3
|
446 |
|
om@3
|
447 |
if self.transaction_count:
|
om@3
|
448 |
# nested transactions are not supported in some databases
|
om@3
|
449 |
if self.ctx.get('ignore_nested_transactions'):
|
om@3
|
450 |
self.engine = dummy_engine()
|
om@3
|
451 |
else:
|
om@3
|
452 |
self.engine = subtransaction_engine()
|
om@3
|
453 |
else:
|
om@3
|
454 |
self.engine = transaction_engine()
|
om@3
|
455 |
|
om@3
|
456 |
self.engine.do_transact()
|
om@3
|
457 |
self.ctx.transactions.append(self)
|
om@3
|
458 |
|
om@3
|
459 |
def __enter__(self):
|
om@3
|
460 |
return self
|
om@3
|
461 |
|
om@3
|
462 |
def __exit__(self, exctype, excvalue, traceback):
|
om@3
|
463 |
if exctype is not None:
|
om@3
|
464 |
self.rollback()
|
om@3
|
465 |
else:
|
om@3
|
466 |
self.commit()
|
om@3
|
467 |
|
om@3
|
468 |
def commit(self):
|
om@3
|
469 |
if len(self.ctx.transactions) > self.transaction_count:
|
om@3
|
470 |
self.engine.do_commit()
|
om@3
|
471 |
self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
|
om@3
|
472 |
|
om@3
|
473 |
def rollback(self):
|
om@3
|
474 |
if len(self.ctx.transactions) > self.transaction_count:
|
om@3
|
475 |
self.engine.do_rollback()
|
om@3
|
476 |
self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
|
om@3
|
477 |
|
om@3
|
478 |
class DB:
|
om@3
|
479 |
"""Database"""
|
om@3
|
480 |
def __init__(self, db_module, keywords):
|
om@3
|
481 |
"""Creates a database.
|
om@3
|
482 |
"""
|
om@3
|
483 |
# some DB implementaions take optional paramater `driver` to use a specific driver modue
|
om@3
|
484 |
# but it should not be passed to connect
|
om@3
|
485 |
keywords.pop('driver', None)
|
om@3
|
486 |
|
om@3
|
487 |
self.db_module = db_module
|
om@3
|
488 |
self.keywords = keywords
|
om@3
|
489 |
|
om@3
|
490 |
self._ctx = threadeddict()
|
om@3
|
491 |
# flag to enable/disable printing queries
|
om@3
|
492 |
self.printing = config.get('debug_sql', config.get('debug', False))
|
om@3
|
493 |
self.supports_multiple_insert = False
|
om@3
|
494 |
|
om@3
|
495 |
try:
|
om@3
|
496 |
import DBUtils
|
om@3
|
497 |
# enable pooling if DBUtils module is available.
|
om@3
|
498 |
self.has_pooling = True
|
om@3
|
499 |
except ImportError:
|
om@3
|
500 |
self.has_pooling = False
|
om@3
|
501 |
|
om@3
|
502 |
# Pooling can be disabled by passing pooling=False in the keywords.
|
om@3
|
503 |
self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
|
om@3
|
504 |
|
om@3
|
505 |
def _getctx(self):
|
om@3
|
506 |
if not self._ctx.get('db'):
|
om@3
|
507 |
self._load_context(self._ctx)
|
om@3
|
508 |
return self._ctx
|
om@3
|
509 |
ctx = property(_getctx)
|
om@3
|
510 |
|
om@3
|
511 |
def _load_context(self, ctx):
|
om@3
|
512 |
ctx.dbq_count = 0
|
om@3
|
513 |
ctx.transactions = [] # stack of transactions
|
om@3
|
514 |
|
om@3
|
515 |
if self.has_pooling:
|
om@3
|
516 |
ctx.db = self._connect_with_pooling(self.keywords)
|
om@3
|
517 |
else:
|
om@3
|
518 |
ctx.db = self._connect(self.keywords)
|
om@3
|
519 |
ctx.db_execute = self._db_execute
|
om@3
|
520 |
|
om@3
|
521 |
if not hasattr(ctx.db, 'commit'):
|
om@3
|
522 |
ctx.db.commit = lambda: None
|
om@3
|
523 |
|
om@3
|
524 |
if not hasattr(ctx.db, 'rollback'):
|
om@3
|
525 |
ctx.db.rollback = lambda: None
|
om@3
|
526 |
|
om@3
|
527 |
def commit(unload=True):
|
om@3
|
528 |
# do db commit and release the connection if pooling is enabled.
|
om@3
|
529 |
ctx.db.commit()
|
om@3
|
530 |
if unload and self.has_pooling:
|
om@3
|
531 |
self._unload_context(self._ctx)
|
om@3
|
532 |
|
om@3
|
533 |
def rollback():
|
om@3
|
534 |
# do db rollback and release the connection if pooling is enabled.
|
om@3
|
535 |
ctx.db.rollback()
|
om@3
|
536 |
if self.has_pooling:
|
om@3
|
537 |
self._unload_context(self._ctx)
|
om@3
|
538 |
|
om@3
|
539 |
ctx.commit = commit
|
om@3
|
540 |
ctx.rollback = rollback
|
om@3
|
541 |
|
om@3
|
542 |
def _unload_context(self, ctx):
|
om@3
|
543 |
del ctx.db
|
om@3
|
544 |
|
om@3
|
545 |
def _connect(self, keywords):
|
om@3
|
546 |
return self.db_module.connect(**keywords)
|
om@3
|
547 |
|
om@3
|
548 |
def _connect_with_pooling(self, keywords):
|
om@3
|
549 |
def get_pooled_db():
|
om@3
|
550 |
from DBUtils import PooledDB
|
om@3
|
551 |
|
om@3
|
552 |
# In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
|
om@3
|
553 |
# see Bug#122112
|
om@3
|
554 |
|
om@3
|
555 |
if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
|
om@3
|
556 |
return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
|
om@3
|
557 |
else:
|
om@3
|
558 |
return PooledDB.PooledDB(creator=self.db_module, **keywords)
|
om@3
|
559 |
|
om@3
|
560 |
if getattr(self, '_pooleddb', None) is None:
|
om@3
|
561 |
self._pooleddb = get_pooled_db()
|
om@3
|
562 |
|
om@3
|
563 |
return self._pooleddb.connection()
|
om@3
|
564 |
|
om@3
|
565 |
def _db_cursor(self):
|
om@3
|
566 |
return self.ctx.db.cursor()
|
om@3
|
567 |
|
om@3
|
568 |
def _param_marker(self):
|
om@3
|
569 |
"""Returns parameter marker based on paramstyle attribute if this database."""
|
om@3
|
570 |
style = getattr(self, 'paramstyle', 'pyformat')
|
om@3
|
571 |
|
om@3
|
572 |
if style == 'qmark':
|
om@3
|
573 |
return '?'
|
om@3
|
574 |
elif style == 'numeric':
|
om@3
|
575 |
return ':1'
|
om@3
|
576 |
elif style in ['format', 'pyformat']:
|
om@3
|
577 |
return '%s'
|
om@3
|
578 |
raise UnknownParamstyle, style
|
om@3
|
579 |
|
om@3
|
580 |
def _db_execute(self, cur, sql_query):
|
om@3
|
581 |
"""executes an sql query"""
|
om@3
|
582 |
self.ctx.dbq_count += 1
|
om@3
|
583 |
|
om@3
|
584 |
try:
|
om@3
|
585 |
a = time.time()
|
om@3
|
586 |
query, params = self._process_query(sql_query)
|
om@3
|
587 |
out = cur.execute(query, params)
|
om@3
|
588 |
b = time.time()
|
om@3
|
589 |
except:
|
om@3
|
590 |
if self.printing:
|
om@3
|
591 |
print >> debug, 'ERR:', str(sql_query)
|
om@3
|
592 |
if self.ctx.transactions:
|
om@3
|
593 |
self.ctx.transactions[-1].rollback()
|
om@3
|
594 |
else:
|
om@3
|
595 |
self.ctx.rollback()
|
om@3
|
596 |
raise
|
om@3
|
597 |
|
om@3
|
598 |
if self.printing:
|
om@3
|
599 |
print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
|
om@3
|
600 |
return out
|
om@3
|
601 |
|
om@3
|
602 |
def _process_query(self, sql_query):
|
om@3
|
603 |
"""Takes the SQLQuery object and returns query string and parameters.
|
om@3
|
604 |
"""
|
om@3
|
605 |
paramstyle = getattr(self, 'paramstyle', 'pyformat')
|
om@3
|
606 |
query = sql_query.query(paramstyle)
|
om@3
|
607 |
params = sql_query.values()
|
om@3
|
608 |
return query, params
|
om@3
|
609 |
|
om@3
|
610 |
def _where(self, where, vars):
|
om@3
|
611 |
if isinstance(where, (int, long)):
|
om@3
|
612 |
where = "id = " + sqlparam(where)
|
om@3
|
613 |
#@@@ for backward-compatibility
|
om@3
|
614 |
elif isinstance(where, (list, tuple)) and len(where) == 2:
|
om@3
|
615 |
where = SQLQuery(where[0], where[1])
|
om@3
|
616 |
elif isinstance(where, SQLQuery):
|
om@3
|
617 |
pass
|
om@3
|
618 |
else:
|
om@3
|
619 |
where = reparam(where, vars)
|
om@3
|
620 |
return where
|
om@3
|
621 |
|
om@3
|
622 |
def query(self, sql_query, vars=None, processed=False, _test=False):
|
om@3
|
623 |
"""
|
om@3
|
624 |
Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
|
om@3
|
625 |
If `processed=True`, `vars` is a `reparam`-style list to use
|
om@3
|
626 |
instead of interpolating.
|
om@3
|
627 |
|
om@3
|
628 |
>>> db = DB(None, {})
|
om@3
|
629 |
>>> db.query("SELECT * FROM foo", _test=True)
|
om@3
|
630 |
<sql: 'SELECT * FROM foo'>
|
om@3
|
631 |
>>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
|
om@3
|
632 |
<sql: "SELECT * FROM foo WHERE x = 'f'">
|
om@3
|
633 |
>>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
|
om@3
|
634 |
<sql: "SELECT * FROM foo WHERE x = 'f'">
|
om@3
|
635 |
"""
|
om@3
|
636 |
if vars is None: vars = {}
|
om@3
|
637 |
|
om@3
|
638 |
if not processed and not isinstance(sql_query, SQLQuery):
|
om@3
|
639 |
sql_query = reparam(sql_query, vars)
|
om@3
|
640 |
|
om@3
|
641 |
if _test: return sql_query
|
om@3
|
642 |
|
om@3
|
643 |
db_cursor = self._db_cursor()
|
om@3
|
644 |
self._db_execute(db_cursor, sql_query)
|
om@3
|
645 |
|
om@3
|
646 |
if db_cursor.description:
|
om@3
|
647 |
names = [x[0] for x in db_cursor.description]
|
om@3
|
648 |
def iterwrapper():
|
om@3
|
649 |
row = db_cursor.fetchone()
|
om@3
|
650 |
while row:
|
om@3
|
651 |
yield storage(dict(zip(names, row)))
|
om@3
|
652 |
row = db_cursor.fetchone()
|
om@3
|
653 |
out = iterbetter(iterwrapper())
|
om@3
|
654 |
out.__len__ = lambda: int(db_cursor.rowcount)
|
om@3
|
655 |
out.list = lambda: [storage(dict(zip(names, x))) \
|
om@3
|
656 |
for x in db_cursor.fetchall()]
|
om@3
|
657 |
else:
|
om@3
|
658 |
out = db_cursor.rowcount
|
om@3
|
659 |
|
om@3
|
660 |
if not self.ctx.transactions:
|
om@3
|
661 |
self.ctx.commit()
|
om@3
|
662 |
return out
|
om@3
|
663 |
|
om@3
|
664 |
def select(self, tables, vars=None, what='*', where=None, order=None, group=None,
|
om@3
|
665 |
limit=None, offset=None, _test=False):
|
om@3
|
666 |
"""
|
om@3
|
667 |
Selects `what` from `tables` with clauses `where`, `order`,
|
om@3
|
668 |
`group`, `limit`, and `offset`. Uses vars to interpolate.
|
om@3
|
669 |
Otherwise, each clause can be a SQLQuery.
|
om@3
|
670 |
|
om@3
|
671 |
>>> db = DB(None, {})
|
om@3
|
672 |
>>> db.select('foo', _test=True)
|
om@3
|
673 |
<sql: 'SELECT * FROM foo'>
|
om@3
|
674 |
>>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
|
om@3
|
675 |
<sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
|
om@3
|
676 |
"""
|
om@3
|
677 |
if vars is None: vars = {}
|
om@3
|
678 |
sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
|
om@3
|
679 |
clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
|
om@3
|
680 |
qout = SQLQuery.join(clauses)
|
om@3
|
681 |
if _test: return qout
|
om@3
|
682 |
return self.query(qout, processed=True)
|
om@3
|
683 |
|
om@3
|
684 |
def where(self, table, what='*', order=None, group=None, limit=None,
|
om@3
|
685 |
offset=None, _test=False, **kwargs):
|
om@3
|
686 |
"""
|
om@3
|
687 |
Selects from `table` where keys are equal to values in `kwargs`.
|
om@3
|
688 |
|
om@3
|
689 |
>>> db = DB(None, {})
|
om@3
|
690 |
>>> db.where('foo', bar_id=3, _test=True)
|
om@3
|
691 |
<sql: 'SELECT * FROM foo WHERE bar_id = 3'>
|
om@3
|
692 |
>>> db.where('foo', source=2, crust='dewey', _test=True)
|
om@3
|
693 |
<sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
|
om@3
|
694 |
>>> db.where('foo', _test=True)
|
om@3
|
695 |
<sql: 'SELECT * FROM foo'>
|
om@3
|
696 |
"""
|
om@3
|
697 |
where_clauses = []
|
om@3
|
698 |
for k, v in kwargs.iteritems():
|
om@3
|
699 |
where_clauses.append(k + ' = ' + sqlquote(v))
|
om@3
|
700 |
|
om@3
|
701 |
if where_clauses:
|
om@3
|
702 |
where = SQLQuery.join(where_clauses, " AND ")
|
om@3
|
703 |
else:
|
om@3
|
704 |
where = None
|
om@3
|
705 |
|
om@3
|
706 |
return self.select(table, what=what, order=order,
|
om@3
|
707 |
group=group, limit=limit, offset=offset, _test=_test,
|
om@3
|
708 |
where=where)
|
om@3
|
709 |
|
om@3
|
710 |
def sql_clauses(self, what, tables, where, group, order, limit, offset):
|
om@3
|
711 |
return (
|
om@3
|
712 |
('SELECT', what),
|
om@3
|
713 |
('FROM', sqllist(tables)),
|
om@3
|
714 |
('WHERE', where),
|
om@3
|
715 |
('GROUP BY', group),
|
om@3
|
716 |
('ORDER BY', order),
|
om@3
|
717 |
('LIMIT', limit),
|
om@3
|
718 |
('OFFSET', offset))
|
om@3
|
719 |
|
om@3
|
720 |
def gen_clause(self, sql, val, vars):
|
om@3
|
721 |
if isinstance(val, (int, long)):
|
om@3
|
722 |
if sql == 'WHERE':
|
om@3
|
723 |
nout = 'id = ' + sqlquote(val)
|
om@3
|
724 |
else:
|
om@3
|
725 |
nout = SQLQuery(val)
|
om@3
|
726 |
#@@@
|
om@3
|
727 |
elif isinstance(val, (list, tuple)) and len(val) == 2:
|
om@3
|
728 |
nout = SQLQuery(val[0], val[1]) # backwards-compatibility
|
om@3
|
729 |
elif isinstance(val, SQLQuery):
|
om@3
|
730 |
nout = val
|
om@3
|
731 |
else:
|
om@3
|
732 |
nout = reparam(val, vars)
|
om@3
|
733 |
|
om@3
|
734 |
def xjoin(a, b):
|
om@3
|
735 |
if a and b: return a + ' ' + b
|
om@3
|
736 |
else: return a or b
|
om@3
|
737 |
|
om@3
|
738 |
return xjoin(sql, nout)
|
om@3
|
739 |
|
om@3
|
740 |
def insert(self, tablename, seqname=None, _test=False, **values):
|
om@3
|
741 |
"""
|
om@3
|
742 |
Inserts `values` into `tablename`. Returns current sequence ID.
|
om@3
|
743 |
Set `seqname` to the ID if it's not the default, or to `False`
|
om@3
|
744 |
if there isn't one.
|
om@3
|
745 |
|
om@3
|
746 |
>>> db = DB(None, {})
|
om@3
|
747 |
>>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
|
om@3
|
748 |
>>> q
|
om@3
|
749 |
<sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
|
om@3
|
750 |
>>> q.query()
|
om@3
|
751 |
'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
|
om@3
|
752 |
>>> q.values()
|
om@3
|
753 |
[2, 'bob']
|
om@3
|
754 |
"""
|
om@3
|
755 |
def q(x): return "(" + x + ")"
|
om@3
|
756 |
|
om@3
|
757 |
if values:
|
om@3
|
758 |
_keys = SQLQuery.join(values.keys(), ', ')
|
om@3
|
759 |
_values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
|
om@3
|
760 |
sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
|
om@3
|
761 |
else:
|
om@3
|
762 |
sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
|
om@3
|
763 |
|
om@3
|
764 |
if _test: return sql_query
|
om@3
|
765 |
|
om@3
|
766 |
db_cursor = self._db_cursor()
|
om@3
|
767 |
if seqname is not False:
|
om@3
|
768 |
sql_query = self._process_insert_query(sql_query, tablename, seqname)
|
om@3
|
769 |
|
om@3
|
770 |
if isinstance(sql_query, tuple):
|
om@3
|
771 |
# for some databases, a separate query has to be made to find
|
om@3
|
772 |
# the id of the inserted row.
|
om@3
|
773 |
q1, q2 = sql_query
|
om@3
|
774 |
self._db_execute(db_cursor, q1)
|
om@3
|
775 |
self._db_execute(db_cursor, q2)
|
om@3
|
776 |
else:
|
om@3
|
777 |
self._db_execute(db_cursor, sql_query)
|
om@3
|
778 |
|
om@3
|
779 |
try:
|
om@3
|
780 |
out = db_cursor.fetchone()[0]
|
om@3
|
781 |
except Exception:
|
om@3
|
782 |
out = None
|
om@3
|
783 |
|
om@3
|
784 |
if not self.ctx.transactions:
|
om@3
|
785 |
self.ctx.commit()
|
om@3
|
786 |
return out
|
om@3
|
787 |
|
om@3
|
788 |
def _get_insert_default_values_query(self, table):
|
om@3
|
789 |
return "INSERT INTO %s DEFAULT VALUES" % table
|
om@3
|
790 |
|
om@3
|
791 |
def multiple_insert(self, tablename, values, seqname=None, _test=False):
|
om@3
|
792 |
"""
|
om@3
|
793 |
Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries,
|
om@3
|
794 |
one for each row to be inserted, each with the same set of keys.
|
om@3
|
795 |
Returns the list of ids of the inserted rows.
|
om@3
|
796 |
Set `seqname` to the ID if it's not the default, or to `False`
|
om@3
|
797 |
if there isn't one.
|
om@3
|
798 |
|
om@3
|
799 |
>>> db = DB(None, {})
|
om@3
|
800 |
>>> db.supports_multiple_insert = True
|
om@3
|
801 |
>>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
|
om@3
|
802 |
>>> db.multiple_insert('person', values=values, _test=True)
|
om@3
|
803 |
<sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
|
om@3
|
804 |
"""
|
om@3
|
805 |
if not values:
|
om@3
|
806 |
return []
|
om@3
|
807 |
|
om@3
|
808 |
if not self.supports_multiple_insert:
|
om@3
|
809 |
out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
|
om@3
|
810 |
if seqname is False:
|
om@3
|
811 |
return None
|
om@3
|
812 |
else:
|
om@3
|
813 |
return out
|
om@3
|
814 |
|
om@3
|
815 |
keys = values[0].keys()
|
om@3
|
816 |
#@@ make sure all keys are valid
|
om@3
|
817 |
|
om@3
|
818 |
# make sure all rows have same keys.
|
om@3
|
819 |
for v in values:
|
om@3
|
820 |
if v.keys() != keys:
|
om@3
|
821 |
raise ValueError, 'Bad data'
|
om@3
|
822 |
|
om@3
|
823 |
sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
|
om@3
|
824 |
|
om@3
|
825 |
for i, row in enumerate(values):
|
om@3
|
826 |
if i != 0:
|
om@3
|
827 |
sql_query.append(", ")
|
om@3
|
828 |
SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
|
om@3
|
829 |
|
om@3
|
830 |
if _test: return sql_query
|
om@3
|
831 |
|
om@3
|
832 |
db_cursor = self._db_cursor()
|
om@3
|
833 |
if seqname is not False:
|
om@3
|
834 |
sql_query = self._process_insert_query(sql_query, tablename, seqname)
|
om@3
|
835 |
|
om@3
|
836 |
if isinstance(sql_query, tuple):
|
om@3
|
837 |
# for some databases, a separate query has to be made to find
|
om@3
|
838 |
# the id of the inserted row.
|
om@3
|
839 |
q1, q2 = sql_query
|
om@3
|
840 |
self._db_execute(db_cursor, q1)
|
om@3
|
841 |
self._db_execute(db_cursor, q2)
|
om@3
|
842 |
else:
|
om@3
|
843 |
self._db_execute(db_cursor, sql_query)
|
om@3
|
844 |
|
om@3
|
845 |
try:
|
om@3
|
846 |
out = db_cursor.fetchone()[0]
|
om@3
|
847 |
out = range(out-len(values)+1, out+1)
|
om@3
|
848 |
except Exception:
|
om@3
|
849 |
out = None
|
om@3
|
850 |
|
om@3
|
851 |
if not self.ctx.transactions:
|
om@3
|
852 |
self.ctx.commit()
|
om@3
|
853 |
return out
|
om@3
|
854 |
|
om@3
|
855 |
|
om@3
|
856 |
def update(self, tables, where, vars=None, _test=False, **values):
|
om@3
|
857 |
"""
|
om@3
|
858 |
Update `tables` with clause `where` (interpolated using `vars`)
|
om@3
|
859 |
and setting `values`.
|
om@3
|
860 |
|
om@3
|
861 |
>>> db = DB(None, {})
|
om@3
|
862 |
>>> name = 'Joseph'
|
om@3
|
863 |
>>> q = db.update('foo', where='name = $name', name='bob', age=2,
|
om@3
|
864 |
... created=SQLLiteral('NOW()'), vars=locals(), _test=True)
|
om@3
|
865 |
>>> q
|
om@3
|
866 |
<sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
|
om@3
|
867 |
>>> q.query()
|
om@3
|
868 |
'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
|
om@3
|
869 |
>>> q.values()
|
om@3
|
870 |
[2, 'bob', 'Joseph']
|
om@3
|
871 |
"""
|
om@3
|
872 |
if vars is None: vars = {}
|
om@3
|
873 |
where = self._where(where, vars)
|
om@3
|
874 |
|
om@3
|
875 |
query = (
|
om@3
|
876 |
"UPDATE " + sqllist(tables) +
|
om@3
|
877 |
" SET " + sqlwhere(values, ', ') +
|
om@3
|
878 |
" WHERE " + where)
|
om@3
|
879 |
|
om@3
|
880 |
if _test: return query
|
om@3
|
881 |
|
om@3
|
882 |
db_cursor = self._db_cursor()
|
om@3
|
883 |
self._db_execute(db_cursor, query)
|
om@3
|
884 |
if not self.ctx.transactions:
|
om@3
|
885 |
self.ctx.commit()
|
om@3
|
886 |
return db_cursor.rowcount
|
om@3
|
887 |
|
om@3
|
888 |
def delete(self, table, where, using=None, vars=None, _test=False):
|
om@3
|
889 |
"""
|
om@3
|
890 |
Deletes from `table` with clauses `where` and `using`.
|
om@3
|
891 |
|
om@3
|
892 |
>>> db = DB(None, {})
|
om@3
|
893 |
>>> name = 'Joe'
|
om@3
|
894 |
>>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
|
om@3
|
895 |
<sql: "DELETE FROM foo WHERE name = 'Joe'">
|
om@3
|
896 |
"""
|
om@3
|
897 |
if vars is None: vars = {}
|
om@3
|
898 |
where = self._where(where, vars)
|
om@3
|
899 |
|
om@3
|
900 |
q = 'DELETE FROM ' + table
|
om@3
|
901 |
if using: q += ' USING ' + sqllist(using)
|
om@3
|
902 |
if where: q += ' WHERE ' + where
|
om@3
|
903 |
|
om@3
|
904 |
if _test: return q
|
om@3
|
905 |
|
om@3
|
906 |
db_cursor = self._db_cursor()
|
om@3
|
907 |
self._db_execute(db_cursor, q)
|
om@3
|
908 |
if not self.ctx.transactions:
|
om@3
|
909 |
self.ctx.commit()
|
om@3
|
910 |
return db_cursor.rowcount
|
om@3
|
911 |
|
om@3
|
912 |
def _process_insert_query(self, query, tablename, seqname):
|
om@3
|
913 |
return query
|
om@3
|
914 |
|
om@3
|
915 |
def transaction(self):
|
om@3
|
916 |
"""Start a transaction."""
|
om@3
|
917 |
return Transaction(self.ctx)
|
om@3
|
918 |
|
om@3
|
919 |
class PostgresDB(DB):
|
om@3
|
920 |
"""Postgres driver."""
|
om@3
|
921 |
def __init__(self, **keywords):
|
om@3
|
922 |
if 'pw' in keywords:
|
om@3
|
923 |
keywords['password'] = keywords.pop('pw')
|
om@3
|
924 |
|
om@3
|
925 |
db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
|
om@3
|
926 |
if db_module.__name__ == "psycopg2":
|
om@3
|
927 |
import psycopg2.extensions
|
om@3
|
928 |
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
om@3
|
929 |
|
om@3
|
930 |
# if db is not provided postgres driver will take it from PGDATABASE environment variable
|
om@3
|
931 |
if 'db' in keywords:
|
om@3
|
932 |
keywords['database'] = keywords.pop('db')
|
om@3
|
933 |
|
om@3
|
934 |
self.dbname = "postgres"
|
om@3
|
935 |
self.paramstyle = db_module.paramstyle
|
om@3
|
936 |
DB.__init__(self, db_module, keywords)
|
om@3
|
937 |
self.supports_multiple_insert = True
|
om@3
|
938 |
self._sequences = None
|
om@3
|
939 |
|
om@3
|
940 |
def _process_insert_query(self, query, tablename, seqname):
|
om@3
|
941 |
if seqname is None:
|
om@3
|
942 |
# when seqname is not provided guess the seqname and make sure it exists
|
om@3
|
943 |
seqname = tablename + "_id_seq"
|
om@3
|
944 |
if seqname not in self._get_all_sequences():
|
om@3
|
945 |
seqname = None
|
om@3
|
946 |
|
om@3
|
947 |
if seqname:
|
om@3
|
948 |
query += "; SELECT currval('%s')" % seqname
|
om@3
|
949 |
|
om@3
|
950 |
return query
|
om@3
|
951 |
|
om@3
|
952 |
def _get_all_sequences(self):
|
om@3
|
953 |
"""Query postgres to find names of all sequences used in this database."""
|
om@3
|
954 |
if self._sequences is None:
|
om@3
|
955 |
q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
|
om@3
|
956 |
self._sequences = set([c.relname for c in self.query(q)])
|
om@3
|
957 |
return self._sequences
|
om@3
|
958 |
|
om@3
|
959 |
def _connect(self, keywords):
|
om@3
|
960 |
conn = DB._connect(self, keywords)
|
om@3
|
961 |
try:
|
om@3
|
962 |
conn.set_client_encoding('UTF8')
|
om@3
|
963 |
except AttributeError:
|
om@3
|
964 |
# fallback for pgdb driver
|
om@3
|
965 |
conn.cursor().execute("set client_encoding to 'UTF-8'")
|
om@3
|
966 |
return conn
|
om@3
|
967 |
|
om@3
|
968 |
def _connect_with_pooling(self, keywords):
|
om@3
|
969 |
conn = DB._connect_with_pooling(self, keywords)
|
om@3
|
970 |
conn._con._con.set_client_encoding('UTF8')
|
om@3
|
971 |
return conn
|
om@3
|
972 |
|
om@3
|
973 |
class MySQLDB(DB):
|
om@3
|
974 |
def __init__(self, **keywords):
|
om@3
|
975 |
import MySQLdb as db
|
om@3
|
976 |
if 'pw' in keywords:
|
om@3
|
977 |
keywords['passwd'] = keywords['pw']
|
om@3
|
978 |
del keywords['pw']
|
om@3
|
979 |
|
om@3
|
980 |
if 'charset' not in keywords:
|
om@3
|
981 |
keywords['charset'] = 'utf8'
|
om@3
|
982 |
elif keywords['charset'] is None:
|
om@3
|
983 |
del keywords['charset']
|
om@3
|
984 |
|
om@3
|
985 |
self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
|
om@3
|
986 |
self.dbname = "mysql"
|
om@3
|
987 |
DB.__init__(self, db, keywords)
|
om@3
|
988 |
self.supports_multiple_insert = True
|
om@3
|
989 |
|
om@3
|
990 |
def _process_insert_query(self, query, tablename, seqname):
|
om@3
|
991 |
return query, SQLQuery('SELECT last_insert_id();')
|
om@3
|
992 |
|
om@3
|
993 |
def _get_insert_default_values_query(self, table):
|
om@3
|
994 |
return "INSERT INTO %s () VALUES()" % table
|
om@3
|
995 |
|
om@3
|
996 |
def import_driver(drivers, preferred=None):
|
om@3
|
997 |
"""Import the first available driver or preferred driver.
|
om@3
|
998 |
"""
|
om@3
|
999 |
if preferred:
|
om@3
|
1000 |
drivers = [preferred]
|
om@3
|
1001 |
|
om@3
|
1002 |
for d in drivers:
|
om@3
|
1003 |
try:
|
om@3
|
1004 |
return __import__(d, None, None, ['x'])
|
om@3
|
1005 |
except ImportError:
|
om@3
|
1006 |
pass
|
om@3
|
1007 |
raise ImportError("Unable to import " + " or ".join(drivers))
|
om@3
|
1008 |
|
om@3
|
1009 |
class SqliteDB(DB):
|
om@3
|
1010 |
def __init__(self, **keywords):
|
om@3
|
1011 |
db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
|
om@3
|
1012 |
|
om@3
|
1013 |
if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
|
om@3
|
1014 |
db.paramstyle = 'qmark'
|
om@3
|
1015 |
|
om@3
|
1016 |
# sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
|
om@3
|
1017 |
# It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
|
om@3
|
1018 |
keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
|
om@3
|
1019 |
|
om@3
|
1020 |
self.paramstyle = db.paramstyle
|
om@3
|
1021 |
keywords['database'] = keywords.pop('db')
|
om@3
|
1022 |
keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
|
om@3
|
1023 |
self.dbname = "sqlite"
|
om@3
|
1024 |
DB.__init__(self, db, keywords)
|
om@3
|
1025 |
|
om@3
|
1026 |
def _process_insert_query(self, query, tablename, seqname):
|
om@3
|
1027 |
return query, SQLQuery('SELECT last_insert_rowid();')
|
om@3
|
1028 |
|
om@3
|
1029 |
def query(self, *a, **kw):
|
om@3
|
1030 |
out = DB.query(self, *a, **kw)
|
om@3
|
1031 |
if isinstance(out, iterbetter):
|
om@3
|
1032 |
del out.__len__
|
om@3
|
1033 |
return out
|
om@3
|
1034 |
|
om@3
|
1035 |
class FirebirdDB(DB):
|
om@3
|
1036 |
"""Firebird Database.
|
om@3
|
1037 |
"""
|
om@3
|
1038 |
def __init__(self, **keywords):
|
om@3
|
1039 |
try:
|
om@3
|
1040 |
import kinterbasdb as db
|
om@3
|
1041 |
except Exception:
|
om@3
|
1042 |
db = None
|
om@3
|
1043 |
pass
|
om@3
|
1044 |
if 'pw' in keywords:
|
om@3
|
1045 |
keywords['passwd'] = keywords['pw']
|
om@3
|
1046 |
del keywords['pw']
|
om@3
|
1047 |
keywords['database'] = keywords['db']
|
om@3
|
1048 |
del keywords['db']
|
om@3
|
1049 |
DB.__init__(self, db, keywords)
|
om@3
|
1050 |
|
om@3
|
1051 |
def delete(self, table, where=None, using=None, vars=None, _test=False):
|
om@3
|
1052 |
# firebird doesn't support using clause
|
om@3
|
1053 |
using=None
|
om@3
|
1054 |
return DB.delete(self, table, where, using, vars, _test)
|
om@3
|
1055 |
|
om@3
|
1056 |
def sql_clauses(self, what, tables, where, group, order, limit, offset):
|
om@3
|
1057 |
return (
|
om@3
|
1058 |
('SELECT', ''),
|
om@3
|
1059 |
('FIRST', limit),
|
om@3
|
1060 |
('SKIP', offset),
|
om@3
|
1061 |
('', what),
|
om@3
|
1062 |
('FROM', sqllist(tables)),
|
om@3
|
1063 |
('WHERE', where),
|
om@3
|
1064 |
('GROUP BY', group),
|
om@3
|
1065 |
('ORDER BY', order)
|
om@3
|
1066 |
)
|
om@3
|
1067 |
|
om@3
|
1068 |
class MSSQLDB(DB):
|
om@3
|
1069 |
def __init__(self, **keywords):
|
om@3
|
1070 |
import pymssql as db
|
om@3
|
1071 |
if 'pw' in keywords:
|
om@3
|
1072 |
keywords['password'] = keywords.pop('pw')
|
om@3
|
1073 |
keywords['database'] = keywords.pop('db')
|
om@3
|
1074 |
self.dbname = "mssql"
|
om@3
|
1075 |
DB.__init__(self, db, keywords)
|
om@3
|
1076 |
|
om@3
|
1077 |
def _process_query(self, sql_query):
|
om@3
|
1078 |
"""Takes the SQLQuery object and returns query string and parameters.
|
om@3
|
1079 |
"""
|
om@3
|
1080 |
# MSSQLDB expects params to be a tuple.
|
om@3
|
1081 |
# Overwriting the default implementation to convert params to tuple.
|
om@3
|
1082 |
paramstyle = getattr(self, 'paramstyle', 'pyformat')
|
om@3
|
1083 |
query = sql_query.query(paramstyle)
|
om@3
|
1084 |
params = sql_query.values()
|
om@3
|
1085 |
return query, tuple(params)
|
om@3
|
1086 |
|
om@3
|
1087 |
def sql_clauses(self, what, tables, where, group, order, limit, offset):
|
om@3
|
1088 |
return (
|
om@3
|
1089 |
('SELECT', what),
|
om@3
|
1090 |
('TOP', limit),
|
om@3
|
1091 |
('FROM', sqllist(tables)),
|
om@3
|
1092 |
('WHERE', where),
|
om@3
|
1093 |
('GROUP BY', group),
|
om@3
|
1094 |
('ORDER BY', order),
|
om@3
|
1095 |
('OFFSET', offset))
|
om@3
|
1096 |
|
om@3
|
1097 |
def _test(self):
|
om@3
|
1098 |
"""Test LIMIT.
|
om@3
|
1099 |
|
om@3
|
1100 |
Fake presence of pymssql module for running tests.
|
om@3
|
1101 |
>>> import sys
|
om@3
|
1102 |
>>> sys.modules['pymssql'] = sys.modules['sys']
|
om@3
|
1103 |
|
om@3
|
1104 |
MSSQL has TOP clause instead of LIMIT clause.
|
om@3
|
1105 |
>>> db = MSSQLDB(db='test', user='joe', pw='secret')
|
om@3
|
1106 |
>>> db.select('foo', limit=4, _test=True)
|
om@3
|
1107 |
<sql: 'SELECT * TOP 4 FROM foo'>
|
om@3
|
1108 |
"""
|
om@3
|
1109 |
pass
|
om@3
|
1110 |
|
om@3
|
1111 |
class OracleDB(DB):
|
om@3
|
1112 |
def __init__(self, **keywords):
|
om@3
|
1113 |
import cx_Oracle as db
|
om@3
|
1114 |
if 'pw' in keywords:
|
om@3
|
1115 |
keywords['password'] = keywords.pop('pw')
|
om@3
|
1116 |
|
om@3
|
1117 |
#@@ TODO: use db.makedsn if host, port is specified
|
om@3
|
1118 |
keywords['dsn'] = keywords.pop('db')
|
om@3
|
1119 |
self.dbname = 'oracle'
|
om@3
|
1120 |
db.paramstyle = 'numeric'
|
om@3
|
1121 |
self.paramstyle = db.paramstyle
|
om@3
|
1122 |
|
om@3
|
1123 |
# oracle doesn't support pooling
|
om@3
|
1124 |
keywords.pop('pooling', None)
|
om@3
|
1125 |
DB.__init__(self, db, keywords)
|
om@3
|
1126 |
|
om@3
|
1127 |
def _process_insert_query(self, query, tablename, seqname):
|
om@3
|
1128 |
if seqname is None:
|
om@3
|
1129 |
# It is not possible to get seq name from table name in Oracle
|
om@3
|
1130 |
return query
|
om@3
|
1131 |
else:
|
om@3
|
1132 |
return query + "; SELECT %s.currval FROM dual" % seqname
|
om@3
|
1133 |
|
om@3
|
1134 |
_databases = {}
|
om@3
|
1135 |
def database(dburl=None, **params):
|
om@3
|
1136 |
"""Creates appropriate database using params.
|
om@3
|
1137 |
|
om@3
|
1138 |
Pooling will be enabled if DBUtils module is available.
|
om@3
|
1139 |
Pooling can be disabled by passing pooling=False in params.
|
om@3
|
1140 |
"""
|
om@3
|
1141 |
dbn = params.pop('dbn')
|
om@3
|
1142 |
if dbn in _databases:
|
om@3
|
1143 |
return _databases[dbn](**params)
|
om@3
|
1144 |
else:
|
om@3
|
1145 |
raise UnknownDB, dbn
|
om@3
|
1146 |
|
om@3
|
1147 |
def register_database(name, clazz):
|
om@3
|
1148 |
"""
|
om@3
|
1149 |
Register a database.
|
om@3
|
1150 |
|
om@3
|
1151 |
>>> class LegacyDB(DB):
|
om@3
|
1152 |
... def __init__(self, **params):
|
om@3
|
1153 |
... pass
|
om@3
|
1154 |
...
|
om@3
|
1155 |
>>> register_database('legacy', LegacyDB)
|
om@3
|
1156 |
>>> db = database(dbn='legacy', db='test', user='joe', passwd='secret')
|
om@3
|
1157 |
"""
|
om@3
|
1158 |
_databases[name] = clazz
|
om@3
|
1159 |
|
om@3
|
1160 |
register_database('mysql', MySQLDB)
|
om@3
|
1161 |
register_database('postgres', PostgresDB)
|
om@3
|
1162 |
register_database('sqlite', SqliteDB)
|
om@3
|
1163 |
register_database('firebird', FirebirdDB)
|
om@3
|
1164 |
register_database('mssql', MSSQLDB)
|
om@3
|
1165 |
register_database('oracle', OracleDB)
|
om@3
|
1166 |
|
om@3
|
1167 |
def _interpolate(format):
|
om@3
|
1168 |
"""
|
om@3
|
1169 |
Takes a format string and returns a list of 2-tuples of the form
|
om@3
|
1170 |
(boolean, string) where boolean says whether string should be evaled
|
om@3
|
1171 |
or not.
|
om@3
|
1172 |
|
om@3
|
1173 |
from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
|
om@3
|
1174 |
"""
|
om@3
|
1175 |
from tokenize import tokenprog
|
om@3
|
1176 |
|
om@3
|
1177 |
def matchorfail(text, pos):
|
om@3
|
1178 |
match = tokenprog.match(text, pos)
|
om@3
|
1179 |
if match is None:
|
om@3
|
1180 |
raise _ItplError(text, pos)
|
om@3
|
1181 |
return match, match.end()
|
om@3
|
1182 |
|
om@3
|
1183 |
namechars = "abcdefghijklmnopqrstuvwxyz" \
|
om@3
|
1184 |
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
|
om@3
|
1185 |
chunks = []
|
om@3
|
1186 |
pos = 0
|
om@3
|
1187 |
|
om@3
|
1188 |
while 1:
|
om@3
|
1189 |
dollar = format.find("$", pos)
|
om@3
|
1190 |
if dollar < 0:
|
om@3
|
1191 |
break
|
om@3
|
1192 |
nextchar = format[dollar + 1]
|
om@3
|
1193 |
|
om@3
|
1194 |
if nextchar == "{":
|
om@3
|
1195 |
chunks.append((0, format[pos:dollar]))
|
om@3
|
1196 |
pos, level = dollar + 2, 1
|
om@3
|
1197 |
while level:
|
om@3
|
1198 |
match, pos = matchorfail(format, pos)
|
om@3
|
1199 |
tstart, tend = match.regs[3]
|
om@3
|
1200 |
token = format[tstart:tend]
|
om@3
|
1201 |
if token == "{":
|
om@3
|
1202 |
level = level + 1
|
om@3
|
1203 |
elif token == "}":
|
om@3
|
1204 |
level = level - 1
|
om@3
|
1205 |
chunks.append((1, format[dollar + 2:pos - 1]))
|
om@3
|
1206 |
|
om@3
|
1207 |
elif nextchar in namechars:
|
om@3
|
1208 |
chunks.append((0, format[pos:dollar]))
|
om@3
|
1209 |
match, pos = matchorfail(format, dollar + 1)
|
om@3
|
1210 |
while pos < len(format):
|
om@3
|
1211 |
if format[pos] == "." and \
|
om@3
|
1212 |
pos + 1 < len(format) and format[pos + 1] in namechars:
|
om@3
|
1213 |
match, pos = matchorfail(format, pos + 1)
|
om@3
|
1214 |
elif format[pos] in "([":
|
om@3
|
1215 |
pos, level = pos + 1, 1
|
om@3
|
1216 |
while level:
|
om@3
|
1217 |
match, pos = matchorfail(format, pos)
|
om@3
|
1218 |
tstart, tend = match.regs[3]
|
om@3
|
1219 |
token = format[tstart:tend]
|
om@3
|
1220 |
if token[0] in "([":
|
om@3
|
1221 |
level = level + 1
|
om@3
|
1222 |
elif token[0] in ")]":
|
om@3
|
1223 |
level = level - 1
|
om@3
|
1224 |
else:
|
om@3
|
1225 |
break
|
om@3
|
1226 |
chunks.append((1, format[dollar + 1:pos]))
|
om@3
|
1227 |
else:
|
om@3
|
1228 |
chunks.append((0, format[pos:dollar + 1]))
|
om@3
|
1229 |
pos = dollar + 1 + (nextchar == "$")
|
om@3
|
1230 |
|
om@3
|
1231 |
if pos < len(format):
|
om@3
|
1232 |
chunks.append((0, format[pos:]))
|
om@3
|
1233 |
return chunks
|
om@3
|
1234 |
|
om@3
|
1235 |
if __name__ == "__main__":
|
om@3
|
1236 |
import doctest
|
om@3
|
1237 |
doctest.testmod()
|