Source code for crate_anon.common.sql

#!/usr/bin/env python
# crate_anon/common/sql.py

"""
===============================================================================

    Copyright (C) 2015-2018 Rudolf Cardinal (rudolf@pobox.com).

    This file is part of CRATE.

    CRATE is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    CRATE is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with CRATE. If not, see <http://www.gnu.org/licenses/>.

===============================================================================
"""

import argparse
from collections import OrderedDict
import functools
import logging
import re
from typing import Any, Dict, Iterable, List, Tuple, Union

from cardinal_pythonlib.json.serialize import (
    METHOD_PROVIDES_INIT_KWARGS,
    METHOD_STRIP_UNDERSCORE,
    register_for_json,
)
from cardinal_pythonlib.lists import unique_list
from cardinal_pythonlib.logs import main_only_quicksetup_rootlogger
from cardinal_pythonlib.reprfunc import mapped_repr_stripping_underscores
from cardinal_pythonlib.sizeformatter import sizeof_fmt
from cardinal_pythonlib.sql.literals import (
    sql_date_literal,
    sql_string_literal,
)
from cardinal_pythonlib.sql.sql_grammar import SqlGrammar, text_from_parsed
from cardinal_pythonlib.sql.sql_grammar_factory import (
    make_grammar,
    mysql_grammar,
)
from cardinal_pythonlib.sqlalchemy.core_query import count_star
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from cardinal_pythonlib.sqlalchemy.schema import column_creation_ddl
from cardinal_pythonlib.timing import MultiTimerContext, timer
from pyparsing import ParseResults
from sqlalchemy import inspect
from sqlalchemy.dialects.mssql.base import MS_2012_VERSION
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import Column, Table

from crate_anon.common.stringfunc import get_spec_match_regex

log = logging.getLogger(__name__)

# =============================================================================
# Types
# =============================================================================

SqlArgsTupleType = Tuple[str, List[Any]]


# =============================================================================
# Constants
# =============================================================================

TIMING_COMMIT = "commit"

SQL_OPS_VALUE_UNNECESSARY = ['IS NULL', 'IS NOT NULL']
SQL_OPS_MULTIPLE_VALUES = ['IN', 'NOT IN']

SQLTYPES_INTEGER = [
    "INT", "INTEGER",
    "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT",
    "BIT", "BOOL", "BOOLEAN",
]
SQLTYPES_FLOAT = [
    "DOUBLE", "FLOAT", "DEC", "DECIMAL",
]
SQLTYPES_TEXT = [
    "CHAR", "VARCHAR", "NVARCHAR",
    "TINYTEXT", "TEXT", "NTEXT", "MEDIUMTEXT", "LONGTEXT",
]
SQLTYPES_WITH_DATE = [
    "DATE", "DATETIME", "TIMESTAMP",
]
# SQLTYPES_BINARY = [
#     "BINARY", "BLOB", "IMAGE", "LONGBLOB", "VARBINARY",
# ]

# Must match querybuilder.js:
QB_DATATYPE_INTEGER = "int"
QB_DATATYPE_FLOAT = "float"
QB_DATATYPE_DATE = "date"
QB_DATATYPE_STRING = "string"
QB_DATATYPE_STRING_FULLTEXT = "string_fulltext"
QB_DATATYPE_UNKNOWN = "unknown"
QB_STRING_TYPES = [QB_DATATYPE_STRING, QB_DATATYPE_STRING_FULLTEXT]

COLTYPE_WITH_ONE_INTEGER_REGEX = re.compile(r"^([A-z]+)\((\d+)\)$")
# ... start, group(alphabetical), literal (, group(digit), literal ), end


# def combine_db_schema_table(db: Optional[str],
#                             schema: Optional[str],
#                             table: str) -> str:
#     # ANSI SQL: http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
#     # <table name>, <qualified name>
#     if not table:
#         raise ValueError("Missing table supplied to combine_db_schema_table")
#     return ".".join(x for x in [db, schema, table] if x)


# =============================================================================
# SQL elements: identifiers
# =============================================================================

@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class SchemaId(object):
    def __init__(self, db: str = '', schema: str = '') -> None:
        assert "." not in db, (
            "Bad database name ({!r}); can't include '.'".format(db)
        )
        assert "." not in schema, (
            "Bad schema name ({!r}); can't include '.'".format(schema)
        )
        self._db = db
        self._schema = schema

    @property
    def schema_tag(self) -> str:
        """
        String suitable for encoding the SchemaId e.g. in a single HTML form.
        The __init__ function checks the assumption of no '.' characters in
        either part.
        """
        return "{}.{}".format(self._db, self._schema)

    @classmethod
    def from_schema_tag(cls, tag: str) -> 'SchemaId':
        parts = tag.split(".")
        assert len(parts) == 2, "Bad schema tag {!r}".format(tag)
        db, schema = parts
        return SchemaId(db, schema)

    def __bool__(self) -> bool:
        return bool(self._schema)

    def __eq__(self, other: 'SchemaId') -> bool:
        return (  # ordering is for speed
            self._schema == other._schema and
            self._db == other._db
        )

    def __lt__(self, other: 'SchemaId') -> bool:
        return (
            (self._db, self._schema) <
            (other._db, other._schema)
        )

    def __hash__(self) -> int:
        return hash(str(self))

    def identifier(self, grammar: SqlGrammar) -> str:
        return make_identifier(grammar,
                               database=self._db,
                               schema=self._schema)

    def table_id(self, table: str) -> 'TableId':
        return TableId(db=self._db, schema=self._schema, table=table)

    def column_id(self, table: str, column: str) -> 'ColumnId':
        return ColumnId(db=self._db, schema=self._schema,
                        table=table, column=column)

    @property
    def db(self) -> str:
        return self._db

    @property
    def schema(self) -> str:
        return self._schema

    def __str__(self) -> str:
        return self.identifier(mysql_grammar)  # specific one unimportant

    def __repr__(self) -> str:
        return mapped_repr_stripping_underscores(
            self, ['_db', '_schema'])


