Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/cardinal_pythonlib/sqlalchemy/schema.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# cardinal_pythonlib/sqlalchemy/schema.py
4"""
5===============================================================================
7 Original code copyright (C) 2009-2021 Rudolf Cardinal (rudolf@pobox.com).
9 This file is part of cardinal_pythonlib.
11 Licensed under the Apache License, Version 2.0 (the "License");
12 you may not use this file except in compliance with the License.
13 You may obtain a copy of the License at
15 https://www.apache.org/licenses/LICENSE-2.0
17 Unless required by applicable law or agreed to in writing, software
18 distributed under the License is distributed on an "AS IS" BASIS,
19 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 See the License for the specific language governing permissions and
21 limitations under the License.
23===============================================================================
25**Functions to work with SQLAlchemy schemas (schemata) directly, via SQLAlchemy
26Core.**
28"""
30import ast
31import contextlib
32import copy
33import csv
34from functools import lru_cache
35import io
36import re
37from typing import Any, Dict, Generator, List, Optional, Type, Union
39from sqlalchemy.dialects import mssql, mysql
40# noinspection PyProtectedMember
41from sqlalchemy.engine import Connection, Engine, ResultProxy
42from sqlalchemy.engine.interfaces import Dialect
43from sqlalchemy.engine.reflection import Inspector
44from sqlalchemy.dialects.mssql.base import TIMESTAMP as MSSQL_TIMESTAMP
45from sqlalchemy.schema import (Column, CreateColumn, DDL, MetaData, Index,
46 Sequence, Table)
47from sqlalchemy.sql import sqltypes, text
48from sqlalchemy.sql.sqltypes import BigInteger, TypeEngine
49from sqlalchemy.sql.visitors import VisitableType
51from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler
52from cardinal_pythonlib.sqlalchemy.dialect import (
53 quote_identifier,
54 SqlaDialectName,
55)
57log = get_brace_style_log_with_null_handler(__name__)
59# =============================================================================
60# Constants
61# =============================================================================
63MSSQL_DEFAULT_SCHEMA = 'dbo'
64POSTGRES_DEFAULT_SCHEMA = 'public'
67# =============================================================================
68# Inspect tables (SQLAlchemy Core)
69# =============================================================================
71def get_table_names(engine: Engine) -> List[str]:
72 """
73 Returns a list of database table names from the :class:`Engine`.
74 """
75 insp = Inspector.from_engine(engine)
76 return insp.get_table_names()
79def get_view_names(engine: Engine) -> List[str]:
80 """
81 Returns a list of database view names from the :class:`Engine`.
82 """
83 insp = Inspector.from_engine(engine)
84 return insp.get_view_names()
87def table_exists(engine: Engine, tablename: str) -> bool:
88 """
89 Does the named table exist in the database?
90 """
91 return tablename in get_table_names(engine)
94def view_exists(engine: Engine, viewname: str) -> bool:
95 """
96 Does the named view exist in the database?
97 """
98 return viewname in get_view_names(engine)
101def table_or_view_exists(engine: Engine, table_or_view_name: str) -> bool:
102 """
103 Does the named table/view exist (either as a table or as a view) in the
104 database?
105 """
106 tables_and_views = get_table_names(engine) + get_view_names(engine)
107 return table_or_view_name in tables_and_views
110class SqlaColumnInspectionInfo(object):
111 """
112 Class to represent information from inspecting a database column.
114 A clearer way of getting information than the plain ``dict`` that SQLAlchemy
115 uses.
116 """
117 def __init__(self, sqla_info_dict: Dict[str, Any]) -> None:
118 """
119 Args:
120 sqla_info_dict:
121 see
123 - https://docs.sqlalchemy.org/en/latest/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_columns
124 - https://bitbucket.org/zzzeek/sqlalchemy/issues/4051/sqlalchemyenginereflectioninspectorget_col
125 """ # noqa
126 # log.debug(repr(sqla_info_dict))
127 self.name = sqla_info_dict['name'] # type: str
128 self.type = sqla_info_dict['type'] # type: TypeEngine
129 self.nullable = sqla_info_dict['nullable'] # type: bool
130 self.default = sqla_info_dict['default'] # type: str # SQL string expression # noqa
131 self.attrs = sqla_info_dict.get('attrs', {}) # type: Dict[str, Any]
132 self.comment = sqla_info_dict.get('comment', '')
133 # ... NB not appearing in
136def gen_columns_info(engine: Engine,
137 tablename: str) -> Generator[SqlaColumnInspectionInfo,
138 None, None]:
139 """
140 For the specified table, generate column information as
141 :class:`SqlaColumnInspectionInfo` objects.
142 """
143 # Dictionary structure: see
144 # http://docs.sqlalchemy.org/en/latest/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_columns # noqa
145 insp = Inspector.from_engine(engine)
146 for d in insp.get_columns(tablename):
147 yield SqlaColumnInspectionInfo(d)
150def get_column_info(engine: Engine, tablename: str,
151 columnname: str) -> Optional[SqlaColumnInspectionInfo]:
152 """
153 For the specified column in the specified table, get column information
154 as a :class:`SqlaColumnInspectionInfo` object (or ``None`` if such a
155 column can't be found).
156 """
157 for info in gen_columns_info(engine, tablename):
158 if info.name == columnname:
159 return info
160 return None
163def get_column_type(engine: Engine, tablename: str,
164 columnname: str) -> Optional[TypeEngine]:
165 """
166 For the specified column in the specified table, get its type as an
167 instance of an SQLAlchemy column type class (or ``None`` if such a column
168 can't be found).
170 For more on :class:`TypeEngine`, see
171 :func:`cardinal_pythonlib.orm_inspect.coltype_as_typeengine`.
172 """
173 for info in gen_columns_info(engine, tablename):
174 if info.name == columnname:
175 return info.type
176 return None
179def get_column_names(engine: Engine, tablename: str) -> List[str]:
180 """
181 Get all the database column names for the specified table.
182 """
183 return [info.name for info in gen_columns_info(engine, tablename)]
186# =============================================================================
187# More introspection
188# =============================================================================
190def get_pk_colnames(table_: Table) -> List[str]:
191 """
192 If a table has a PK, this will return its database column name(s);
193 otherwise, ``None``.
194 """
195 pk_names = [] # type: List[str]
196 for col in table_.columns:
197 if col.primary_key:
198 pk_names.append(col.name)
199 return pk_names
202def get_single_int_pk_colname(table_: Table) -> Optional[str]:
203 """
204 If a table has a single-field (non-composite) integer PK, this will
205 return its database column name; otherwise, None.
207 Note that it is legitimate for a database table to have both a composite
208 primary key and a separate ``IDENTITY`` (``AUTOINCREMENT``) integer field.
209 This function won't find such columns.
210 """
211 n_pks = 0
212 int_pk_names = []
213 for col in table_.columns:
214 if col.primary_key:
215 n_pks += 1
216 if is_sqlatype_integer(col.type):
217 int_pk_names.append(col.name)
218 if n_pks == 1 and len(int_pk_names) == 1:
219 return int_pk_names[0]
220 return None
223def get_single_int_autoincrement_colname(table_: Table) -> Optional[str]:
224 """
225 If a table has a single integer ``AUTOINCREMENT`` column, this will
226 return its name; otherwise, ``None``.
228 - It's unlikely that a database has >1 ``AUTOINCREMENT`` field anyway, but
229 we should check.
230 - SQL Server's ``IDENTITY`` keyword is equivalent to MySQL's
231 ``AUTOINCREMENT``.
232 - Verify against SQL Server:
234 .. code-block:: sql
236 SELECT table_name, column_name
237 FROM information_schema.columns
238 WHERE COLUMNPROPERTY(OBJECT_ID(table_schema + '.' + table_name),
239 column_name,
240 'IsIdentity') = 1
241 ORDER BY table_name;
243 ... https://stackoverflow.com/questions/87747
245 - Also:
247 .. code-block:: sql
249 sp_columns 'tablename';
251 ... which is what SQLAlchemy does (``dialects/mssql/base.py``, in
252 :func:`get_columns`).
253 """
254 n_autoinc = 0
255 int_autoinc_names = []
256 for col in table_.columns:
257 if col.autoincrement:
258 n_autoinc += 1
259 if is_sqlatype_integer(col.type):
260 int_autoinc_names.append(col.name)
261 if n_autoinc > 1:
262 log.warning("Table {!r} has {} autoincrement columns",
263 table_.name, n_autoinc)
264 if n_autoinc == 1 and len(int_autoinc_names) == 1:
265 return int_autoinc_names[0]
266 return None
269def get_effective_int_pk_col(table_: Table) -> Optional[str]:
270 """
271 If a table has a single integer primary key, or a single integer
272 ``AUTOINCREMENT`` column, return its column name; otherwise, ``None``.
273 """
274 return (
275 get_single_int_pk_colname(table_) or
276 get_single_int_autoincrement_colname(table_) or
277 None
278 )
281# =============================================================================
282# Indexes
283# =============================================================================
285def index_exists(engine: Engine, tablename: str, indexname: str) -> bool:
286 """
287 Does the specified index exist for the specified table?
288 """
289 insp = Inspector.from_engine(engine)
290 return any(i['name'] == indexname for i in insp.get_indexes(tablename))
293def mssql_get_pk_index_name(engine: Engine,
294 tablename: str,
295 schemaname: str = MSSQL_DEFAULT_SCHEMA) -> str:
296 """
297 For Microsoft SQL Server specifically: fetch the name of the PK index
298 for the specified table (in the specified schema), or ``''`` if none is
299 found.
300 """
301 # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Connection.execute # noqa
302 # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text # noqa
303 # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.TextClause.bindparams # noqa
304 # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.ResultProxy # noqa
305 query = text("""
306SELECT
307 kc.name AS index_name
308FROM
309 sys.key_constraints AS kc
310 INNER JOIN sys.tables AS ta ON ta.object_id = kc.parent_object_id
311 INNER JOIN sys.schemas AS s ON ta.schema_id = s.schema_id
312WHERE
313 kc.[type] = 'PK'
314 AND ta.name = :tablename
315 AND s.name = :schemaname
316 """).bindparams(
317 tablename=tablename,
318 schemaname=schemaname,
319 )
320 with contextlib.closing(
321 engine.execute(query)) as result: # type: ResultProxy # noqa
322 row = result.fetchone()
323 return row[0] if row else ''
326def mssql_table_has_ft_index(engine: Engine,
327 tablename: str,
328 schemaname: str = MSSQL_DEFAULT_SCHEMA) -> bool:
329 """
330 For Microsoft SQL Server specifically: does the specified table (in the
331 specified schema) have at least one full-text index?
332 """
333 query = text("""
334SELECT
335 COUNT(*)
336FROM
337 sys.key_constraints AS kc
338 INNER JOIN sys.tables AS ta ON ta.object_id = kc.parent_object_id
339 INNER JOIN sys.schemas AS s ON ta.schema_id = s.schema_id
340 INNER JOIN sys.fulltext_indexes AS fi ON fi.object_id = ta.object_id
341WHERE
342 ta.name = :tablename
343 AND s.name = :schemaname
344 """).bindparams(
345 tablename=tablename,
346 schemaname=schemaname,
347 )
348 with contextlib.closing(
349 engine.execute(query)) as result: # type: ResultProxy # noqa
350 row = result.fetchone()
351 return row[0] > 0
354def mssql_transaction_count(engine_or_conn: Union[Connection, Engine]) -> int:
355 """
356 For Microsoft SQL Server specifically: fetch the value of the ``TRANCOUNT``
357 variable (see e.g.
358 https://docs.microsoft.com/en-us/sql/t-sql/functions/trancount-transact-sql?view=sql-server-2017).
359 Returns ``None`` if it can't be found (unlikely?).
360 """
361 sql = "SELECT @@TRANCOUNT"
362 with contextlib.closing(
363 engine_or_conn.execute(sql)) as result: # type: ResultProxy # noqa
364 row = result.fetchone()
365 return row[0] if row else None
368def add_index(engine: Engine,
369 sqla_column: Column = None,
370 multiple_sqla_columns: List[Column] = None,
371 unique: bool = False,
372 fulltext: bool = False,
373 length: int = None) -> None:
374 """
375 Adds an index to a database column (or, in restricted circumstances,
376 several columns).
378 The table name is worked out from the :class:`Column` object.
380 Args:
381 engine: SQLAlchemy :class:`Engine` object
382 sqla_column: single column to index
383 multiple_sqla_columns: multiple columns to index (see below)
384 unique: make a ``UNIQUE`` index?
385 fulltext: make a ``FULLTEXT`` index?
386 length: index length to use (default ``None``)
388 Restrictions:
390 - Specify either ``sqla_column`` or ``multiple_sqla_columns``, not both.
391 - The normal method is ``sqla_column``.
392 - ``multiple_sqla_columns`` is only used for Microsoft SQL Server full-text
393 indexing (as this database permits only one full-text index per table,
394 though that index can be on multiple columns).
396 """
397 # We used to process a table as a unit; this makes index creation faster
398 # (using ALTER TABLE).
399 # http://dev.mysql.com/doc/innodb/1.1/en/innodb-create-index-examples.html # noqa
400 # ... ignored in transition to SQLAlchemy
402 def quote(identifier: str) -> str:
403 return quote_identifier(identifier, engine)
405 is_mssql = engine.dialect.name == SqlaDialectName.MSSQL
406 is_mysql = engine.dialect.name == SqlaDialectName.MYSQL
408 multiple_sqla_columns = multiple_sqla_columns or [] # type: List[Column]
409 if multiple_sqla_columns and not (fulltext and is_mssql):
410 raise ValueError("add_index: Use multiple_sqla_columns only for mssql "
411 "(Microsoft SQL Server) full-text indexing")
412 if bool(multiple_sqla_columns) == (sqla_column is not None):
413 raise ValueError(
414 f"add_index: Use either sqla_column or multiple_sqla_columns, "
415 f"not both (sqla_column = {sqla_column!r}, "
416 f"multiple_sqla_columns = {multiple_sqla_columns!r}"
417 )
418 if sqla_column is not None:
419 colnames = [sqla_column.name]
420 sqla_table = sqla_column.table
421 tablename = sqla_table.name
422 else:
423 colnames = [c.name for c in multiple_sqla_columns]
424 sqla_table = multiple_sqla_columns[0].table
425 tablename = sqla_table.name
426 if any(c.table.name != tablename for c in multiple_sqla_columns[1:]):
427 raise ValueError(
428 f"add_index: tablenames are inconsistent in "
429 f"multiple_sqla_columns = {multiple_sqla_columns!r}")
431 if fulltext:
432 if is_mssql:
433 idxname = '' # they are unnamed
434 else:
435 idxname = "_idxft_{}".format("_".join(colnames))
436 else:
437 idxname = "_idx_{}".format("_".join(colnames))
438 if idxname and index_exists(engine, tablename, idxname):
439 log.info(f"Skipping creation of index {idxname} on "
440 f"table {tablename}; already exists")
441 return
442 # because it will crash if you add it again!
443 log.info(
444 "Creating{ft} index {i} on table {t}, column(s) {c}",
445 ft=" full-text" if fulltext else "",
446 i=idxname or "<unnamed>",
447 t=tablename,
448 c=", ".join(colnames),
449 )
451 if fulltext:
452 if is_mysql:
453 log.info('OK to ignore this warning, if it follows next: '
454 '"InnoDB rebuilding table to add column FTS_DOC_ID"')
455 # https://dev.mysql.com/doc/refman/5.6/en/innodb-fulltext-index.html
456 sql = (
457 "ALTER TABLE {tablename} "
458 "ADD FULLTEXT INDEX {idxname} ({colnames})".format(
459 tablename=quote(tablename),
460 idxname=quote(idxname),
461 colnames=", ".join(quote(c) for c in colnames),
462 )
463 )
464 # DDL(sql, bind=engine).execute_if(dialect=SqlaDialectName.MYSQL)
465 DDL(sql, bind=engine).execute()
467 elif is_mssql: # Microsoft SQL Server
468 # https://msdn.microsoft.com/library/ms187317(SQL.130).aspx
469 # Argh! Complex.
470 # Note that the database must also have had a
471 # CREATE FULLTEXT CATALOG somename AS DEFAULT;
472 # statement executed on it beforehand.
473 schemaname = engine.schema_for_object(
474 sqla_table) or MSSQL_DEFAULT_SCHEMA # noqa
475 if mssql_table_has_ft_index(engine=engine,
476 tablename=tablename,
477 schemaname=schemaname):
478 log.info(
479 f"... skipping creation of full-text index on table "
480 f"{tablename}; a full-text index already exists for that "
481 f"table; you can have only one full-text index per table, "
482 f"though it can be on multiple columns")
483 return
484 pk_index_name = mssql_get_pk_index_name(
485 engine=engine, tablename=tablename, schemaname=schemaname)
486 if not pk_index_name:
487 raise ValueError(
488 f"To make a FULLTEXT index under SQL Server, we need to "
489 f"know the name of the PK index, but couldn't find one "
490 f"from get_pk_index_name() for table {tablename!r}")
491 # We don't name the FULLTEXT index itself, but it has to relate
492 # to an existing unique index.
493 sql = (
494 "CREATE FULLTEXT INDEX ON {tablename} ({colnames}) "
495 "KEY INDEX {keyidxname} ".format(
496 tablename=quote(tablename),
497 keyidxname=quote(pk_index_name),
498 colnames=", ".join(quote(c) for c in colnames),
499 )
500 )
501 # SQL Server won't let you do this inside a transaction:
502 # "CREATE FULLTEXT INDEX statement cannot be used inside a user
503 # transaction."
504 # https://msdn.microsoft.com/nl-nl/library/ms191544(v=sql.105).aspx
505 # So let's ensure any preceding transactions are completed, and
506 # run the SQL in a raw way:
507 # engine.execute(sql).execution_options(autocommit=False)
508 # http://docs.sqlalchemy.org/en/latest/core/connections.html#understanding-autocommit
509 #
510 # ... lots of faff with this (see test code in no_transactions.py)
511 # ... ended up using explicit "autocommit=True" parameter (for
512 # pyodbc); see create_indexes()
513 transaction_count = mssql_transaction_count(engine)
514 if transaction_count != 0:
515 log.critical(f"SQL Server transaction count (should be 0): "
516 f"{transaction_count}")
517 # Executing serial COMMITs or a ROLLBACK won't help here if
518 # this transaction is due to Python DBAPI default behaviour.
519 DDL(sql, bind=engine).execute()
521 # The reversal procedure is DROP FULLTEXT INDEX ON tablename;
523 else:
524 log.error(f"Don't know how to make full text index on dialect "
525 f"{engine.dialect.name}")
527 else:
528 index = Index(idxname, sqla_column, unique=unique, mysql_length=length)
529 index.create(engine)
530 # Index creation doesn't require a commit.
533# =============================================================================
534# More DDL
535# =============================================================================
537def make_bigint_autoincrement_column(column_name: str,
538 dialect: Dialect) -> Column:
539 """
540 Returns an instance of :class:`Column` representing a :class:`BigInteger`
541 ``AUTOINCREMENT`` column in the specified :class:`Dialect`.
542 """
543 # noinspection PyUnresolvedReferences
544 if dialect.name == SqlaDialectName.MSSQL:
545 return Column(column_name, BigInteger,
546 Sequence('dummy_name', start=1, increment=1))
547 else:
548 # return Column(column_name, BigInteger, autoincrement=True)
549 # noinspection PyUnresolvedReferences
550 raise AssertionError(
551 f"SQLAlchemy doesn't support non-PK autoincrement fields yet for "
552 f"dialect {dialect.name!r}")
553 # see https://stackoverflow.com/questions/2937229
556def column_creation_ddl(sqla_column: Column, dialect: Dialect) -> str:
557 """
558 Returns DDL to create a column, using the specified dialect.
560 The column should already be bound to a table (because e.g. the SQL Server
561 dialect requires this for DDL generation).
563 Manual testing:
565 .. code-block:: python
567 from sqlalchemy.schema import Column, CreateColumn, MetaData, Sequence, Table
568 from sqlalchemy.sql.sqltypes import BigInteger
569 from sqlalchemy.dialects.mssql.base import MSDialect
570 dialect = MSDialect()
571 col1 = Column('hello', BigInteger, nullable=True)
572 col2 = Column('world', BigInteger, autoincrement=True) # does NOT generate IDENTITY
573 col3 = Column('you', BigInteger, Sequence('dummy_name', start=1, increment=1))
574 metadata = MetaData()
575 t = Table('mytable', metadata)
576 t.append_column(col1)
577 t.append_column(col2)
578 t.append_column(col3)
579 print(str(CreateColumn(col1).compile(dialect=dialect))) # hello BIGINT NULL
580 print(str(CreateColumn(col2).compile(dialect=dialect))) # world BIGINT NULL
581 print(str(CreateColumn(col3).compile(dialect=dialect))) # you BIGINT NOT NULL IDENTITY(1,1)
583 If you don't append the column to a Table object, the DDL generation step
584 gives:
586 .. code-block:: none
588 sqlalchemy.exc.CompileError: mssql requires Table-bound columns in order to generate DDL
589 """ # noqa
590 return str(CreateColumn(sqla_column).compile(dialect=dialect))
593# noinspection PyUnresolvedReferences
594def giant_text_sqltype(dialect: Dialect) -> str:
595 """
596 Returns the SQL column type used to make very large text columns for a
597 given dialect.
599 Args:
600 dialect: a SQLAlchemy :class:`Dialect`
601 Returns:
602 the SQL data type of "giant text", typically 'LONGTEXT' for MySQL
603 and 'NVARCHAR(MAX)' for SQL Server.
604 """
605 if dialect.name == SqlaDialectName.SQLSERVER:
606 return 'NVARCHAR(MAX)'
607 elif dialect.name == SqlaDialectName.MYSQL:
608 return 'LONGTEXT'
609 else:
610 raise ValueError(f"Unknown dialect: {dialect.name}")
613# =============================================================================
614# SQLAlchemy column types
615# =============================================================================
617# -----------------------------------------------------------------------------
618# Reverse a textual SQL column type to an SQLAlchemy column type
619# -----------------------------------------------------------------------------
621RE_MYSQL_ENUM_COLTYPE = re.compile(r'ENUM\((?P<valuelist>.+)\)')
622RE_COLTYPE_WITH_COLLATE = re.compile(r'(?P<maintype>.+) COLLATE .+')
623RE_COLTYPE_WITH_ONE_PARAM = re.compile(r'(?P<type>\w+)\((?P<size>\w+)\)')
624# ... e.g. "VARCHAR(10)"
625RE_COLTYPE_WITH_TWO_PARAMS = re.compile(
626 r'(?P<type>\w+)\((?P<size>\w+),\s*(?P<dp>\w+)\)')
627# ... e.g. "DECIMAL(10, 2)"
630# http://www.w3schools.com/sql/sql_create_table.asp
633def _get_sqla_coltype_class_from_str(coltype: str,
634 dialect: Dialect) -> Type[TypeEngine]:
635 """
636 Returns the SQLAlchemy class corresponding to a particular SQL column
637 type in a given dialect.
639 Performs an upper- and lower-case search.
640 For example, the SQLite dialect uses upper case, and the
641 MySQL dialect uses lower case.
642 """
643 # noinspection PyUnresolvedReferences
644 ischema_names = dialect.ischema_names
645 try:
646 return ischema_names[coltype.upper()]
647 except KeyError:
648 return ischema_names[coltype.lower()]
651def get_list_of_sql_string_literals_from_quoted_csv(x: str) -> List[str]:
652 """
653 Used to extract SQL column type parameters. For example, MySQL has column
654 types that look like ``ENUM('a', 'b', 'c', 'd')``. This function takes the
655 ``"'a', 'b', 'c', 'd'"`` and converts it to ``['a', 'b', 'c', 'd']``.
656 """
657 f = io.StringIO(x)
658 reader = csv.reader(f, delimiter=',', quotechar="'", quoting=csv.QUOTE_ALL,
659 skipinitialspace=True)
660 for line in reader: # should only be one
661 return [x for x in line]
664@lru_cache(maxsize=None)
665def get_sqla_coltype_from_dialect_str(coltype: str,
666 dialect: Dialect) -> TypeEngine:
667 """
668 Returns an SQLAlchemy column type, given a column type name (a string) and
669 an SQLAlchemy dialect. For example, this might convert the string
670 ``INTEGER(11)`` to an SQLAlchemy ``Integer(length=11)``.
672 NOTE that the reverse operation is performed by ``str(coltype)`` or
673 ``coltype.compile()`` or ``coltype.compile(dialect)``; see
674 :class:`TypeEngine`.
676 Args:
677 dialect: a SQLAlchemy :class:`Dialect` class
679 coltype: a ``str()`` representation, e.g. from ``str(c['type'])`` where
680 ``c`` is an instance of :class:`sqlalchemy.sql.schema.Column`.
682 Returns:
683 a Python object that is a subclass of
684 :class:`sqlalchemy.types.TypeEngine`
686 Example:
688 .. code-block:: python
690 get_sqla_coltype_from_string('INTEGER(11)', engine.dialect)
691 # gives: Integer(length=11)
693 Notes:
695 - :class:`sqlalchemy.engine.default.DefaultDialect` is the dialect base
696 class
698 - a dialect contains these things of interest:
700 - ``ischema_names``: string-to-class dictionary
701 - ``type_compiler``: instance of e.g.
702 :class:`sqlalchemy.sql.compiler.GenericTypeCompiler`. This has a
703 ``process()`` method, but that operates on :class:`TypeEngine` objects.
704 - ``get_columns``: takes a table name, inspects the database
706 - example of the dangers of ``eval``:
707 http://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html
709 - An example of a function doing the reflection/inspection within
710 SQLAlchemy is
711 :func:`sqlalchemy.dialects.mssql.base.MSDialect.get_columns`,
712 which has this lookup: ``coltype = self.ischema_names.get(type, None)``
714 Caveats:
716 - the parameters, e.g. ``DATETIME(6)``, do NOT necessarily either work at
717 all or work correctly. For example, SQLAlchemy will happily spit out
718 ``'INTEGER(11)'`` but its :class:`sqlalchemy.sql.sqltypes.INTEGER` class
719 takes no parameters, so you get the error ``TypeError: object() takes no
720 parameters``. Similarly, MySQL's ``DATETIME(6)`` uses the 6 to refer to
721 precision, but the ``DATETIME`` class in SQLAlchemy takes only a boolean
722 parameter (timezone).
723 - However, sometimes we have to have parameters, e.g. ``VARCHAR`` length.
724 - Thus, this is a bit useless.
725 - Fixed, with a few special cases.
726 """
727 size = None # type: Optional[int]
728 dp = None # type: Optional[int]
729 args = [] # type: List[Any]
730 kwargs = {} # type: Dict[str, Any]
731 basetype = ''
733 # noinspection PyPep8,PyBroadException
734 try:
735 # Split e.g. "VARCHAR(32) COLLATE blah" into "VARCHAR(32)", "who cares"
736 m = RE_COLTYPE_WITH_COLLATE.match(coltype)
737 if m is not None:
738 coltype = m.group('maintype')
740 found = False
742 if not found:
743 # Deal with ENUM('a', 'b', 'c', ...)
744 m = RE_MYSQL_ENUM_COLTYPE.match(coltype)
745 if m is not None:
746 # Convert to VARCHAR with max size being that of largest enum
747 basetype = 'VARCHAR'
748 values = get_list_of_sql_string_literals_from_quoted_csv(
749 m.group('valuelist'))
750 length = max(len(x) for x in values)
751 kwargs = {'length': length}
752 found = True
754 if not found:
755 # Split e.g. "DECIMAL(10, 2)" into DECIMAL, 10, 2
756 m = RE_COLTYPE_WITH_TWO_PARAMS.match(coltype)
757 if m is not None:
758 basetype = m.group('type').upper()
759 size = ast.literal_eval(m.group('size'))
760 dp = ast.literal_eval(m.group('dp'))
761 found = True
763 if not found:
764 # Split e.g. "VARCHAR(32)" into VARCHAR, 32
765 m = RE_COLTYPE_WITH_ONE_PARAM.match(coltype)
766 if m is not None:
767 basetype = m.group('type').upper()
768 size_text = m.group('size').strip().upper()
769 if size_text != 'MAX':
770 size = ast.literal_eval(size_text)
771 found = True
773 if not found:
774 basetype = coltype.upper()
776 # Special cases: pre-processing
777 # noinspection PyUnresolvedReferences
778 if (dialect.name == SqlaDialectName.MSSQL and
779 basetype.lower() == 'integer'):
780 basetype = 'int'
782 cls = _get_sqla_coltype_class_from_str(basetype, dialect)
784 # Special cases: post-processing
785 if basetype == 'DATETIME' and size:
786 # First argument to DATETIME() is timezone, so...
787 # noinspection PyUnresolvedReferences
788 if dialect.name == SqlaDialectName.MYSQL:
789 kwargs = {'fsp': size}
790 else:
791 pass
792 else:
793 args = [x for x in (size, dp) if x is not None]
795 try:
796 return cls(*args, **kwargs)
797 except TypeError:
798 return cls()
800 except:
801 # noinspection PyUnresolvedReferences
802 raise ValueError(f"Failed to convert SQL type {coltype!r} in dialect "
803 f"{dialect.name!r} to an SQLAlchemy type")
806# get_sqla_coltype_from_dialect_str("INTEGER", engine.dialect)
807# get_sqla_coltype_from_dialect_str("INTEGER(11)", engine.dialect)
808# get_sqla_coltype_from_dialect_str("VARCHAR(50)", engine.dialect)
809# get_sqla_coltype_from_dialect_str("DATETIME", engine.dialect)
810# get_sqla_coltype_from_dialect_str("DATETIME(6)", engine.dialect)
813# =============================================================================
814# Do special dialect conversions on SQLAlchemy SQL types (of class type)
815# =============================================================================
817def remove_collation(coltype: TypeEngine) -> TypeEngine:
818 """
819 Returns a copy of the specific column type with any ``COLLATION`` removed.
820 """
821 if not getattr(coltype, 'collation', None):
822 return coltype
823 newcoltype = copy.copy(coltype)
824 newcoltype.collation = None
825 return newcoltype
828@lru_cache(maxsize=None)
829def convert_sqla_type_for_dialect(
830 coltype: TypeEngine,
831 dialect: Dialect,
832 strip_collation: bool = True,
833 convert_mssql_timestamp: bool = True,
834 expand_for_scrubbing: bool = False) -> TypeEngine:
835 """
836 Converts an SQLAlchemy column type from one SQL dialect to another.
838 Args:
839 coltype: SQLAlchemy column type in the source dialect
841 dialect: destination :class:`Dialect`
843 strip_collation: remove any ``COLLATION`` information?
845 convert_mssql_timestamp:
846 since you cannot write to a SQL Server ``TIMESTAMP`` field, setting
847 this option to ``True`` (the default) converts such types to
848 something equivalent but writable.
850 expand_for_scrubbing:
851 The purpose of expand_for_scrubbing is that, for example, a
852 ``VARCHAR(200)`` field containing one or more instances of
853 ``Jones``, where ``Jones`` is to be replaced with ``[XXXXXX]``,
854 will get longer (by an unpredictable amount). So, better to expand
855 to unlimited length.
857 Returns:
858 an SQLAlchemy column type instance, in the destination dialect
860 """
861 assert coltype is not None
863 # noinspection PyUnresolvedReferences
864 to_mysql = dialect.name == SqlaDialectName.MYSQL
865 # noinspection PyUnresolvedReferences
866 to_mssql = dialect.name == SqlaDialectName.MSSQL
867 typeclass = type(coltype)
869 # -------------------------------------------------------------------------
870 # Text
871 # -------------------------------------------------------------------------
872 if isinstance(coltype, sqltypes.Enum):
873 return sqltypes.String(length=coltype.length)
874 if isinstance(coltype, sqltypes.UnicodeText):
875 # Unbounded Unicode text.
876 # Includes derived classes such as mssql.base.NTEXT.
877 return sqltypes.UnicodeText()
878 if isinstance(coltype, sqltypes.Text):
879 # Unbounded text, more generally. (UnicodeText inherits from Text.)
880 # Includes sqltypes.TEXT.
881 return sqltypes.Text()
882 # Everything inheriting from String has a length property, but can be None.
883 # There are types that can be unlimited in SQL Server, e.g. VARCHAR(MAX)
884 # and NVARCHAR(MAX), that MySQL needs a length for. (Failure to convert
885 # gives e.g.: 'NVARCHAR requires a length on dialect mysql'.)
886 if isinstance(coltype, sqltypes.Unicode):
887 # Includes NVARCHAR(MAX) in SQL -> NVARCHAR() in SQLAlchemy.
888 if (coltype.length is None and to_mysql) or expand_for_scrubbing:
889 return sqltypes.UnicodeText()
890 # The most general case; will pick up any other string types.
891 if isinstance(coltype, sqltypes.String):
892 # Includes VARCHAR(MAX) in SQL -> VARCHAR() in SQLAlchemy
893 if (coltype.length is None and to_mysql) or expand_for_scrubbing:
894 return sqltypes.Text()
895 if strip_collation:
896 return remove_collation(coltype)
897 return coltype
899 # -------------------------------------------------------------------------
900 # Binary
901 # -------------------------------------------------------------------------
903 # -------------------------------------------------------------------------
904 # BIT
905 # -------------------------------------------------------------------------
906 if typeclass == mssql.base.BIT and to_mysql:
907 # MySQL BIT objects have a length attribute.
908 return mysql.base.BIT()
910 # -------------------------------------------------------------------------
911 # TIMESTAMP
912 # -------------------------------------------------------------------------
913 is_mssql_timestamp = isinstance(coltype, MSSQL_TIMESTAMP)
914 if is_mssql_timestamp and to_mssql and convert_mssql_timestamp:
915 # You cannot write explicitly to a TIMESTAMP field in SQL Server; it's
916 # used for autogenerated values only.
917 # - https://stackoverflow.com/questions/10262426/sql-server-cannot-insert-an-explicit-value-into-a-timestamp-column # noqa
918 # - https://social.msdn.microsoft.com/Forums/sqlserver/en-US/5167204b-ef32-4662-8e01-00c9f0f362c2/how-to-tranfer-a-column-with-timestamp-datatype?forum=transactsql # noqa
919 # ... suggesting BINARY(8) to store the value.
920 # MySQL is more helpful:
921 # - https://stackoverflow.com/questions/409286/should-i-use-field-datetime-or-timestamp # noqa
922 return mssql.base.BINARY(8)
924 # -------------------------------------------------------------------------
925 # Some other type
926 # -------------------------------------------------------------------------
927 return coltype
930# =============================================================================
931# Questions about SQLAlchemy column types
932# =============================================================================
934# Note:
935# x = String } type(x) == VisitableType # metaclass
936# x = BigInteger }
937# but:
938# x = String() } type(x) == TypeEngine
939# x = BigInteger() }
940#
941# isinstance also cheerfully handles multiple inheritance, i.e. if you have
942# class A(object), class B(object), and class C(A, B), followed by x = C(),
943# then all of isinstance(x, A), isinstance(x, B), isinstance(x, C) are True
945def _coltype_to_typeengine(coltype: Union[TypeEngine,
946 VisitableType]) -> TypeEngine:
947 """
948 An example is simplest: if you pass in ``Integer()`` (an instance of
949 :class:`TypeEngine`), you'll get ``Integer()`` back. If you pass in
950 ``Integer`` (an instance of :class:`VisitableType`), you'll also get
951 ``Integer()`` back. The function asserts that its return type is an
952 instance of :class:`TypeEngine`.
954 See also
955 :func:`cardinal_pythonlib.sqlalchemy.orm_inspect.coltype_as_typeengine`.
956 """
957 if isinstance(coltype, VisitableType):
958 coltype = coltype()
959 assert isinstance(coltype, TypeEngine)
960 return coltype
963def is_sqlatype_binary(coltype: Union[TypeEngine, VisitableType]) -> bool:
964 """
965 Is the SQLAlchemy column type a binary type?
966 """
967 # Several binary types inherit internally from _Binary, making that the
968 # easiest to check.
969 coltype = _coltype_to_typeengine(coltype)
970 # noinspection PyProtectedMember
971 return isinstance(coltype, sqltypes._Binary)
974def is_sqlatype_date(coltype: Union[TypeEngine, VisitableType]) -> bool:
975 """
976 Is the SQLAlchemy column type a date type?
977 """
978 coltype = _coltype_to_typeengine(coltype)
979 # No longer valid in SQLAlchemy 1.2.11:
980 # return isinstance(coltype, sqltypes._DateAffinity)
981 return (
982 isinstance(coltype, sqltypes.DateTime) or
983 isinstance(coltype, sqltypes.Date)
984 )
987def is_sqlatype_integer(coltype: Union[TypeEngine, VisitableType]) -> bool:
988 """
989 Is the SQLAlchemy column type an integer type?
990 """
991 coltype = _coltype_to_typeengine(coltype)
992 return isinstance(coltype, sqltypes.Integer)
995def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool:
996 """
997 Is the SQLAlchemy column type one that inherits from :class:`Numeric`,
998 such as :class:`Float`, :class:`Decimal`?
999 """
1000 coltype = _coltype_to_typeengine(coltype)
1001 return isinstance(coltype, sqltypes.Numeric) # includes Float, Decimal
1004def is_sqlatype_string(coltype: Union[TypeEngine, VisitableType]) -> bool:
1005 """
1006 Is the SQLAlchemy column type a string type?
1007 """
1008 coltype = _coltype_to_typeengine(coltype)
1009 return isinstance(coltype, sqltypes.String)
1012def is_sqlatype_text_of_length_at_least(
1013 coltype: Union[TypeEngine, VisitableType],
1014 min_length: int = 1000) -> bool:
1015 """
1016 Is the SQLAlchemy column type a string type that's at least the specified
1017 length?
1018 """
1019 coltype = _coltype_to_typeengine(coltype)
1020 if not isinstance(coltype, sqltypes.String):
1021 return False # not a string/text type at all
1022 if coltype.length is None:
1023 return True # string of unlimited length
1024 return coltype.length >= min_length
1027def is_sqlatype_text_over_one_char(
1028 coltype: Union[TypeEngine, VisitableType]) -> bool:
1029 """
1030 Is the SQLAlchemy column type a string type that's more than one character
1031 long?
1032 """
1033 coltype = _coltype_to_typeengine(coltype)
1034 return is_sqlatype_text_of_length_at_least(coltype, 2)
1037def does_sqlatype_merit_fulltext_index(
1038 coltype: Union[TypeEngine, VisitableType],
1039 min_length: int = 1000) -> bool:
1040 """
1041 Is the SQLAlchemy column type a type that might merit a ``FULLTEXT``
1042 index (meaning a string type of at least ``min_length``)?
1043 """
1044 coltype = _coltype_to_typeengine(coltype)
1045 return is_sqlatype_text_of_length_at_least(coltype, min_length)
1048def does_sqlatype_require_index_len(
1049 coltype: Union[TypeEngine, VisitableType]) -> bool:
1050 """
1051 Is the SQLAlchemy column type one that requires its indexes to have a
1052 length specified?
1054 (MySQL, at least, requires index length to be specified for ``BLOB`` and
1055 ``TEXT`` columns:
1056 https://dev.mysql.com/doc/refman/5.7/en/create-index.html.)
1057 """
1058 coltype = _coltype_to_typeengine(coltype)
1059 if isinstance(coltype, sqltypes.Text):
1060 return True
1061 if isinstance(coltype, sqltypes.LargeBinary):
1062 return True
1063 return False
1066# =============================================================================
1067# Hack in new type
1068# =============================================================================
1070def hack_in_mssql_xml_type():
1071 r"""
1072 Modifies SQLAlchemy's type map for Microsoft SQL Server to support XML.
1074 SQLAlchemy does not support the XML type in SQL Server (mssql).
1075 Upon reflection, we get:
1077 .. code-block:: none
1079 sqlalchemy\dialects\mssql\base.py:1921: SAWarning: Did not recognize type 'xml' of column '...'
1081 We will convert anything of type ``XML`` into type ``TEXT``.
1083 """ # noqa
1084 log.debug("Adding type 'xml' to SQLAlchemy reflection for dialect 'mssql'")
1085 mssql.base.ischema_names['xml'] = mssql.base.TEXT
1086 # https://stackoverflow.com/questions/32917867/sqlalchemy-making-schema-reflection-find-use-a-custom-type-for-all-instances # noqa
1088 # print(repr(mssql.base.ischema_names.keys()))
1089 # print(repr(mssql.base.ischema_names))
1092# =============================================================================
1093# Check column definition equality
1094# =============================================================================
1096def column_types_equal(a_coltype: TypeEngine, b_coltype: TypeEngine) -> bool:
1097 """
1098 Checks that two SQLAlchemy column types are equal (by comparing ``str()``
1099 versions of them).
1101 See https://stackoverflow.com/questions/34787794/sqlalchemy-column-type-comparison.
1103 IMPERFECT.
1104 """ # noqa
1105 return str(a_coltype) == str(b_coltype)
1108def columns_equal(a: Column, b: Column) -> bool:
1109 """
1110 Are two SQLAlchemy columns are equal? Checks based on:
1112 - column ``name``
1113 - column ``type`` (see :func:`column_types_equal`)
1114 - ``nullable``
1115 """
1116 return (
1117 a.name == b.name and
1118 column_types_equal(a.type, b.type) and
1119 a.nullable == b.nullable
1120 )
1123def column_lists_equal(a: List[Column], b: List[Column]) -> bool:
1124 """
1125 Are all columns in list ``a`` equal to their counterparts in list ``b``,
1126 as per :func:`columns_equal`?
1127 """
1128 n = len(a)
1129 if len(b) != n:
1130 return False
1131 for i in range(n):
1132 if not columns_equal(a[i], b[i]):
1133 log.debug("Mismatch: {!r} != {!r}", a[i], b[i])
1134 return False
1135 return True
1138def indexes_equal(a: Index, b: Index) -> bool:
1139 """
1140 Are two indexes equal? Checks by comparing ``str()`` versions of them.
1141 (AM UNSURE IF THIS IS ENOUGH.)
1142 """
1143 return str(a) == str(b)
1146def index_lists_equal(a: List[Index], b: List[Index]) -> bool:
1147 """
1148 Are all indexes in list ``a`` equal to their counterparts in list ``b``,
1149 as per :func:`indexes_equal`?
1150 """
1151 n = len(a)
1152 if len(b) != n:
1153 return False
1154 for i in range(n):
1155 if not indexes_equal(a[i], b[i]):
1156 log.debug("Mismatch: {!r} != {!r}", a[i], b[i])
1157 return False
1158 return True
1161# =============================================================================
1162# Tests
1163# =============================================================================
1165def test_assert(x, y) -> None:
1166 try:
1167 assert x == y
1168 except AssertionError:
1169 print(f"{x!r} should have been {y!r}")
1170 raise
1173def unit_tests() -> None:
1174 from sqlalchemy.dialects.mssql.base import MSDialect
1175 from sqlalchemy.dialects.mysql.base import MySQLDialect
1176 d_mssql = MSDialect()
1177 d_mysql = MySQLDialect()
1178 col1 = Column('hello', BigInteger, nullable=True)
1179 col2 = Column('world', BigInteger,
1180 autoincrement=True) # does NOT generate IDENTITY
1181 col3 = make_bigint_autoincrement_column('you', d_mssql)
1182 metadata = MetaData()
1183 t = Table('mytable', metadata)
1184 t.append_column(col1)
1185 t.append_column(col2)
1186 t.append_column(col3)
1188 print("Checking Column -> DDL: SQL Server (mssql)")
1189 test_assert(column_creation_ddl(col1, d_mssql), "hello BIGINT NULL")
1190 test_assert(column_creation_ddl(col2, d_mssql), "world BIGINT NULL")
1191 test_assert(column_creation_ddl(col3, d_mssql),
1192 "you BIGINT NOT NULL IDENTITY(1,1)")
1194 print("Checking Column -> DDL: MySQL (mysql)")
1195 test_assert(column_creation_ddl(col1, d_mysql), "hello BIGINT")
1196 test_assert(column_creation_ddl(col2, d_mysql), "world BIGINT")
1197 # not col3; unsupported
1199 print("Checking SQL type -> SQL Alchemy type")
1200 to_check = [
1201 # mssql
1202 ("BIGINT", d_mssql),
1203 ("NVARCHAR(32)", d_mssql),
1204 ("NVARCHAR(MAX)", d_mssql),
1205 ('NVARCHAR(160) COLLATE "Latin1_General_CI_AS"', d_mssql),
1206 # mysql
1207 ("BIGINT", d_mssql),
1208 ("LONGTEXT", d_mysql),
1209 ("ENUM('red','green','blue')", d_mysql),
1210 ]
1211 for coltype, dialect in to_check:
1212 print(f"... {coltype!r} -> dialect {dialect.name!r} -> "
1213 f"{get_sqla_coltype_from_dialect_str(coltype, dialect)!r}")
1216if __name__ == '__main__':
1217 unit_tests()