Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pymysql/connections.py : 13%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Python implementation of the MySQL client-server protocol
2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
3# Error codes:
4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
5import errno
6import os
7import socket
8import struct
9import sys
10import traceback
11import warnings
13from . import _auth
15from .charset import charset_by_name, charset_by_id
16from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
17from . import converters
18from .cursors import Cursor
19from .optionfile import Parser
20from .protocol import (
21 dump_packet,
22 MysqlPacket,
23 FieldDescriptorPacket,
24 OKPacketWrapper,
25 EOFPacketWrapper,
26 LoadLocalPacketWrapper,
27)
28from . import err, VERSION_STRING
30try:
31 import ssl
33 SSL_ENABLED = True
34except ImportError:
35 ssl = None
36 SSL_ENABLED = False
38try:
39 import getpass
41 DEFAULT_USER = getpass.getuser()
42 del getpass
43except (ImportError, KeyError):
44 # KeyError occurs when there's no entry in OS database for a current user.
45 DEFAULT_USER = None
47DEBUG = False
49TEXT_TYPES = {
50 FIELD_TYPE.BIT,
51 FIELD_TYPE.BLOB,
52 FIELD_TYPE.LONG_BLOB,
53 FIELD_TYPE.MEDIUM_BLOB,
54 FIELD_TYPE.STRING,
55 FIELD_TYPE.TINY_BLOB,
56 FIELD_TYPE.VAR_STRING,
57 FIELD_TYPE.VARCHAR,
58 FIELD_TYPE.GEOMETRY,
59}
62DEFAULT_CHARSET = "utf8mb4"
64MAX_PACKET_LEN = 2 ** 24 - 1
67def _pack_int24(n):
68 return struct.pack("<I", n)[:3]
71# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
72def _lenenc_int(i):
73 if i < 0:
74 raise ValueError(
75 "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
76 )
77 elif i < 0xFB:
78 return bytes([i])
79 elif i < (1 << 16):
80 return b"\xfc" + struct.pack("<H", i)
81 elif i < (1 << 24):
82 return b"\xfd" + struct.pack("<I", i)[:3]
83 elif i < (1 << 64):
84 return b"\xfe" + struct.pack("<Q", i)
85 else:
86 raise ValueError(
87 "Encoding %x is larger than %x - no representation in LengthEncodedInteger"
88 % (i, (1 << 64))
89 )
92class Connection:
93 """
94 Representation of a socket with a mysql server.
96 The proper way to get an instance of this class is to call
97 connect().
99 Establish a connection to the MySQL database. Accepts several
100 arguments:
102 :param host: Host where the database server is located
103 :param user: Username to log in as
104 :param password: Password to use.
105 :param database: Database to use, None to not use a particular one.
106 :param port: MySQL port to use, default is usually OK. (default: 3306)
107 :param bind_address: When the client has multiple network interfaces, specify
108 the interface from which to connect to the host. Argument can be
109 a hostname or an IP address.
110 :param unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
111 :param read_timeout: The timeout for reading from the connection in seconds (default: None - no timeout)
112 :param write_timeout: The timeout for writing to the connection in seconds (default: None - no timeout)
113 :param charset: Charset you want to use.
114 :param sql_mode: Default SQL_MODE to use.
115 :param read_default_file:
116 Specifies my.cnf file to read these parameters from under the [client] section.
117 :param conv:
118 Conversion dictionary to use instead of the default one.
119 This is used to provide custom marshalling and unmarshalling of types.
120 See converters.
121 :param use_unicode:
122 Whether or not to default to unicode strings.
123 This option defaults to true.
124 :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
125 :param cursorclass: Custom cursor class to use.
126 :param init_command: Initial SQL statement to run when connection is established.
127 :param connect_timeout: Timeout before throwing an exception when connecting.
128 (default: 10, min: 1, max: 31536000)
129 :param ssl:
130 A dict of arguments similar to mysql_ssl_set()'s parameters.
131 :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate
132 :param ssl_cert: Path to the file that contains a PEM-formatted client certificate
133 :param ssl_disabled: A boolean value that disables usage of TLS
134 :param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate
135 :param ssl_verify_cert: Set to true to check the validity of server certificates
136 :param ssl_verify_identity: Set to true to check the server's identity
137 :param read_default_group: Group to read from in the configuration file.
138 :param autocommit: Autocommit mode. None means use server default. (default: False)
139 :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
140 :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
141 Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
142 :param defer_connect: Don't explicitly connect on construction - wait for connect call.
143 (default: False)
144 :param auth_plugin_map: A dict of plugin names to a class that processes that plugin.
145 The class will take the Connection object as the argument to the constructor.
146 The class needs an authenticate method taking an authentication packet as
147 an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
148 (if no authenticate method) for returning a string from the user. (experimental)
149 :param server_public_key: SHA256 authentication plugin public key value. (default: None)
150 :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
151 :param compress: Not supported
152 :param named_pipe: Not supported
153 :param db: **DEPRECATED** Alias for database.
154 :param passwd: **DEPRECATED** Alias for password.
156 See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the
157 specification.
158 """
160 _sock = None
161 _auth_plugin_name = ""
162 _closed = False
163 _secure = False
165 def __init__(
166 self,
167 *,
168 user=None, # The first four arguments is based on DB-API 2.0 recommendation.
169 password="",
170 host=None,
171 database=None,
172 unix_socket=None,
173 port=0,
174 charset="",
175 sql_mode=None,
176 read_default_file=None,
177 conv=None,
178 use_unicode=True,
179 client_flag=0,
180 cursorclass=Cursor,
181 init_command=None,
182 connect_timeout=10,
183 read_default_group=None,
184 autocommit=False,
185 local_infile=False,
186 max_allowed_packet=16 * 1024 * 1024,
187 defer_connect=False,
188 auth_plugin_map=None,
189 read_timeout=None,
190 write_timeout=None,
191 bind_address=None,
192 binary_prefix=False,
193 program_name=None,
194 server_public_key=None,
195 ssl=None,
196 ssl_ca=None,
197 ssl_cert=None,
198 ssl_disabled=None,
199 ssl_key=None,
200 ssl_verify_cert=None,
201 ssl_verify_identity=None,
202 compress=None, # not supported
203 named_pipe=None, # not supported
204 passwd=None, # deprecated
205 db=None, # deprecated
206 ):
207 if db is not None and database is None:
208 # We will raise warining in 2022 or later.
209 # See https://github.com/PyMySQL/PyMySQL/issues/939
210 # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
211 database = db
212 if passwd is not None and not password:
213 # We will raise warining in 2022 or later.
214 # See https://github.com/PyMySQL/PyMySQL/issues/939
215 # warnings.warn(
216 # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
217 # )
218 password = passwd
220 if compress or named_pipe:
221 raise NotImplementedError(
222 "compress and named_pipe arguments are not supported"
223 )
225 self._local_infile = bool(local_infile)
226 if self._local_infile:
227 client_flag |= CLIENT.LOCAL_FILES
229 if read_default_group and not read_default_file:
230 if sys.platform.startswith("win"):
231 read_default_file = "c:\\my.ini"
232 else:
233 read_default_file = "/etc/my.cnf"
235 if read_default_file:
236 if not read_default_group:
237 read_default_group = "client"
239 cfg = Parser()
240 cfg.read(os.path.expanduser(read_default_file))
242 def _config(key, arg):
243 if arg:
244 return arg
245 try:
246 return cfg.get(read_default_group, key)
247 except Exception:
248 return arg
250 user = _config("user", user)
251 password = _config("password", password)
252 host = _config("host", host)
253 database = _config("database", database)
254 unix_socket = _config("socket", unix_socket)
255 port = int(_config("port", port))
256 bind_address = _config("bind-address", bind_address)
257 charset = _config("default-character-set", charset)
258 if not ssl:
259 ssl = {}
260 if isinstance(ssl, dict):
261 for key in ["ca", "capath", "cert", "key", "cipher"]:
262 value = _config("ssl-" + key, ssl.get(key))
263 if value:
264 ssl[key] = value
266 self.ssl = False
267 if not ssl_disabled:
268 if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
269 ssl = {
270 "ca": ssl_ca,
271 "check_hostname": bool(ssl_verify_identity),
272 "verify_mode": ssl_verify_cert
273 if ssl_verify_cert is not None
274 else False,
275 }
276 if ssl_cert is not None:
277 ssl["cert"] = ssl_cert
278 if ssl_key is not None:
279 ssl["key"] = ssl_key
280 if ssl:
281 if not SSL_ENABLED:
282 raise NotImplementedError("ssl module not found")
283 self.ssl = True
284 client_flag |= CLIENT.SSL
285 self.ctx = self._create_ssl_ctx(ssl)
287 self.host = host or "localhost"
288 self.port = port or 3306
289 if type(self.port) is not int:
290 raise ValueError("port should be of type int")
291 self.user = user or DEFAULT_USER
292 self.password = password or b""
293 if isinstance(self.password, str):
294 self.password = self.password.encode("latin1")
295 self.db = database
296 self.unix_socket = unix_socket
297 self.bind_address = bind_address
298 if not (0 < connect_timeout <= 31536000):
299 raise ValueError("connect_timeout should be >0 and <=31536000")
300 self.connect_timeout = connect_timeout or None
301 if read_timeout is not None and read_timeout <= 0:
302 raise ValueError("read_timeout should be > 0")
303 self._read_timeout = read_timeout
304 if write_timeout is not None and write_timeout <= 0:
305 raise ValueError("write_timeout should be > 0")
306 self._write_timeout = write_timeout
308 self.charset = charset or DEFAULT_CHARSET
309 self.use_unicode = use_unicode
311 self.encoding = charset_by_name(self.charset).encoding
313 client_flag |= CLIENT.CAPABILITIES
314 if self.db:
315 client_flag |= CLIENT.CONNECT_WITH_DB
317 self.client_flag = client_flag
319 self.cursorclass = cursorclass
321 self._result = None
322 self._affected_rows = 0
323 self.host_info = "Not connected"
325 # specified autocommit mode. None means use server default.
326 self.autocommit_mode = autocommit
328 if conv is None:
329 conv = converters.conversions
331 # Need for MySQLdb compatibility.
332 self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
333 self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
334 self.sql_mode = sql_mode
335 self.init_command = init_command
336 self.max_allowed_packet = max_allowed_packet
337 self._auth_plugin_map = auth_plugin_map or {}
338 self._binary_prefix = binary_prefix
339 self.server_public_key = server_public_key
341 self._connect_attrs = {
342 "_client_name": "pymysql",
343 "_pid": str(os.getpid()),
344 "_client_version": VERSION_STRING,
345 }
347 if program_name:
348 self._connect_attrs["program_name"] = program_name
350 if defer_connect:
351 self._sock = None
352 else:
353 self.connect()
355 def __enter__(self):
356 return self
358 def __exit__(self, *exc_info):
359 del exc_info
360 self.close()
362 def _create_ssl_ctx(self, sslp):
363 if isinstance(sslp, ssl.SSLContext):
364 return sslp
365 ca = sslp.get("ca")
366 capath = sslp.get("capath")
367 hasnoca = ca is None and capath is None
368 ctx = ssl.create_default_context(cafile=ca, capath=capath)
369 ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
370 verify_mode_value = sslp.get("verify_mode")
371 if verify_mode_value is None:
372 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
373 elif isinstance(verify_mode_value, bool):
374 ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
375 else:
376 if isinstance(verify_mode_value, str):
377 verify_mode_value = verify_mode_value.lower()
378 if verify_mode_value in ("none", "0", "false", "no"):
379 ctx.verify_mode = ssl.CERT_NONE
380 elif verify_mode_value == "optional":
381 ctx.verify_mode = ssl.CERT_OPTIONAL
382 elif verify_mode_value in ("required", "1", "true", "yes"):
383 ctx.verify_mode = ssl.CERT_REQUIRED
384 else:
385 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
386 if "cert" in sslp:
387 ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
388 if "cipher" in sslp:
389 ctx.set_ciphers(sslp["cipher"])
390 ctx.options |= ssl.OP_NO_SSLv2
391 ctx.options |= ssl.OP_NO_SSLv3
392 return ctx
394 def close(self):
395 """
396 Send the quit message and close the socket.
398 See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
399 in the specification.
401 :raise Error: If the connection is already closed.
402 """
403 if self._closed:
404 raise err.Error("Already closed")
405 self._closed = True
406 if self._sock is None:
407 return
408 send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT)
409 try:
410 self._write_bytes(send_data)
411 except Exception:
412 pass
413 finally:
414 self._force_close()
416 @property
417 def open(self):
418 """Return True if the connection is open"""
419 return self._sock is not None
421 def _force_close(self):
422 """Close connection without QUIT message"""
423 if self._sock:
424 try:
425 self._sock.close()
426 except: # noqa
427 pass
428 self._sock = None
429 self._rfile = None
431 __del__ = _force_close
433 def autocommit(self, value):
434 self.autocommit_mode = bool(value)
435 current = self.get_autocommit()
436 if value != current:
437 self._send_autocommit_mode()
439 def get_autocommit(self):
440 return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
442 def _read_ok_packet(self):
443 pkt = self._read_packet()
444 if not pkt.is_ok_packet():
445 raise err.OperationalError(2014, "Command Out of Sync")
446 ok = OKPacketWrapper(pkt)
447 self.server_status = ok.server_status
448 return ok
450 def _send_autocommit_mode(self):
451 """Set whether or not to commit after every execute()"""
452 self._execute_command(
453 COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)
454 )
455 self._read_ok_packet()
457 def begin(self):
458 """Begin transaction."""
459 self._execute_command(COMMAND.COM_QUERY, "BEGIN")
460 self._read_ok_packet()
462 def commit(self):
463 """
464 Commit changes to stable storage.
466 See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
467 in the specification.
468 """
469 self._execute_command(COMMAND.COM_QUERY, "COMMIT")
470 self._read_ok_packet()
472 def rollback(self):
473 """
474 Roll back the current transaction.
476 See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
477 in the specification.
478 """
479 self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
480 self._read_ok_packet()
482 def show_warnings(self):
483 """Send the "SHOW WARNINGS" SQL command."""
484 self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
485 result = MySQLResult(self)
486 result.read()
487 return result.rows
489 def select_db(self, db):
490 """
491 Set current db.
493 :param db: The name of the db.
494 """
495 self._execute_command(COMMAND.COM_INIT_DB, db)
496 self._read_ok_packet()
498 def escape(self, obj, mapping=None):
499 """Escape whatever value you pass to it.
501 Non-standard, for internal use; do not use this in your applications.
502 """
503 if isinstance(obj, str):
504 return "'" + self.escape_string(obj) + "'"
505 if isinstance(obj, (bytes, bytearray)):
506 ret = self._quote_bytes(obj)
507 if self._binary_prefix:
508 ret = "_binary" + ret
509 return ret
510 return converters.escape_item(obj, self.charset, mapping=mapping)
512 def literal(self, obj):
513 """Alias for escape()
515 Non-standard, for internal use; do not use this in your applications.
516 """
517 return self.escape(obj, self.encoders)
519 def escape_string(self, s):
520 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
521 return s.replace("'", "''")
522 return converters.escape_string(s)
524 def _quote_bytes(self, s):
525 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
526 return "'%s'" % (s.replace(b"'", b"''").decode("ascii", "surrogateescape"),)
527 return converters.escape_bytes(s)
529 def cursor(self, cursor=None):
530 """
531 Create a new cursor to execute queries with.
533 :param cursor: The type of cursor to create; one of :py:class:`Cursor`,
534 :py:class:`SSCursor`, :py:class:`DictCursor`, or :py:class:`SSDictCursor`.
535 None means use Cursor.
536 """
537 if cursor:
538 return cursor(self)
539 return self.cursorclass(self)
541 # The following methods are INTERNAL USE ONLY (called from Cursor)
542 def query(self, sql, unbuffered=False):
543 # if DEBUG:
544 # print("DEBUG: sending query:", sql)
545 if isinstance(sql, str):
546 sql = sql.encode(self.encoding, "surrogateescape")
547 self._execute_command(COMMAND.COM_QUERY, sql)
548 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
549 return self._affected_rows
551 def next_result(self, unbuffered=False):
552 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
553 return self._affected_rows
555 def affected_rows(self):
556 return self._affected_rows
558 def kill(self, thread_id):
559 arg = struct.pack("<I", thread_id)
560 self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
561 return self._read_ok_packet()
563 def ping(self, reconnect=True):
564 """
565 Check if the server is alive.
567 :param reconnect: If the connection is closed, reconnect.
568 :raise Error: If the connection is closed and reconnect=False.
569 """
570 if self._sock is None:
571 if reconnect:
572 self.connect()
573 reconnect = False
574 else:
575 raise err.Error("Already closed")
576 try:
577 self._execute_command(COMMAND.COM_PING, "")
578 self._read_ok_packet()
579 except Exception:
580 if reconnect:
581 self.connect()
582 self.ping(False)
583 else:
584 raise
586 def set_charset(self, charset):
587 # Make sure charset is supported.
588 encoding = charset_by_name(charset).encoding
590 self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s" % self.escape(charset))
591 self._read_packet()
592 self.charset = charset
593 self.encoding = encoding
595 def connect(self, sock=None):
596 self._closed = False
597 try:
598 if sock is None:
599 if self.unix_socket:
600 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
601 sock.settimeout(self.connect_timeout)
602 sock.connect(self.unix_socket)
603 self.host_info = "Localhost via UNIX socket"
604 self._secure = True
605 if DEBUG:
606 print("connected using unix_socket")
607 else:
608 kwargs = {}
609 if self.bind_address is not None:
610 kwargs["source_address"] = (self.bind_address, 0)
611 while True:
612 try:
613 sock = socket.create_connection(
614 (self.host, self.port), self.connect_timeout, **kwargs
615 )
616 break
617 except (OSError, IOError) as e:
618 if e.errno == errno.EINTR:
619 continue
620 raise
621 self.host_info = "socket %s:%d" % (self.host, self.port)
622 if DEBUG:
623 print("connected using socket")
624 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
625 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
626 sock.settimeout(None)
628 self._sock = sock
629 self._rfile = sock.makefile("rb")
630 self._next_seq_id = 0
632 self._get_server_information()
633 self._request_authentication()
635 if self.sql_mode is not None:
636 c = self.cursor()
637 c.execute("SET sql_mode=%s", (self.sql_mode,))
639 if self.init_command is not None:
640 c = self.cursor()
641 c.execute(self.init_command)
642 c.close()
643 self.commit()
645 if self.autocommit_mode is not None:
646 self.autocommit(self.autocommit_mode)
647 except BaseException as e:
648 self._rfile = None
649 if sock is not None:
650 try:
651 sock.close()
652 except: # noqa
653 pass
655 if isinstance(e, (OSError, IOError, socket.error)):
656 exc = err.OperationalError(
657 2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e)
658 )
659 # Keep original exception and traceback to investigate error.
660 exc.original_exception = e
661 exc.traceback = traceback.format_exc()
662 if DEBUG:
663 print(exc.traceback)
664 raise exc
666 # If e is neither DatabaseError or IOError, It's a bug.
667 # But raising AssertionError hides original error.
668 # So just reraise it.
669 raise
671 def write_packet(self, payload):
672 """Writes an entire "mysql packet" in its entirety to the network
673 adding its length and sequence number.
674 """
675 # Internal note: when you build packet manually and calls _write_bytes()
676 # directly, you should set self._next_seq_id properly.
677 data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
678 if DEBUG:
679 dump_packet(data)
680 self._write_bytes(data)
681 self._next_seq_id = (self._next_seq_id + 1) % 256
683 def _read_packet(self, packet_type=MysqlPacket):
684 """Read an entire "mysql packet" in its entirety from the network
685 and return a MysqlPacket type that represents the results.
687 :raise OperationalError: If the connection to the MySQL server is lost.
688 :raise InternalError: If the packet sequence number is wrong.
689 """
690 buff = bytearray()
691 while True:
692 packet_header = self._read_bytes(4)
693 # if DEBUG: dump_packet(packet_header)
695 btrl, btrh, packet_number = struct.unpack("<HBB", packet_header)
696 bytes_to_read = btrl + (btrh << 16)
697 if packet_number != self._next_seq_id:
698 self._force_close()
699 if packet_number == 0:
700 # MariaDB sends error packet with seqno==0 when shutdown
701 raise err.OperationalError(
702 CR.CR_SERVER_LOST,
703 "Lost connection to MySQL server during query",
704 )
705 raise err.InternalError(
706 "Packet sequence number wrong - got %d expected %d"
707 % (packet_number, self._next_seq_id)
708 )
709 self._next_seq_id = (self._next_seq_id + 1) % 256
711 recv_data = self._read_bytes(bytes_to_read)
712 if DEBUG:
713 dump_packet(recv_data)
714 buff += recv_data
715 # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
716 if bytes_to_read == 0xFFFFFF:
717 continue
718 if bytes_to_read < MAX_PACKET_LEN:
719 break
721 packet = packet_type(bytes(buff), self.encoding)
722 if packet.is_error_packet():
723 if self._result is not None and self._result.unbuffered_active is True:
724 self._result.unbuffered_active = False
725 packet.raise_for_error()
726 return packet
728 def _read_bytes(self, num_bytes):
729 self._sock.settimeout(self._read_timeout)
730 while True:
731 try:
732 data = self._rfile.read(num_bytes)
733 break
734 except (IOError, OSError) as e:
735 if e.errno == errno.EINTR:
736 continue
737 self._force_close()
738 raise err.OperationalError(
739 CR.CR_SERVER_LOST,
740 "Lost connection to MySQL server during query (%s)" % (e,),
741 )
742 except BaseException:
743 # Don't convert unknown exception to MySQLError.
744 self._force_close()
745 raise
746 if len(data) < num_bytes:
747 self._force_close()
748 raise err.OperationalError(
749 CR.CR_SERVER_LOST, "Lost connection to MySQL server during query"
750 )
751 return data
753 def _write_bytes(self, data):
754 self._sock.settimeout(self._write_timeout)
755 try:
756 self._sock.sendall(data)
757 except IOError as e:
758 self._force_close()
759 raise err.OperationalError(
760 CR.CR_SERVER_GONE_ERROR, "MySQL server has gone away (%r)" % (e,)
761 )
763 def _read_query_result(self, unbuffered=False):
764 self._result = None
765 if unbuffered:
766 try:
767 result = MySQLResult(self)
768 result.init_unbuffered_query()
769 except:
770 result.unbuffered_active = False
771 result.connection = None
772 raise
773 else:
774 result = MySQLResult(self)
775 result.read()
776 self._result = result
777 if result.server_status is not None:
778 self.server_status = result.server_status
779 return result.affected_rows
781 def insert_id(self):
782 if self._result:
783 return self._result.insert_id
784 else:
785 return 0
787 def _execute_command(self, command, sql):
788 """
789 :raise InterfaceError: If the connection is closed.
790 :raise ValueError: If no username was specified.
791 """
792 if not self._sock:
793 raise err.InterfaceError(0, "")
795 # If the last query was unbuffered, make sure it finishes before
796 # sending new commands
797 if self._result is not None:
798 if self._result.unbuffered_active:
799 warnings.warn("Previous unbuffered result was left incomplete")
800 self._result._finish_unbuffered_query()
801 while self._result.has_next:
802 self.next_result()
803 self._result = None
805 if isinstance(sql, str):
806 sql = sql.encode(self.encoding)
808 packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
810 # tiny optimization: build first packet manually instead of
811 # calling self..write_packet()
812 prelude = struct.pack("<iB", packet_size, command)
813 packet = prelude + sql[: packet_size - 1]
814 self._write_bytes(packet)
815 if DEBUG:
816 dump_packet(packet)
817 self._next_seq_id = 1
819 if packet_size < MAX_PACKET_LEN:
820 return
822 sql = sql[packet_size - 1 :]
823 while True:
824 packet_size = min(MAX_PACKET_LEN, len(sql))
825 self.write_packet(sql[:packet_size])
826 sql = sql[packet_size:]
827 if not sql and packet_size < MAX_PACKET_LEN:
828 break
830 def _request_authentication(self):
831 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
832 if int(self.server_version.split(".", 1)[0]) >= 5:
833 self.client_flag |= CLIENT.MULTI_RESULTS
835 if self.user is None:
836 raise ValueError("Did not specify a username")
838 charset_id = charset_by_name(self.charset).id
839 if isinstance(self.user, str):
840 self.user = self.user.encode(self.encoding)
842 data_init = struct.pack(
843 "<iIB23s", self.client_flag, MAX_PACKET_LEN, charset_id, b""
844 )
846 if self.ssl and self.server_capabilities & CLIENT.SSL:
847 self.write_packet(data_init)
849 self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
850 self._rfile = self._sock.makefile("rb")
851 self._secure = True
853 data = data_init + self.user + b"\0"
855 authresp = b""
856 plugin_name = None
858 if self._auth_plugin_name == "":
859 plugin_name = b""
860 authresp = _auth.scramble_native_password(self.password, self.salt)
861 elif self._auth_plugin_name == "mysql_native_password":
862 plugin_name = b"mysql_native_password"
863 authresp = _auth.scramble_native_password(self.password, self.salt)
864 elif self._auth_plugin_name == "caching_sha2_password":
865 plugin_name = b"caching_sha2_password"
866 if self.password:
867 if DEBUG:
868 print("caching_sha2: trying fast path")
869 authresp = _auth.scramble_caching_sha2(self.password, self.salt)
870 else:
871 if DEBUG:
872 print("caching_sha2: empty password")
873 elif self._auth_plugin_name == "sha256_password":
874 plugin_name = b"sha256_password"
875 if self.ssl and self.server_capabilities & CLIENT.SSL:
876 authresp = self.password + b"\0"
877 elif self.password:
878 authresp = b"\1" # request public key
879 else:
880 authresp = b"\0" # empty password
882 if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
883 data += _lenenc_int(len(authresp)) + authresp
884 elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
885 data += struct.pack("B", len(authresp)) + authresp
886 else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
887 data += authresp + b"\0"
889 if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
890 if isinstance(self.db, str):
891 self.db = self.db.encode(self.encoding)
892 data += self.db + b"\0"
894 if self.server_capabilities & CLIENT.PLUGIN_AUTH:
895 data += (plugin_name or b"") + b"\0"
897 if self.server_capabilities & CLIENT.CONNECT_ATTRS:
898 connect_attrs = b""
899 for k, v in self._connect_attrs.items():
900 k = k.encode("utf-8")
901 connect_attrs += struct.pack("B", len(k)) + k
902 v = v.encode("utf-8")
903 connect_attrs += struct.pack("B", len(v)) + v
904 data += struct.pack("B", len(connect_attrs)) + connect_attrs
906 self.write_packet(data)
907 auth_packet = self._read_packet()
909 # if authentication method isn't accepted the first byte
910 # will have the octet 254
911 if auth_packet.is_auth_switch_request():
912 if DEBUG:
913 print("received auth switch")
914 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
915 auth_packet.read_uint8() # 0xfe packet identifier
916 plugin_name = auth_packet.read_string()
917 if (
918 self.server_capabilities & CLIENT.PLUGIN_AUTH
919 and plugin_name is not None
920 ):
921 auth_packet = self._process_auth(plugin_name, auth_packet)
922 else:
923 # send legacy handshake
924 data = _auth.scramble_old_password(self.password, self.salt) + b"\0"
925 self.write_packet(data)
926 auth_packet = self._read_packet()
927 elif auth_packet.is_extra_auth_data():
928 if DEBUG:
929 print("received extra data")
930 # https://dev.mysql.com/doc/internals/en/successful-authentication.html
931 if self._auth_plugin_name == "caching_sha2_password":
932 auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
933 elif self._auth_plugin_name == "sha256_password":
934 auth_packet = _auth.sha256_password_auth(self, auth_packet)
935 else:
936 raise err.OperationalError(
937 "Received extra packet for auth method %r", self._auth_plugin_name
938 )
940 if DEBUG:
941 print("Succeed to auth")
943 def _process_auth(self, plugin_name, auth_packet):
944 handler = self._get_auth_plugin_handler(plugin_name)
945 if handler:
946 try:
947 return handler.authenticate(auth_packet)
948 except AttributeError:
949 if plugin_name != b"dialog":
950 raise err.OperationalError(
951 2059,
952 "Authentication plugin '%s'"
953 " not loaded: - %r missing authenticate method"
954 % (plugin_name, type(handler)),
955 )
956 if plugin_name == b"caching_sha2_password":
957 return _auth.caching_sha2_password_auth(self, auth_packet)
958 elif plugin_name == b"sha256_password":
959 return _auth.sha256_password_auth(self, auth_packet)
960 elif plugin_name == b"mysql_native_password":
961 data = _auth.scramble_native_password(self.password, auth_packet.read_all())
962 elif plugin_name == b"client_ed25519":
963 data = _auth.ed25519_password(self.password, auth_packet.read_all())
964 elif plugin_name == b"mysql_old_password":
965 data = (
966 _auth.scramble_old_password(self.password, auth_packet.read_all())
967 + b"\0"
968 )
969 elif plugin_name == b"mysql_clear_password":
970 # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
971 data = self.password + b"\0"
972 elif plugin_name == b"dialog":
973 pkt = auth_packet
974 while True:
975 flag = pkt.read_uint8()
976 echo = (flag & 0x06) == 0x02
977 last = (flag & 0x01) == 0x01
978 prompt = pkt.read_all()
980 if prompt == b"Password: ":
981 self.write_packet(self.password + b"\0")
982 elif handler:
983 resp = "no response - TypeError within plugin.prompt method"
984 try:
985 resp = handler.prompt(echo, prompt)
986 self.write_packet(resp + b"\0")
987 except AttributeError:
988 raise err.OperationalError(
989 2059,
990 "Authentication plugin '%s'"
991 " not loaded: - %r missing prompt method"
992 % (plugin_name, handler),
993 )
994 except TypeError:
995 raise err.OperationalError(
996 2061,
997 "Authentication plugin '%s'"
998 " %r didn't respond with string. Returned '%r' to prompt %r"
999 % (plugin_name, handler, resp, prompt),
1000 )
1001 else:
1002 raise err.OperationalError(
1003 2059,
1004 "Authentication plugin '%s' (%r) not configured"
1005 % (plugin_name, handler),
1006 )
1007 pkt = self._read_packet()
1008 pkt.check_error()
1009 if pkt.is_ok_packet() or last:
1010 break
1011 return pkt
1012 else:
1013 raise err.OperationalError(
1014 2059, "Authentication plugin '%s' not configured" % plugin_name
1015 )
1017 self.write_packet(data)
1018 pkt = self._read_packet()
1019 pkt.check_error()
1020 return pkt
1022 def _get_auth_plugin_handler(self, plugin_name):
1023 plugin_class = self._auth_plugin_map.get(plugin_name)
1024 if not plugin_class and isinstance(plugin_name, bytes):
1025 plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii"))
1026 if plugin_class:
1027 try:
1028 handler = plugin_class(self)
1029 except TypeError:
1030 raise err.OperationalError(
1031 2059,
1032 "Authentication plugin '%s'"
1033 " not loaded: - %r cannot be constructed with connection object"
1034 % (plugin_name, plugin_class),
1035 )
1036 else:
1037 handler = None
1038 return handler
1040 # _mysql support
1041 def thread_id(self):
1042 return self.server_thread_id[0]
1044 def character_set_name(self):
1045 return self.charset
1047 def get_host_info(self):
1048 return self.host_info
1050 def get_proto_info(self):
1051 return self.protocol_version
1053 def _get_server_information(self):
1054 i = 0
1055 packet = self._read_packet()
1056 data = packet.get_all_data()
1058 self.protocol_version = data[i]
1059 i += 1
1061 server_end = data.find(b"\0", i)
1062 self.server_version = data[i:server_end].decode("latin1")
1063 i = server_end + 1
1065 self.server_thread_id = struct.unpack("<I", data[i : i + 4])
1066 i += 4
1068 self.salt = data[i : i + 8]
1069 i += 9 # 8 + 1(filler)
1071 self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0]
1072 i += 2
1074 if len(data) >= i + 6:
1075 lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6])
1076 i += 6
1077 # TODO: deprecate server_language and server_charset.
1078 # mysqlclient-python doesn't provide it.
1079 self.server_language = lang
1080 try:
1081 self.server_charset = charset_by_id(lang).name
1082 except KeyError:
1083 # unknown collation
1084 self.server_charset = None
1086 self.server_status = stat
1087 if DEBUG:
1088 print("server_status: %x" % stat)
1090 self.server_capabilities |= cap_h << 16
1091 if DEBUG:
1092 print("salt_len:", salt_len)
1093 salt_len = max(12, salt_len - 9)
1095 # reserved
1096 i += 10
1098 if len(data) >= i + salt_len:
1099 # salt_len includes auth_plugin_data_part_1 and filler
1100 self.salt += data[i : i + salt_len]
1101 i += salt_len
1103 i += 1
1104 # AUTH PLUGIN NAME may appear here.
1105 if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1106 # Due to Bug#59453 the auth-plugin-name is missing the terminating
1107 # NUL-char in versions prior to 5.5.10 and 5.6.2.
1108 # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
1109 # didn't use version checks as mariadb is corrected and reports
1110 # earlier than those two.
1111 server_end = data.find(b"\0", i)
1112 if server_end < 0: # pragma: no cover - very specific upstream bug
1113 # not found \0 and last field so take it all
1114 self._auth_plugin_name = data[i:].decode("utf-8")
1115 else:
1116 self._auth_plugin_name = data[i:server_end].decode("utf-8")
1118 def get_server_info(self):
1119 return self.server_version
1121 Warning = err.Warning
1122 Error = err.Error
1123 InterfaceError = err.InterfaceError
1124 DatabaseError = err.DatabaseError
1125 DataError = err.DataError
1126 OperationalError = err.OperationalError
1127 IntegrityError = err.IntegrityError
1128 InternalError = err.InternalError
1129 ProgrammingError = err.ProgrammingError
1130 NotSupportedError = err.NotSupportedError
1133class MySQLResult:
1134 def __init__(self, connection):
1135 """
1136 :type connection: Connection
1137 """
1138 self.connection = connection
1139 self.affected_rows = None
1140 self.insert_id = None
1141 self.server_status = None
1142 self.warning_count = 0
1143 self.message = None
1144 self.field_count = 0
1145 self.description = None
1146 self.rows = None
1147 self.has_next = None
1148 self.unbuffered_active = False
1150 def __del__(self):
1151 if self.unbuffered_active:
1152 self._finish_unbuffered_query()
1154 def read(self):
1155 try:
1156 first_packet = self.connection._read_packet()
1158 if first_packet.is_ok_packet():
1159 self._read_ok_packet(first_packet)
1160 elif first_packet.is_load_local_packet():
1161 self._read_load_local_packet(first_packet)
1162 else:
1163 self._read_result_packet(first_packet)
1164 finally:
1165 self.connection = None
1167 def init_unbuffered_query(self):
1168 """
1169 :raise OperationalError: If the connection to the MySQL server is lost.
1170 :raise InternalError:
1171 """
1172 self.unbuffered_active = True
1173 first_packet = self.connection._read_packet()
1175 if first_packet.is_ok_packet():
1176 self._read_ok_packet(first_packet)
1177 self.unbuffered_active = False
1178 self.connection = None
1179 elif first_packet.is_load_local_packet():
1180 self._read_load_local_packet(first_packet)
1181 self.unbuffered_active = False
1182 self.connection = None
1183 else:
1184 self.field_count = first_packet.read_length_encoded_integer()
1185 self._get_descriptions()
1187 # Apparently, MySQLdb picks this number because it's the maximum
1188 # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1189 # we set it to this instead of None, which would be preferred.
1190 self.affected_rows = 18446744073709551615
1192 def _read_ok_packet(self, first_packet):
1193 ok_packet = OKPacketWrapper(first_packet)
1194 self.affected_rows = ok_packet.affected_rows
1195 self.insert_id = ok_packet.insert_id
1196 self.server_status = ok_packet.server_status
1197 self.warning_count = ok_packet.warning_count
1198 self.message = ok_packet.message
1199 self.has_next = ok_packet.has_next
1201 def _read_load_local_packet(self, first_packet):
1202 if not self.connection._local_infile:
1203 raise RuntimeError(
1204 "**WARN**: Received LOAD_LOCAL packet but local_infile option is false."
1205 )
1206 load_packet = LoadLocalPacketWrapper(first_packet)
1207 sender = LoadLocalFile(load_packet.filename, self.connection)
1208 try:
1209 sender.send_data()
1210 except:
1211 self.connection._read_packet() # skip ok packet
1212 raise
1214 ok_packet = self.connection._read_packet()
1215 if (
1216 not ok_packet.is_ok_packet()
1217 ): # pragma: no cover - upstream induced protocol error
1218 raise err.OperationalError(2014, "Commands Out of Sync")
1219 self._read_ok_packet(ok_packet)
1221 def _check_packet_is_eof(self, packet):
1222 if not packet.is_eof_packet():
1223 return False
1224 # TODO: Support CLIENT.DEPRECATE_EOF
1225 # 1) Add DEPRECATE_EOF to CAPABILITIES
1226 # 2) Mask CAPABILITIES with server_capabilities
1227 # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
1228 wp = EOFPacketWrapper(packet)
1229 self.warning_count = wp.warning_count
1230 self.has_next = wp.has_next
1231 return True
1233 def _read_result_packet(self, first_packet):
1234 self.field_count = first_packet.read_length_encoded_integer()
1235 self._get_descriptions()
1236 self._read_rowdata_packet()
1238 def _read_rowdata_packet_unbuffered(self):
1239 # Check if in an active query
1240 if not self.unbuffered_active:
1241 return
1243 # EOF
1244 packet = self.connection._read_packet()
1245 if self._check_packet_is_eof(packet):
1246 self.unbuffered_active = False
1247 self.connection = None
1248 self.rows = None
1249 return
1251 row = self._read_row_from_packet(packet)
1252 self.affected_rows = 1
1253 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1254 return row
1256 def _finish_unbuffered_query(self):
1257 # After much reading on the MySQL protocol, it appears that there is,
1258 # in fact, no way to stop MySQL from sending all the data after
1259 # executing a query, so we just spin, and wait for an EOF packet.
1260 while self.unbuffered_active:
1261 packet = self.connection._read_packet()
1262 if self._check_packet_is_eof(packet):
1263 self.unbuffered_active = False
1264 self.connection = None # release reference to kill cyclic reference.
1266 def _read_rowdata_packet(self):
1267 """Read a rowdata packet for each data row in the result set."""
1268 rows = []
1269 while True:
1270 packet = self.connection._read_packet()
1271 if self._check_packet_is_eof(packet):
1272 self.connection = None # release reference to kill cyclic reference.
1273 break
1274 rows.append(self._read_row_from_packet(packet))
1276 self.affected_rows = len(rows)
1277 self.rows = tuple(rows)
1279 def _read_row_from_packet(self, packet):
1280 row = []
1281 for encoding, converter in self.converters:
1282 try:
1283 data = packet.read_length_coded_string()
1284 except IndexError:
1285 # No more columns in this row
1286 # See https://github.com/PyMySQL/PyMySQL/pull/434
1287 break
1288 if data is not None:
1289 if encoding is not None:
1290 data = data.decode(encoding)
1291 if DEBUG:
1292 print("DEBUG: DATA = ", data)
1293 if converter is not None:
1294 data = converter(data)
1295 row.append(data)
1296 return tuple(row)
1298 def _get_descriptions(self):
1299 """Read a column descriptor packet for each column in the result."""
1300 self.fields = []
1301 self.converters = []
1302 use_unicode = self.connection.use_unicode
1303 conn_encoding = self.connection.encoding
1304 description = []
1306 for i in range(self.field_count):
1307 field = self.connection._read_packet(FieldDescriptorPacket)
1308 self.fields.append(field)
1309 description.append(field.description())
1310 field_type = field.type_code
1311 if use_unicode:
1312 if field_type == FIELD_TYPE.JSON:
1313 # When SELECT from JSON column: charset = binary
1314 # When SELECT CAST(... AS JSON): charset = connection encoding
1315 # This behavior is different from TEXT / BLOB.
1316 # We should decode result by connection encoding regardless charsetnr.
1317 # See https://github.com/PyMySQL/PyMySQL/issues/488
1318 encoding = conn_encoding # SELECT CAST(... AS JSON)
1319 elif field_type in TEXT_TYPES:
1320 if field.charsetnr == 63: # binary
1321 # TEXTs with charset=binary means BINARY types.
1322 encoding = None
1323 else:
1324 encoding = conn_encoding
1325 else:
1326 # Integers, Dates and Times, and other basic data is encoded in ascii
1327 encoding = "ascii"
1328 else:
1329 encoding = None
1330 converter = self.connection.decoders.get(field_type)
1331 if converter is converters.through:
1332 converter = None
1333 if DEBUG:
1334 print(f"DEBUG: field={field}, converter={converter}")
1335 self.converters.append((encoding, converter))
1337 eof_packet = self.connection._read_packet()
1338 assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF"
1339 self.description = tuple(description)
1342class LoadLocalFile:
1343 def __init__(self, filename, connection):
1344 self.filename = filename
1345 self.connection = connection
1347 def send_data(self):
1348 """Send data packets from the local file to the server"""
1349 if not self.connection._sock:
1350 raise err.InterfaceError(0, "")
1351 conn = self.connection
1353 try:
1354 with open(self.filename, "rb") as open_file:
1355 packet_size = min(
1356 conn.max_allowed_packet, 16 * 1024
1357 ) # 16KB is efficient enough
1358 while True:
1359 chunk = open_file.read(packet_size)
1360 if not chunk:
1361 break
1362 conn.write_packet(chunk)
1363 except IOError:
1364 raise err.OperationalError(1017, f"Can't find file '{self.filename}'")
1365 finally:
1366 # send the empty packet to signify we are done sending data
1367 conn.write_packet(b"")