@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class TableId(object):
    def __init__(self, db: str = '', schema: str = '',
                 table: str = '') -> None:
        self._db = db
        self._schema = schema
        self._table = table

    def __bool__(self) -> bool:
        return bool(self._table)

    def __eq__(self, other: 'TableId') -> bool:
        return (  # ordering is for speed
            self._table == other._table and
            self._schema == other._schema and
            self._db == other._db
        )

    def __lt__(self, other: 'TableId') -> bool:
        return (
            (self._db, self._schema, self._table) <
            (other._db, other._schema, other._table)
        )

    def __hash__(self) -> int:
        return hash(str(self))

    def identifier(self, grammar: SqlGrammar) -> str:
        return make_identifier(grammar,
                               database=self._db,
                               schema=self._schema,
                               table=self._table)

    @property
    def schema_id(self) -> SchemaId:
        return SchemaId(db=self._db, schema=self._schema)

    def column_id(self, column: str) -> 'ColumnId':
        return ColumnId(db=self._db, schema=self._schema,
                        table=self._table, column=column)

    def database_schema_part(self, grammar: SqlGrammar) -> str:
        return make_identifier(grammar,
                               database=self._db,
                               schema=self._schema)

    def table_part(self, grammar: SqlGrammar) -> str:
        return make_identifier(grammar, table=self._table)

    @property
    def db(self) -> str:
        return self._db

    @property
    def schema(self) -> str:
        return self._schema

    @property
    def table(self) -> str:
        return self._table

    def __str__(self) -> str:
        return self.identifier(mysql_grammar)  # specific one unimportant

    def __repr__(self) -> str:
        return mapped_repr_stripping_underscores(
            self, ['_db', '_schema', '_table'])


@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class ColumnId(object):
    def __init__(self, db: str = '', schema: str = '',
                 table: str = '', column: str = '') -> None:
        self._db = db
        self._schema = schema
        self._table = table
        self._column = column

    def __bool__(self) -> bool:
        return bool(self._column)

    def __eq__(self, other: 'ColumnId') -> bool:
        return (
            self._column == other._column and
            self._table == other._table and
            self._schema == other._schema and
            self._db == other._db
        )

    def __lt__(self, other: 'ColumnId') -> bool:
        return (
            (self._db, self._schema, self._table, self._column) <
            (other._db, other._schema, other._table, other._column)
        )

    @property
    def is_valid(self) -> bool:
        return bool(self._table and self._column)  # the minimum

    def identifier(self, grammar: SqlGrammar) -> str:
        return make_identifier(grammar,
                               database=self._db,
                               schema=self._schema,
                               table=self._table,
                               column=self._column)

    @property
    def db(self) -> str:
        return self._db

    @property
    def schema(self) -> str:
        return self._schema

    @property
    def table(self) -> str:
        return self._table

    @property
    def column(self) -> str:
        return self._column

    @property
    def schema_id(self) -> SchemaId:
        return SchemaId(db=self._db, schema=self._schema)

    @property
    def table_id(self) -> TableId:
        return TableId(db=self._db, schema=self._schema, table=self._table)

    @property
    def has_table_and_column(self) -> bool:
        return bool(self._table and self._column)

    def __str__(self) -> str:
        return self.identifier(mysql_grammar)  # specific one unimportant

    def __repr__(self) -> str:
        return mapped_repr_stripping_underscores(
            self, ['_db', '_schema', '_table', '_column'])

    # def html(self, grammar: SqlGrammar, bold_column: bool = True) -> str:
    #     components = [
    #         html.escape(grammar.quote_identifier_if_required(x))
    #         for x in [self._db, self._schema, self._table, self._column]
    #         if x]
    #     if not components:
    #         return ''
    #     if bold_column:
    #         components[-1] = "<b>{}</b>".format(components[-1])
    #     return ".".join(components)


def split_db_schema_table(db_schema_table: str) -> TableId:
    components = db_schema_table.split('.')
    if len(components) == 3:  # db.schema.table
        d, s, t = components[0], components[1], components[2]
    elif len(components) == 2:  # schema.table
        d, s, t = '', components[0], components[1]
    elif len(components) == 1:  # table
        d, s, t = '', '', components[0]
    else:
        raise ValueError("Bad db_schema_table: {}".format(db_schema_table))
    return TableId(db=d, schema=s, table=t)


