Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pymysql/cursors.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import re
2from . import err
5#: Regular expression for :meth:`Cursor.executemany`.
6#: executemany only supports simple bulk insert.
7#: You can use it to load large dataset.
8RE_INSERT_VALUES = re.compile(
9 r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
10 + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
11 + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
12 re.IGNORECASE | re.DOTALL,
13)
16class Cursor:
17 """
18 This is the object you use to interact with the database.
20 Do not create an instance of a Cursor yourself. Call
21 connections.Connection.cursor().
23 See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in
24 the specification.
25 """
27 #: Max statement size which :meth:`executemany` generates.
28 #:
29 #: Max size of allowed statement is max_allowed_packet - packet_header_size.
30 #: Default value of max_allowed_packet is 1048576.
31 max_stmt_length = 1024000
33 def __init__(self, connection):
34 self.connection = connection
35 self.description = None
36 self.rownumber = 0
37 self.rowcount = -1
38 self.arraysize = 1
39 self._executed = None
40 self._result = None
41 self._rows = None
43 def close(self):
44 """
45 Closing a cursor just exhausts all remaining data.
46 """
47 conn = self.connection
48 if conn is None:
49 return
50 try:
51 while self.nextset():
52 pass
53 finally:
54 self.connection = None
56 def __enter__(self):
57 return self
59 def __exit__(self, *exc_info):
60 del exc_info
61 self.close()
63 def _get_db(self):
64 if not self.connection:
65 raise err.ProgrammingError("Cursor closed")
66 return self.connection
68 def _check_executed(self):
69 if not self._executed:
70 raise err.ProgrammingError("execute() first")
72 def _conv_row(self, row):
73 return row
75 def setinputsizes(self, *args):
76 """Does nothing, required by DB API."""
78 def setoutputsizes(self, *args):
79 """Does nothing, required by DB API."""
81 def _nextset(self, unbuffered=False):
82 """Get the next query set"""
83 conn = self._get_db()
84 current_result = self._result
85 if current_result is None or current_result is not conn._result:
86 return None
87 if not current_result.has_next:
88 return None
89 self._result = None
90 self._clear_result()
91 conn.next_result(unbuffered=unbuffered)
92 self._do_get_result()
93 return True
95 def nextset(self):
96 return self._nextset(False)
98 def _ensure_bytes(self, x, encoding=None):
99 if isinstance(x, str):
100 x = x.encode(encoding)
101 elif isinstance(x, (tuple, list)):
102 x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
103 return x
105 def _escape_args(self, args, conn):
106 if isinstance(args, (tuple, list)):
107 return tuple(conn.literal(arg) for arg in args)
108 elif isinstance(args, dict):
109 return {key: conn.literal(val) for (key, val) in args.items()}
110 else:
111 # If it's not a dictionary let's try escaping it anyways.
112 # Worst case it will throw a Value error
113 return conn.escape(args)
115 def mogrify(self, query, args=None):
116 """
117 Returns the exact string that is sent to the database by calling the
118 execute() method.
120 This method follows the extension to the DB API 2.0 followed by Psycopg.
121 """
122 conn = self._get_db()
124 if args is not None:
125 query = query % self._escape_args(args, conn)
127 return query
129 def execute(self, query, args=None):
130 """Execute a query
132 :param str query: Query to execute.
134 :param args: parameters used with query. (optional)
135 :type args: tuple, list or dict
137 :return: Number of affected rows
138 :rtype: int
140 If args is a list or tuple, %s can be used as a placeholder in the query.
141 If args is a dict, %(name)s can be used as a placeholder in the query.
142 """
143 while self.nextset():
144 pass
146 query = self.mogrify(query, args)
148 result = self._query(query)
149 self._executed = query
150 return result
152 def executemany(self, query, args):
153 # type: (str, list) -> int
154 """Run several data against one query
156 :param query: query to execute on server
157 :param args: Sequence of sequences or mappings. It is used as parameter.
158 :return: Number of rows affected, if any.
160 This method improves performance on multiple-row INSERT and
161 REPLACE. Otherwise it is equivalent to looping over args with
162 execute().
163 """
164 if not args:
165 return
167 m = RE_INSERT_VALUES.match(query)
168 if m:
169 q_prefix = m.group(1) % ()
170 q_values = m.group(2).rstrip()
171 q_postfix = m.group(3) or ""
172 assert q_values[0] == "(" and q_values[-1] == ")"
173 return self._do_execute_many(
174 q_prefix,
175 q_values,
176 q_postfix,
177 args,
178 self.max_stmt_length,
179 self._get_db().encoding,
180 )
182 self.rowcount = sum(self.execute(query, arg) for arg in args)
183 return self.rowcount
185 def _do_execute_many(
186 self, prefix, values, postfix, args, max_stmt_length, encoding
187 ):
188 conn = self._get_db()
189 escape = self._escape_args
190 if isinstance(prefix, str):
191 prefix = prefix.encode(encoding)
192 if isinstance(postfix, str):
193 postfix = postfix.encode(encoding)
194 sql = bytearray(prefix)
195 args = iter(args)
196 v = values % escape(next(args), conn)
197 if isinstance(v, str):
198 v = v.encode(encoding, "surrogateescape")
199 sql += v
200 rows = 0
201 for arg in args:
202 v = values % escape(arg, conn)
203 if isinstance(v, str):
204 v = v.encode(encoding, "surrogateescape")
205 if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
206 rows += self.execute(sql + postfix)
207 sql = bytearray(prefix)
208 else:
209 sql += b","
210 sql += v
211 rows += self.execute(sql + postfix)
212 self.rowcount = rows
213 return rows
215 def callproc(self, procname, args=()):
216 """Execute stored procedure procname with args
218 procname -- string, name of procedure to execute on server
220 args -- Sequence of parameters to use with procedure
222 Returns the original args.
224 Compatibility warning: PEP-249 specifies that any modified
225 parameters must be returned. This is currently impossible
226 as they are only available by storing them in a server
227 variable and then retrieved by a query. Since stored
228 procedures return zero or more result sets, there is no
229 reliable way to get at OUT or INOUT parameters via callproc.
230 The server variables are named @_procname_n, where procname
231 is the parameter above and n is the position of the parameter
232 (from zero). Once all result sets generated by the procedure
233 have been fetched, you can issue a SELECT @_procname_0, ...
234 query using .execute() to get any OUT or INOUT values.
236 Compatibility warning: The act of calling a stored procedure
237 itself creates an empty result set. This appears after any
238 result sets generated by the procedure. This is non-standard
239 behavior with respect to the DB-API. Be sure to use nextset()
240 to advance through all result sets; otherwise you may get
241 disconnected.
242 """
243 conn = self._get_db()
244 if args:
245 fmt = f"@_{procname}_%d=%s"
246 self._query(
247 "SET %s"
248 % ",".join(
249 fmt % (index, conn.escape(arg)) for index, arg in enumerate(args)
250 )
251 )
252 self.nextset()
254 q = "CALL %s(%s)" % (
255 procname,
256 ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]),
257 )
258 self._query(q)
259 self._executed = q
260 return args
262 def fetchone(self):
263 """Fetch the next row"""
264 self._check_executed()
265 if self._rows is None or self.rownumber >= len(self._rows):
266 return None
267 result = self._rows[self.rownumber]
268 self.rownumber += 1
269 return result
271 def fetchmany(self, size=None):
272 """Fetch several rows"""
273 self._check_executed()
274 if self._rows is None:
275 return ()
276 end = self.rownumber + (size or self.arraysize)
277 result = self._rows[self.rownumber : end]
278 self.rownumber = min(end, len(self._rows))
279 return result
281 def fetchall(self):
282 """Fetch all the rows"""
283 self._check_executed()
284 if self._rows is None:
285 return ()
286 if self.rownumber:
287 result = self._rows[self.rownumber :]
288 else:
289 result = self._rows
290 self.rownumber = len(self._rows)
291 return result
293 def scroll(self, value, mode="relative"):
294 self._check_executed()
295 if mode == "relative":
296 r = self.rownumber + value
297 elif mode == "absolute":
298 r = value
299 else:
300 raise err.ProgrammingError("unknown scroll mode %s" % mode)
302 if not (0 <= r < len(self._rows)):
303 raise IndexError("out of range")
304 self.rownumber = r
306 def _query(self, q):
307 conn = self._get_db()
308 self._last_executed = q
309 self._clear_result()
310 conn.query(q)
311 self._do_get_result()
312 return self.rowcount
314 def _clear_result(self):
315 self.rownumber = 0
316 self._result = None
318 self.rowcount = 0
319 self.description = None
320 self.lastrowid = None
321 self._rows = None
323 def _do_get_result(self):
324 conn = self._get_db()
326 self._result = result = conn._result
328 self.rowcount = result.affected_rows
329 self.description = result.description
330 self.lastrowid = result.insert_id
331 self._rows = result.rows
333 def __iter__(self):
334 return iter(self.fetchone, None)
336 Warning = err.Warning
337 Error = err.Error
338 InterfaceError = err.InterfaceError
339 DatabaseError = err.DatabaseError
340 DataError = err.DataError
341 OperationalError = err.OperationalError
342 IntegrityError = err.IntegrityError
343 InternalError = err.InternalError
344 ProgrammingError = err.ProgrammingError
345 NotSupportedError = err.NotSupportedError
348class DictCursorMixin:
349 # You can override this to use OrderedDict or other dict-like types.
350 dict_type = dict
352 def _do_get_result(self):
353 super(DictCursorMixin, self)._do_get_result()
354 fields = []
355 if self.description:
356 for f in self._result.fields:
357 name = f.name
358 if name in fields:
359 name = f.table_name + "." + name
360 fields.append(name)
361 self._fields = fields
363 if fields and self._rows:
364 self._rows = [self._conv_row(r) for r in self._rows]
366 def _conv_row(self, row):
367 if row is None:
368 return None
369 return self.dict_type(zip(self._fields, row))
372class DictCursor(DictCursorMixin, Cursor):
373 """A cursor which returns results as a dictionary"""
376class SSCursor(Cursor):
377 """
378 Unbuffered Cursor, mainly useful for queries that return a lot of data,
379 or for connections to remote servers over a slow network.
381 Instead of copying every row of data into a buffer, this will fetch
382 rows as needed. The upside of this is the client uses much less memory,
383 and rows are returned much faster when traveling over a slow network
384 or if the result set is very big.
386 There are limitations, though. The MySQL protocol doesn't support
387 returning the total number of rows, so the only way to tell how many rows
388 there are is to iterate over every row returned. Also, it currently isn't
389 possible to scroll backwards, as only the current row is held in memory.
390 """
392 def _conv_row(self, row):
393 return row
395 def close(self):
396 conn = self.connection
397 if conn is None:
398 return
400 if self._result is not None and self._result is conn._result:
401 self._result._finish_unbuffered_query()
403 try:
404 while self.nextset():
405 pass
406 finally:
407 self.connection = None
409 __del__ = close
411 def _query(self, q):
412 conn = self._get_db()
413 self._last_executed = q
414 self._clear_result()
415 conn.query(q, unbuffered=True)
416 self._do_get_result()
417 return self.rowcount
419 def nextset(self):
420 return self._nextset(unbuffered=True)
422 def read_next(self):
423 """Read next row"""
424 return self._conv_row(self._result._read_rowdata_packet_unbuffered())
426 def fetchone(self):
427 """Fetch next row"""
428 self._check_executed()
429 row = self.read_next()
430 if row is None:
431 return None
432 self.rownumber += 1
433 return row
435 def fetchall(self):
436 """
437 Fetch all, as per MySQLdb. Pretty useless for large queries, as
438 it is buffered. See fetchall_unbuffered(), if you want an unbuffered
439 generator version of this method.
440 """
441 return list(self.fetchall_unbuffered())
443 def fetchall_unbuffered(self):
444 """
445 Fetch all, implemented as a generator, which isn't to standard,
446 however, it doesn't make sense to return everything in a list, as that
447 would use ridiculous memory for large result sets.
448 """
449 return iter(self.fetchone, None)
451 def __iter__(self):
452 return self.fetchall_unbuffered()
454 def fetchmany(self, size=None):
455 """Fetch many"""
456 self._check_executed()
457 if size is None:
458 size = self.arraysize
460 rows = []
461 for i in range(size):
462 row = self.read_next()
463 if row is None:
464 break
465 rows.append(row)
466 self.rownumber += 1
467 return rows
469 def scroll(self, value, mode="relative"):
470 self._check_executed()
472 if mode == "relative":
473 if value < 0:
474 raise err.NotSupportedError(
475 "Backwards scrolling not supported by this cursor"
476 )
478 for _ in range(value):
479 self.read_next()
480 self.rownumber += value
481 elif mode == "absolute":
482 if value < self.rownumber:
483 raise err.NotSupportedError(
484 "Backwards scrolling not supported by this cursor"
485 )
487 end = value - self.rownumber
488 for _ in range(end):
489 self.read_next()
490 self.rownumber = value
491 else:
492 raise err.ProgrammingError("unknown scroll mode %s" % mode)
495class SSDictCursor(DictCursorMixin, SSCursor):
496 """An unbuffered cursor, which returns results as a dictionary"""