def split_db_schema_table_column(db_schema_table_col: str) -> ColumnId:
    components = db_schema_table_col.split('.')
    if len(components) == 4:  # db.schema.table.column
        d, s, t, c = components[0], components[1], components[2], components[3]
    elif len(components) == 3:  # schema.table.column
        d, s, t, c = '', components[0], components[1], components[2]
    elif len(components) == 2:  # table.column
        d, s, t, c = '', '', components[0], components[1]
    elif len(components) == 1:  # column
        d, s, t, c = '', '', '', components[0]
    else:
        raise ValueError("Bad db_schema_table_col: {}".format(
            db_schema_table_col))
    return ColumnId(db=d, schema=s, table=t, column=c)


def columns_to_table_column_hierarchy(
        columns: List[ColumnId],
        sort: bool = True) -> List[Tuple[TableId, List[ColumnId]]]:
    tables = unique_list(c.table_id for c in columns)
    if sort:
        tables.sort()
    table_column_map = []
    for t in tables:
        t_columns = [c for c in columns if c.table_id == t]
        if sort:
            t_columns.sort()
        table_column_map.append((t, t_columns))
    return table_column_map


# =============================================================================
# Using SQL grammars (but without reference to Django models, for testing)
# =============================================================================

def make_identifier(grammar: SqlGrammar,
                    database: str = None,
                    schema: str = None,
                    table: str = None,
                    column: str = None) -> str:
    elements = [grammar.quote_identifier_if_required(x)
                for x in (database, schema, table, column) if x]
    assert elements, "make_identifier(): No elements passed!"
    return ".".join(elements)


def dumb_make_identifier(database: str = None,
                         schema: str = None,
                         table: str = None,
                         column: str = None) -> str:
    elements = filter(None, [database, schema, table, column])
    assert elements, "make_identifier(): No elements passed!"
    return ".".join(elements)


def parser_add_result_column(parsed: ParseResults,
                             column: str,
                             grammar: SqlGrammar) -> ParseResults:
    # Presupposes at least one column already in the SELECT statement.
    # log.critical("Adding: " + repr(column))
    existing_columns = parsed.select_expression.select_columns.asList()
    # log.critical(parsed.dump())
    # log.critical("existing columns: {}".format(repr(existing_columns)))
    # log.critical("adding column: {}".format(column))
    if column not in existing_columns:
        # log.critical("... doesn't exist; adding")
        newcol = grammar.get_result_column().parseString(column,
                                                         parseAll=True)
        # log.critical("... " + repr(newcol))
        parsed.select_expression.extend([",", newcol])
    # else:
    #     log.critical("... skipping column; exists")
    # log.critical(parsed.dump())
    return parsed


class JoinInfo(object):
    def __init__(self,
                 table: str,
                 join_type: str = 'INNER JOIN',
                 join_condition: str = '') -> None:  # e.g. "ON x = y"
        self.join_type = join_type
        self.table = table
        self.join_condition = join_condition


[docs]def parser_add_from_tables(parsed: ParseResults, join_info_list: List[JoinInfo], grammar: SqlGrammar) -> ParseResults: """ Presupposes at least one table already in the FROM clause. """ # log.critical(parsed.dump()) existing_tables = parsed.join_source.from_tables.asList() # log.critical("existing tables: {}".format(existing_tables)) # log.critical("adding table: {}".format(table)) for ji in join_info_list: if ji.table in existing_tables: # already there # log.critical("field already present") continue parsed_join = grammar.get_join_op().parseString(ji.join_type, parseAll=True)[0] # e.g. INNER JOIN # noqa parsed_table = grammar.get_table_spec().parseString(ji.table, parseAll=True)[0] extrabits = [parsed_join, parsed_table] if ji.join_condition: # e.g. ON x = y extrabits.append( grammar.get_join_constraint().parseString(ji.join_condition, parseAll=True)[0]) parsed.join_source.extend(extrabits) # log.critical(parsed.dump()) return parsed
[docs]def get_first_from_table(parsed: ParseResults, match_db: str = '', match_schema: str = '', match_table: str = '') -> TableId: """ Given a set of parsed results from a SELECT statement, returns the (db, schema, table) tuple representing the first table in the FROM clause. Optionally, the match may be constrained with the match* parameters. """ existing_tables = parsed.join_source.from_tables.asList() for t in existing_tables: table_id = split_db_schema_table(t) if match_db and table_id.db != match_db: continue if match_schema and table_id.schema != match_schema: continue if match_table and table_id.table != match_table: continue return table_id return TableId()
def set_distinct_within_parsed(p: ParseResults, action: str = 'set') -> None: ss = p.select_specifier # type: ParseResults if action == 'set': if 'DISTINCT' not in ss.asList(): ss.append('DISTINCT') elif action == 'clear': if 'DISTINCT' in ss.asList(): del ss[:] elif action == 'toggle': if 'DISTINCT' in ss.asList(): del ss[:] else: ss.append('DISTINCT') else: raise ValueError("action must be one of set/clear/toggle") def set_distinct(sql: str, grammar: SqlGrammar, action: str = 'set', formatted: bool = True, debug: bool = False, debug_verbose: bool = False) -> str: p = grammar.get_select_statement().parseString(sql, parseAll=True) if debug: log.info("START: {}".format(sql)) if debug_verbose: log.debug("start dump:\n" + p.dump()) set_distinct_within_parsed(p, action=action) result = text_from_parsed(p, formatted=formatted) if debug: log.info("END: {}".format(result)) if debug_verbose: log.debug("end dump:\n" + p.dump()) return result def toggle_distinct(sql: str, grammar: SqlGrammar, formatted: bool = True, debug: bool = False, debug_verbose: bool = False) -> str: return set_distinct(sql=sql, grammar=grammar, action='toggle', formatted=formatted, debug=debug, debug_verbose=debug_verbose) # ============================================================================= # SQLAlchemy reflection and DDL # ============================================================================= _print_not_execute = False def set_print_not_execute(print_not_execute: bool) -> None: global _print_not_execute _print_not_execute = print_not_execute def execute(engine: Engine, sql: str) -> None: log.debug(sql) if _print_not_execute: print(format_sql_for_print(sql) + "\n;") # extra \n in case the SQL ends in a comment else: engine.execute(sql) def add_columns(engine: Engine, table: Table, columns: List[Column]) -> None: existing_column_names = get_column_names(engine, tablename=table.name, to_lower=True) column_defs = [] for column in columns: if column.name.lower() not in existing_column_names: column_defs.append(column_creation_ddl(column, engine.dialect)) else: log.debug("Table {}: column {} already exists; not adding".format( repr(table.name), repr(column.name))) # ANSI SQL: add one column at a time: ALTER TABLE ADD [COLUMN] coldef # - i.e. "COLUMN" optional, one at a time, no parentheses # - http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt # MySQL: ALTER TABLE ADD [COLUMN] (a INT, b VARCHAR(32)); # - i.e. "COLUMN" optional, parentheses required for >1, multiple OK # - http://dev.mysql.com/doc/refman/5.7/en/alter-table.html # MS SQL Server: ALTER TABLE ADD COLUMN a INT, B VARCHAR(32); # - i.e. no "COLUMN", no parentheses, multiple OK # - https://msdn.microsoft.com/en-us/library/ms190238.aspx # - https://msdn.microsoft.com/en-us/library/ms190273.aspx # - http://stackoverflow.com/questions/2523676 # SQLAlchemy doesn't provide a shortcut for this. for column_def in column_defs: log.info("Table {}: adding column {}".format( repr(table.name), repr(column_def))) execute(engine, """ ALTER TABLE {tablename} ADD {column_def} """.format(tablename=table.name, column_def=column_def)) def drop_columns(engine: Engine, table: Table, column_names: Iterable[str]) -> None: existing_column_names = get_column_names(engine, tablename=table.name, to_lower=True) for name in column_names: if name.lower() not in existing_column_names: log.debug("Table {}: column {} does not exist; not " "dropping".format(repr(table.name), repr(name))) else: log.info("Table {}: dropping column {}".format( repr(table.name), repr(name))) sql = "ALTER TABLE {t} DROP COLUMN {c}".format(t=table.name, c=name) # SQL Server: # http://www.techonthenet.com/sql_server/tables/alter_table.php # MySQL: # http://dev.mysql.com/doc/refman/5.7/en/alter-table.html execute(engine, sql) def add_indexes(engine: Engine, table: Table, indexdictlist: Iterable[Dict[str, Any]]) -> None: existing_index_names = get_index_names(engine, tablename=table.name, to_lower=True) for idxdefdict in indexdictlist: index_name = idxdefdict['index_name'] column = idxdefdict['column'] if not isinstance(column, str): column = ", ".join(column) # must be a list unique = idxdefdict.get('unique', False) if index_name.lower() not in existing_index_names: log.info("Table {}: adding index {} on column {}".format( repr(table.name), repr(index_name), repr(column))) execute(engine, """ CREATE{unique} INDEX {idxname} ON {tablename} ({column}) """.format( unique=" UNIQUE" if unique else "", idxname=index_name, tablename=table.name, column=column, )) else: log.debug("Table {}: index {} already exists; not adding".format( repr(table.name), repr(index_name))) def drop_indexes(engine: Engine, table: Table, index_names: Iterable[str]) -> None: existing_index_names = get_index_names(engine, tablename=table.name, to_lower=True) for index_name in index_names: if index_name.lower() not in existing_index_names: log.debug("Table {}: index {} does not exist; not dropping".format( repr(table.name), repr(index_name))) else: log.info("Table {}: dropping index {}".format( repr(table.name), repr(index_name))) if engine.dialect.name == 'mysql': sql = "ALTER TABLE {t} DROP INDEX {i}".format(t=table.name, i=index_name) elif engine.dialect.name == 'mssql': sql = "DROP INDEX {t}.{i}".format(t=table.name, i=index_name) else: assert False, "Unknown dialect: {}".format(engine.dialect.name) execute(engine, sql) def get_table_names(engine: Engine, to_lower: bool = False, sort: bool = False) -> List[str]: inspector = inspect(engine) table_names = inspector.get_table_names() if to_lower: table_names = [x.lower() for x in table_names] if sort: table_names = sorted(table_names, key=lambda x: x.lower()) return table_names def get_view_names(engine: Engine, to_lower: bool = False, sort: bool = False) -> List[str]: inspector = inspect(engine) view_names = inspector.get_view_names() if to_lower: view_names = [x.lower() for x in view_names] if sort: view_names = sorted(view_names, key=lambda x: x.lower()) return view_names
[docs]def get_column_names(engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False) -> List[str]: """ Reads columns names afresh from the database (in case metadata is out of date). """ inspector = inspect(engine) columns = inspector.get_columns(tablename) column_names = [x['name'] for x in columns] if to_lower: column_names = [x.lower() for x in column_names] if sort: column_names = sorted(column_names, key=lambda x: x.lower()) return column_names
[docs]def get_index_names(engine: Engine, tablename: str, to_lower: bool = False) -> List[str]: """ Reads index names from the database. """ # http://docs.sqlalchemy.org/en/latest/core/reflection.html inspector = inspect(engine) indexes = inspector.get_indexes(tablename) index_names = [x['name'] for x in indexes if x['name']] # ... at least for SQL Server, there always seems to be a blank one # with {'name': None, ...}. if to_lower: index_names = [x.lower() for x in index_names] return index_names
def ensure_columns_present(engine: Engine, tablename: str, column_names: Iterable[str]) -> None: existing_column_names = get_column_names(engine, tablename=tablename, to_lower=True) if not column_names: return for col in column_names: if col.lower() not in existing_column_names: raise ValueError( "Column {} missing from table {}, whose columns are {}".format( repr(col), repr(tablename), repr(existing_column_names))) def create_view(engine: Engine, viewname: str, select_sql: str) -> None: if engine.dialect.name == 'mysql': # MySQL has CREATE OR REPLACE VIEW. sql = "CREATE OR REPLACE VIEW {viewname} AS {select_sql}".format( viewname=viewname, select_sql=select_sql, ) else: # SQL Server doesn't: http://stackoverflow.com/questions/18534919 drop_view(engine, viewname, quiet=True) sql = "CREATE VIEW {viewname} AS {select_sql}".format( viewname=viewname, select_sql=select_sql, ) log.info("Creating view: {}".format(repr(viewname))) execute(engine, sql) def assert_view_has_same_num_rows(engine: Engine, basetable: str, viewname: str) -> None: # Note that this relies on the data, i.e. design failures MAY cause this # assertion to fail, but won't necessarily (e.g. if the table is empty). n_base = count_star(engine, basetable) n_view = count_star(engine, viewname) assert n_view == n_base, ( "View bug: view {} has {} records but its base table {} " "has {}; they should be equal".format( viewname, n_view, basetable, n_base)) def drop_view(engine: Engine, viewname: str, quiet: bool = False) -> None: # MySQL has DROP VIEW IF EXISTS, but SQL Server only has that from # SQL Server 2016 onwards. # - https://msdn.microsoft.com/en-us/library/ms173492.aspx # - http://dev.mysql.com/doc/refman/5.7/en/drop-view.html view_names = get_view_names(engine, to_lower=True) if viewname.lower() not in view_names: log.debug("View {} does not exist; not dropping".format(viewname)) else: if not quiet: log.info("Dropping view: {}".format(repr(viewname))) sql = "DROP VIEW {viewname}".format(viewname=viewname) execute(engine, sql) # ============================================================================= # View-building assistance class # ============================================================================= class ViewMaker(object): def __init__(self, viewname: str, engine: Engine, basetable: str, existing_to_lower: bool = False, rename: Dict[str, str] = None, progargs: argparse.Namespace = None, enforce_same_n_rows_as_base: bool = True, insert_basetable_columns: bool = True) -> None: rename = rename or {} assert basetable, "ViewMaker: basetable missing!" self.viewname = viewname self.engine = engine self.basetable = basetable self.progargs = progargs # only for others' benefit self.enforce_same_n_rows_as_base = enforce_same_n_rows_as_base self.select_elements = [] self.from_elements = [basetable] self.where_elements = [] self.lookup_tables = [] # type: List[str] self.index_requests = OrderedDict() if insert_basetable_columns: grammar = make_grammar(engine.dialect.name) def q(identifier: str) -> str: return grammar.quote_identifier_if_required(identifier) for colname in get_column_names(engine, tablename=basetable, to_lower=existing_to_lower): if colname in rename: rename_to = rename[colname] if not rename_to: continue as_clause = " AS {}".format(q(rename_to)) else: as_clause = "" self.select_elements.append("{t}.{c}{as_clause}".format( t=q(basetable), c=q(colname), as_clause=as_clause)) assert self.select_elements, "Must have some active SELECT " \ "elements from base table" def add_select(self, clause: str) -> None: self.select_elements.append(clause) def add_from(self, clause: str) -> None: self.from_elements.append(clause) def add_where(self, clause: str) -> None: self.where_elements.append(clause) def get_sql(self) -> str: assert self.select_elements, "ViewMaker: no SELECT elements!" if self.where_elements: where = "\n WHERE {}".format( "\n AND ".join(self.where_elements)) else: where = "" return ( "\n SELECT {select_elements}" "\n FROM {from_elements}{where}".format( select_elements=",\n ".join(self.select_elements), from_elements="\n ".join(self.from_elements), where=where)) def create_view(self, engine: Engine) -> None: create_view(engine, self.viewname, self.get_sql()) if self.enforce_same_n_rows_as_base: assert_view_has_same_num_rows(engine, self.basetable, self.viewname) def drop_view(self, engine: Engine) -> None: drop_view(engine, self.viewname) def record_lookup_table(self, table: str) -> None: if table not in self.lookup_tables: self.lookup_tables.append(table) def request_index(self, table: str, column: str) -> None: if table not in self.index_requests: self.index_requests[table] = [] # type: List[str] self.index_requests[table].append(column) def record_lookup_table_keyfield( self, table: str, keyfield: Union[str, Iterable[str]]) -> None: if isinstance(keyfield, str): keyfield = [keyfield] self.record_lookup_table(table) for kf in keyfield: self.request_index(table, kf) def record_lookup_table_keyfields( self, table_keyfield_tuples: Iterable[ Tuple[str, Union[str, Iterable[str]]] ]) -> None: for t, k in table_keyfield_tuples: self.record_lookup_table_keyfield(t, k) def get_lookup_tables(self) -> List[str]: return self.lookup_tables def get_index_request_dict(self) -> Dict[str, List[str]]: return self.index_requests # ============================================================================= # Transaction size-limiting class # ============================================================================= class TransactionSizeLimiter(object): def __init__(self, session: Session, max_rows_before_commit: int = None, max_bytes_before_commit: int = None) -> None: self._session = session self._max_rows_before_commit = max_rows_before_commit self._max_bytes_before_commit = max_bytes_before_commit self._bytes_in_transaction = 0 self._rows_in_transaction = 0 def commit(self) -> None: with MultiTimerContext(timer, TIMING_COMMIT): self._session.commit() self._bytes_in_transaction = 0 self._rows_in_transaction = 0 def notify(self, n_rows: int, n_bytes: int, force_commit: bool=False) -> None: if force_commit: self.commit() return self._bytes_in_transaction += n_bytes self._rows_in_transaction += n_rows # log.critical( # "adding {} rows, {} bytes, " # "to make {} rows, {} bytes so far".format( # n_rows, n_bytes, # self._rows_in_transaction, self._bytes_in_transaction)) if (self._max_bytes_before_commit is not None and self._bytes_in_transaction >= self._max_bytes_before_commit): log.info( "Triggering early commit based on byte count (reached {}, " "limit is {})".format( sizeof_fmt(self._bytes_in_transaction), sizeof_fmt(self._max_bytes_before_commit))) self.commit() elif (self._max_rows_before_commit is not None and self._rows_in_transaction >= self._max_rows_before_commit): log.info( "Triggering early commit based on row count (reached {} rows, " "limit is {})".format(self._rows_in_transaction, self._max_rows_before_commit)) self.commit() # ============================================================================= # Specification matching # ============================================================================= def _matches_tabledef(table: str, tabledef: str) -> bool: tr = get_spec_match_regex(tabledef) return tr.match(table) def matches_tabledef(table: str, tabledef: Union[str, List[str]]) -> bool: if isinstance(tabledef, str): return _matches_tabledef(table, tabledef) elif not tabledef: return False else: # list return any(_matches_tabledef(table, td) for td in tabledef) def _matches_fielddef(table: str, field: str, fielddef: str) -> bool: column_id = split_db_schema_table_column(fielddef) cr = get_spec_match_regex(column_id.column) if not column_id.table: return cr.match(field) tr = get_spec_match_regex(column_id.table) return tr.match(table) and cr.match(field) def matches_fielddef(table: str, field: str, fielddef: Union[str, List[str]]) -> bool: if isinstance(fielddef, str): return _matches_fielddef(table, field, fielddef) elif not fielddef: return False else: # list return any(_matches_fielddef(table, field, fd) for fd in fielddef) # ============================================================================= # More SQL # =============================================================================
[docs]def sql_fragment_cast_to_int(expr: str, big: bool = True, dialect: Dialect = None, viewmaker: ViewMaker = None) -> str: """ For Microsoft SQL Server. Conversion to INT: - http://stackoverflow.com/questions/2000045 - http://stackoverflow.com/questions/14719760 # this one - http://stackoverflow.com/questions/14692131 - see LIKE example. - see ISNUMERIC(); https://msdn.microsoft.com/en-us/library/ms186272.aspx ... but that includes non-integer numerics - https://msdn.microsoft.com/en-us/library/ms174214(v=sql.120).aspx ... relates to the SQL Server Management Studio "Find and Replace" dialogue box, not to SQL itself! - http://stackoverflow.com/questions/29206404/mssql-regular-expression Note that the regex-like expression supported by LIKE is extremely limited. - https://msdn.microsoft.com/en-us/library/ms179859.aspx The only things supported are: .. code-block:: none % any characters _ any single character [] single character in range or set, e.g. [a-f], [abcdef] [^] single character NOT in range or set, e.g. [^a-f], [abcdef] SQL Server does not support a REGEXP command directly. So the best bet is to have the LIKE clause check for a non-integer: .. code-block:: sql CASE WHEN something LIKE '%[^0-9]%' THEN NULL ELSE CAST(something AS BIGINT) END ... which doesn't deal with spaces properly, but there you go. Could also strip whitespace left/right: .. code-block:: sql CASE WHEN LTRIM(RTRIM(something)) LIKE '%[^0-9]%' THEN NULL ELSE CAST(something AS BIGINT) END Only works for positive integers. LTRIM/RTRIM are not ANSI SQL. Nor are unusual LIKE clauses; see http://stackoverflow.com/questions/712580/list-of-special-characters-for-sql-like-clause The other, for SQL Server 2012 or higher, is TRY_CAST: .. code-block:: sql TRY_CAST(something AS BIGINT) ... which returns NULL upon failure; see https://msdn.microsoft.com/en-us/library/hh974669.aspx """ # noqa inttype = "BIGINT" if big else "INTEGER" if dialect is None and viewmaker is not None: dialect = viewmaker.engine.dialect if dialect is None: sql_server = True supports_try_cast = False else: # noinspection PyUnresolvedReferences sql_server = dialect.name == 'mssql' # noinspection PyUnresolvedReferences supports_try_cast = (sql_server and dialect.server_version_info >= MS_2012_VERSION) if supports_try_cast: return "TRY_CAST({expr} AS {inttype})".format(expr=expr, inttype=inttype) elif sql_server: return ( "CASE WHEN LTRIM(RTRIM({expr})) LIKE '%[^0-9]%' " "THEN NULL ELSE CAST({expr} AS {inttype}) END".format( expr=expr, inttype=inttype) ) # Doesn't support negative integers. else: # noinspection PyUnresolvedReferences raise ValueError("Code not yet written for convert-to-int for " "dialect {}".format(dialect.name))
# ============================================================================= # Abstracted SQL WHERE condition # ============================================================================= @register_for_json(method=METHOD_PROVIDES_INIT_KWARGS) @functools.total_ordering class WhereCondition(object): # Ancillary class for building SQL WHERE expressions from our web forms. def __init__(self, column_id: ColumnId = None, op: str = '', datatype: str = '', value_or_values: Any = None, raw_sql: str = '', from_table_for_raw_sql: TableId = None) -> None: self._column_id = column_id self._op = op.upper() self._datatype = datatype self._value = value_or_values self._no_value = False self._multivalue = False self._raw_sql = raw_sql self._from_table_for_raw_sql = from_table_for_raw_sql if not self._raw_sql: if self._op in SQL_OPS_VALUE_UNNECESSARY: self._no_value = True assert value_or_values is None, "Superfluous value passed" elif self._op in SQL_OPS_MULTIPLE_VALUES: self._multivalue = True assert isinstance(value_or_values, list), "Need list" else: assert not isinstance(value_or_values, list), "Need single value" # noqa def init_kwargs(self) -> Dict: return { 'column_id': self._column_id, 'op': self._op, 'datatype': self._datatype, 'value_or_values': self._value, 'raw_sql': self._raw_sql, 'from_table_for_raw_sql': self._from_table_for_raw_sql, } def __repr__(self) -> str: return ( "<{qualname}(" "column_id={column_id}, " "op={op}, " "datatype={datatype}, " "value_or_values={value_or_values}, " "raw_sql={raw_sql}, " "from_table_for_raw_sql={from_table_for_raw_sql}" ") at {addr}>".format( qualname=self.__class__.__qualname__, column_id=repr(self._column_id), op=repr(self._op), datatype=repr(self._datatype), value_or_values=repr(self._value), raw_sql=repr(self._raw_sql), from_table_for_raw_sql=repr(self._from_table_for_raw_sql), addr=hex(id(self)), ) ) def __eq__(self, other: 'WhereCondition') -> bool: return ( self._raw_sql == other._raw_sql and self._column_id == other._column_id and self._op == other._op and self._value == other._value ) def __lt__(self, other: 'WhereCondition') -> bool: return ( (self._raw_sql, self._column_id, self._op, self._value) < (other._raw_sql, other._column_id, other._op, other._value) ) @property def column_id(self) -> ColumnId: return self._column_id @property def table_id(self) -> TableId: if self._raw_sql: return self._from_table_for_raw_sql return self.column_id.table_id def table_str(self, grammar: SqlGrammar) -> str: return self.table_id.identifier(grammar) def sql(self, grammar: SqlGrammar) -> str: if self._raw_sql: return self._raw_sql col = self._column_id.identifier(grammar) op = self._op if self._no_value: return "{col} {op}".format(col=col, op=op) if self._datatype in QB_STRING_TYPES: element_converter = sql_string_literal elif self._datatype == QB_DATATYPE_DATE: element_converter = sql_date_literal elif self._datatype == QB_DATATYPE_INTEGER: element_converter = str elif self._datatype == QB_DATATYPE_FLOAT: element_converter = str else: # Safe default element_converter = sql_string_literal if self._multivalue: literal = "({})".format(", ".join(element_converter(v) for v in self._value)) else: literal = element_converter(self._value) if self._op == 'MATCH': # MySQL return "MATCH ({col}) AGAINST ({val})".format(col=col, val=literal) elif self._op == 'CONTAINS': # SQL Server return "CONTAINS({col}, {val})".format(col=col, val=literal) else: return "{col} {op} {val}".format(col=col, op=op, val=literal) # ============================================================================= # SQL formatting # ============================================================================= def format_sql_for_print(sql: str) -> str: # Remove blank lines and trailing spaces lines = list(filter(None, [x.replace("\t", " ").rstrip() for x in sql.splitlines()])) # Shift all lines left if they're left-padded firstleftpos = float('inf') for line in lines: leftpos = len(line) - len(line.lstrip()) firstleftpos = min(firstleftpos, leftpos) if firstleftpos > 0: lines = [x[firstleftpos:] for x in lines] return "\n".join(lines) # ============================================================================= # Plain SQL types # ============================================================================= def is_sql_column_type_textual(column_type: str, min_length: int = 1) -> bool: column_type = column_type.upper() if column_type in SQLTYPES_TEXT: # A text type without a specific length return True try: m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type) basetype = m.group(1) length = int(m.group(2)) except (AttributeError, ValueError): return False return length >= min_length and basetype in SQLTYPES_TEXT
[docs]def escape_quote_in_literal(s: str) -> str: """ Escape '. We could use '' or \'. Let's use \. for consistency with percent escaping. """ return s.replace("'", r"\'")
[docs]def escape_percent_in_literal(sql: str) -> str: """ Escapes % by converting it to \%. Use this for LIKE clauses. http://dev.mysql.com/doc/refman/5.7/en/string-literals.html """ return sql.replace('%', r'\%')
[docs]def escape_percent_for_python_dbapi(sql: str) -> str: """ Escapes % by converting it to %%. Use this for SQL within Python where % characters are used for argument placeholders. """ return sql.replace('%', '%%')
[docs]def escape_sql_string_literal(s: str) -> str: """ Escapes SQL string literal fragments against quotes and parameter substitution. """ return escape_percent_in_literal(escape_quote_in_literal(s))
def make_string_literal(s: str) -> str: return "'{}'".format(escape_sql_string_literal(s)) def escape_sql_string_or_int_literal(s: Union[str, int]) -> str: if isinstance(s, int): return str(s) else: return make_string_literal(s)
[docs]def translate_sql_qmark_to_percent(sql: str) -> str: """ MySQL likes '?' as a placeholder. - https://dev.mysql.com/doc/refman/5.7/en/sql-syntax-prepared-statements.html Python DBAPI allows several: '%s', '?', ':1', ':name', '%(name)s'. - https://www.python.org/dev/peps/pep-0249/#paramstyle Django uses '%s'. - https://docs.djangoproject.com/en/1.8/topics/db/sql/ Microsoft like '?', '@paramname', and ':paramname'. - https://msdn.microsoft.com/en-us/library/yy6y35y8(v=vs.110).aspx We need to parse SQL with argument placeholders. - See SqlGrammar classes, particularly: bind_parameter I prefer ?, because % is used in LIKE clauses, and the databases we're using like it. So: - We use %s when using cursor.execute() directly, via Django. - We use ? when talking to users, and SqlGrammar objects, so that the visual appearance matches what they expect from their database. This function translates SQL using ? placeholders to SQL using %s placeholders, without breaking literal '?' or '%', e.g. inside string literals. """ # noqa # 1. Escape % characters sql = escape_percent_for_python_dbapi(sql) # 2. Replace ? characters that are not within quotes with %s. newsql = "" in_quotes = False for c in sql: if c == "'": in_quotes = not in_quotes if c == '?' and not in_quotes: newsql += '%s' else: newsql += c return newsql
_ = """ _SQLTEST1 = "SELECT a FROM b WHERE c=? AND d LIKE 'blah%' AND e='?'" _SQLTEST2 = "SELECT a FROM b WHERE c=%s AND d LIKE 'blah%%' AND e='?'" _SQLTEST3 = translate_sql_qmark_to_percent(_SQLTEST1) """ # ============================================================================= # Tests # ============================================================================= def unit_tests(): assert matches_tabledef("sometable", "sometable") assert matches_tabledef("sometable", "some*") assert matches_tabledef("sometable", "*table") assert matches_tabledef("sometable", "*") assert matches_tabledef("sometable", "s*e") assert not matches_tabledef("sometable", "x*y") assert matches_fielddef("sometable", "somefield", "*.somefield") assert matches_fielddef("sometable", "somefield", "sometable.somefield") assert matches_fielddef("sometable", "somefield", "sometable.*") assert matches_fielddef("sometable", "somefield", "somefield") grammar = make_grammar(SqlaDialectName.MYSQL) sql = "SELECT t1.c1, t2.c2 " \ "FROM t1 INNER JOIN t2 ON t1.k = t2.k" parsed = grammar.get_select_statement().parseString(sql, parseAll=True) table_id = get_first_from_table(parsed) # noqa log.info(repr(table_id)) if __name__ == '__main__': main_only_quicksetup_rootlogger() unit_tests()