#!/usr/bin/env python
# crate_anon/crateweb/research/sql_writer.py
"""
===============================================================================
Copyright (C) 2015-2018 Rudolf Cardinal (rudolf@pobox.com).
This file is part of CRATE.
CRATE is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
CRATE is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with CRATE. If not, see <http://www.gnu.org/licenses/>.
===============================================================================
"""
import logging
from typing import List, Optional
from cardinal_pythonlib.logs import main_only_quicksetup_rootlogger
from cardinal_pythonlib.sql.sql_grammar import (
format_sql,
SqlGrammar,
text_from_parsed,
)
from cardinal_pythonlib.sql.sql_grammar_factory import make_grammar
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from pyparsing import ParseResults
from crate_anon.common.sql import (
ColumnId,
get_first_from_table,
JoinInfo,
parser_add_result_column,
parser_add_from_tables,
set_distinct_within_parsed,
TableId,
WhereCondition,
)
from crate_anon.crateweb.research.research_db_info import (
research_database_info,
)
log = logging.getLogger(__name__)
# =============================================================================
# Automagically create/manipulate SQL statements based on our extra knowledge
# of the fields that can be used to link across tables/databases.
# =============================================================================
def get_join_info(grammar: SqlGrammar,
parsed: ParseResults,
jointable: TableId,
magic_join: bool = False,
nonmagic_join_type: str = "INNER JOIN",
nonmagic_join_condition: str = '') -> List[JoinInfo]:
# Returns e.g. ["INNER JOIN", "tablename", "WHERE somecondition"].
# INNER JOIN etc. is part of ANSI SQL
first_from_table = get_first_from_table(parsed)
from_table_in_join_schema = get_first_from_table(
parsed,
match_db=jointable.db,
match_schema=jointable.schema)
exact_match_table = get_first_from_table(
parsed,
match_db=jointable.db,
match_schema=jointable.schema,
match_table=jointable.table)
if not first_from_table:
# No tables in query yet.
# This should not happen; this function is to help with adding
# new FROM tables to existing FROM clauses.
log.warning("get_join_info: no tables in query")
return []
if exact_match_table:
# This table is already in the query. No JOIN should be required.
# log.critical("get_join_info: same table already in query")
return []
if not magic_join:
# log.critical("get_join_info: non-magic join")
return [JoinInfo(join_type=nonmagic_join_type,
table=jointable.identifier(grammar),
join_condition=nonmagic_join_condition)]
if from_table_in_join_schema:
# Another table from the same database is present. Link on the
# TRID field.
# log.critical("get_join_info: joining to another table in same DB")
return [JoinInfo(
join_type='INNER JOIN',
table=jointable.identifier(grammar),
join_condition="ON {new} = {existing}".format(
new=research_database_info.get_trid_column(
jointable).identifier(grammar),
existing=research_database_info.get_trid_column(
from_table_in_join_schema).identifier(grammar),
)
)]
# OK. So now we're building a cross-database join.
existing_family = research_database_info.get_dbinfo_by_schema_id(
first_from_table.schema_id).rid_family
new_family = research_database_info.get_dbinfo_by_schema_id(
jointable.schema_id).rid_family
# log.critical("existing_family={}, new_family={}".format(
# existing_family, new_family))
if existing_family and existing_family == new_family:
# log.critical("get_join_info: new DB, same RID family")
return [JoinInfo(
join_type='INNER JOIN',
table=jointable.identifier(grammar),
join_condition="ON {new} = {existing}".format(
new=research_database_info.get_rid_column(
jointable).identifier(grammar),
existing=research_database_info.get_rid_column(
first_from_table).identifier(grammar),
)
)]
# If we get here, we have to do a complicated join via the MRID.
# log.critical("get_join_info: new DB, different RID family, using MRID")
existing_mrid_column = research_database_info.get_mrid_column_from_table(
first_from_table)
existing_mrid_table = existing_mrid_column.table_id
if not existing_mrid_table:
raise ValueError(
"No MRID table available (in the same database as table {}; "
"cannot link)".format(first_from_table))
new_mrid_column = research_database_info.get_mrid_column_from_table(
jointable)
new_mrid_table = new_mrid_column.table_id
existing_mrid_table_in_query = bool(get_first_from_table(
parsed,
match_db=existing_mrid_table.db,
match_schema=existing_mrid_table.schema,
match_table=existing_mrid_table.table))
joins = []
if not existing_mrid_table_in_query:
joins.append(JoinInfo(
join_type='INNER JOIN',
table=existing_mrid_table.identifier(grammar),
join_condition="ON {m1_trid1} = {t1_trid1}".format(
m1_trid1=research_database_info.get_trid_column(
existing_mrid_table).identifier(grammar),
t1_trid1=research_database_info.get_trid_column(
first_from_table).identifier(grammar),
)
))
joins.append(JoinInfo(
join_type='INNER JOIN',
table=new_mrid_table.identifier(grammar),
join_condition="ON {m2_mrid2} = {m1_mrid1}".format(
m2_mrid2=new_mrid_column.identifier(grammar),
m1_mrid1=existing_mrid_column.identifier(grammar),
)
))
if jointable != new_mrid_table:
joins.append(JoinInfo(
join_type='INNER JOIN',
table=jointable.identifier(grammar),
join_condition="ON {t2_trid2} = {m2_trid2}".format(
t2_trid2=research_database_info.get_trid_column(
jointable).identifier(grammar),
m2_trid2=research_database_info.get_trid_column(
new_mrid_table).identifier(grammar),
)
))
return joins
class SelectElement(object):
def __init__(self,
column_id: ColumnId = None,
raw_select: str = '',
from_table_for_raw_select: TableId = None,
alias: str = ''):
self.column_id = column_id
self.raw_select = raw_select
self.from_table_for_raw_select = from_table_for_raw_select
self.alias = alias
def __repr__(self) -> str:
return (
"<{qualname}("
"column_id={column_id}, "
"raw_select={raw_select}, "
"from_table_for_raw_select={from_table_for_raw_select}, "
"alias={alias}) "
"at {addr}>".format(
qualname=self.__class__.__qualname__,
column_id=repr(self.column_id),
raw_select=repr(self.raw_select),
from_table_for_raw_select=repr(self.from_table_for_raw_select),
alias=repr(self.alias),
addr=hex(id(self)),
)
)
def sql_select_column(self, grammar: SqlGrammar) -> str:
result = self.raw_select or self.column_id.identifier(grammar)
if self.alias:
result += " AS " + self.alias
return result
def from_table(self) -> Optional[TableId]:
if self.raw_select:
return self.from_table_for_raw_select
return self.column_id.table_id
def from_table_str(self, grammar: SqlGrammar) -> str:
table_id = self.from_table()
if not table_id:
return ''
return table_id.identifier(grammar)
def sql_select_from(self, grammar: SqlGrammar) -> str:
sql = "SELECT " + self.sql_select_column(grammar=grammar)
from_table = self.from_table()
if from_table:
sql += " FROM " + from_table.identifier(grammar)
return sql
[docs]def reparse_select(p: ParseResults, grammar: SqlGrammar) -> ParseResults:
"""
Internal function for when we get desperate trying to hack around
the results of pyparsing's efforts.
"""
return grammar.get_select_statement().parseString(
text_from_parsed(p, formatted=False),
parseAll=True
)
[docs]def add_to_select(sql: str,
grammar: SqlGrammar,
select_elements: List[SelectElement] = None,
where_conditions: List[WhereCondition] = None,
# For SELECT:
distinct: bool = None, # True, False, or None to leave as is
# For WHERE:
where_type: str = "AND",
bracket_where: bool = False,
# For either, for JOIN:
magic_join: bool = True,
join_type: str = "NATURAL JOIN",
join_condition: str = '',
# General:
formatted: bool = True,
debug: bool = False,
debug_verbose: bool = False) -> str:
"""
This function encapsulates our query builder's common operations.
One premise is that SQL parsing is relatively slow, so we should do this
only once. We parse; add bits to the parsed structure as required; then
re-convert to text.
If you specify table/column, elements will be added to SELECT and FROM
unless they already exist.
If you specify where_expression, elements will be added to WHERE.
In this situation, you should also specify where_table; if the where_table
isn't yet in the FROM clause, this will be added as well.
Parsing is SLOW, so we should do as much as possible in a single call to
this function.
"""
select_elements = select_elements or [] # type: List[SelectElement]
where_conditions = where_conditions or [] # type: List[WhereCondition]
if debug:
log.info("START: {}".format(sql))
log.debug("select_elements: {}".format(select_elements))
log.debug("where_conditions: {}".format(where_conditions))
log.debug("where_type: {}".format(where_type))
log.debug("join_type: {}".format(join_type))
log.debug("join_condition: {}".format(join_condition))
# -------------------------------------------------------------------------
# Get going. We have to handle a fresh SQL statement in a slightly
# different way.
# -------------------------------------------------------------------------
if not sql:
if not select_elements:
raise ValueError("Fresh SQL statements must include a SELECT "
"element")
# ---------------------------------------------------------------------
# Fresh SQL statement
# ---------------------------------------------------------------------
first_select = select_elements[0]
select_elements = select_elements[1:]
sql = first_select.sql_select_from(grammar)
# log.debug("Starting SQL from scratch as: " + sql)
# -------------------------------------------------------------------------
# Parse what we have (which is now, at a minimum, SELECT ... FROM ...).
# -------------------------------------------------------------------------
p = grammar.get_select_statement().parseString(sql, parseAll=True)
if debug and debug_verbose:
log.debug("start dump:\n" + p.dump())
existing_tables = p.join_source.from_tables.asList() # type: List[str]
new_tables = [] # type: List[TableId]
def add_new_table(_table_id: TableId) -> None:
if (_table_id and
_table_id not in new_tables and
_table_id.identifier(grammar) not in existing_tables):
new_tables.append(_table_id)
# -------------------------------------------------------------------------
# DISTINCT?
# -------------------------------------------------------------------------
if distinct is True:
set_distinct_within_parsed(p, action='set')
elif distinct is False:
set_distinct_within_parsed(p, action='clear')
# -------------------------------------------------------------------------
# Process all the (other?) SELECT clauses
# -------------------------------------------------------------------------
for se in select_elements:
p = parser_add_result_column(p, se.sql_select_column(grammar),
grammar=grammar)
add_new_table(se.from_table())
# -------------------------------------------------------------------------
# Process all the WHERE clauses
# -------------------------------------------------------------------------
for wc in where_conditions:
where_expression = wc.sql(grammar)
if bracket_where:
where_expression = '(' + where_expression + ')'
# The tricky bit: inserting it.
# We use the [0] to overcome the effects of defining these things
# as a pyparsing Group(), which encapsulates the results in a list.
if p.where_clause:
cond = grammar.get_expr().parseString(where_expression,
parseAll=True)[0]
extra = [where_type, cond]
p.where_clause.where_expr.extend(extra)
else:
# No WHERE as yet
# Doing this properly is a nightmare.
# It's hard to add a *named* ParseResults element to another.
# So it's very hard to alter p.where_clause.where_expr such that
# we can continue adding more WHERE clauses if we want.
# This is the inefficient, cop-out method:
# (1) Add as plain text
p.where_clause.append("WHERE " + where_expression)
# (2) Reparse...
p = reparse_select(p, grammar=grammar)
add_new_table(wc.table_id)
# -------------------------------------------------------------------------
# Process all the FROM clauses, autojoining as necessary
# -------------------------------------------------------------------------
for table_id in new_tables:
p = parser_add_from_tables(
p,
get_join_info(grammar=grammar,
parsed=p,
jointable=table_id,
magic_join=magic_join,
nonmagic_join_type=join_type,
nonmagic_join_condition=join_condition),
grammar=grammar)
if debug and debug_verbose:
log.debug("end dump:\n" + p.dump())
result = text_from_parsed(p, formatted=False)
if formatted:
result = format_sql(result)
if debug:
log.info("END: {}".format(result))
return result
# =============================================================================
# Unit tests
# =============================================================================
def unit_tests() -> None:
grammar = make_grammar(SqlaDialectName.MYSQL)
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=grammar,
select_elements=[SelectElement(
column_id=ColumnId(table="t2", column="c")
)],
magic_join=False # magic_join requires DB knowledge hence Django
))
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=grammar,
select_elements=[SelectElement(
column_id=ColumnId(table="t1", column="a")
)]
))
log.info(add_to_select(
"",
grammar=grammar,
select_elements=[SelectElement(
column_id=ColumnId(table="t2", column="c")
)]
))
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=grammar,
where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")]
))
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1",
grammar=grammar,
where_conditions=[WhereCondition(raw_sql="t1.col1 > 5")]
))
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t3.col99 = 100",
grammar=grammar,
where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")]
))
# Multiple WHEREs where before there were none:
log.info(add_to_select(
"SELECT t1.a, t1.b FROM t1",
grammar=grammar,
where_conditions=[WhereCondition(raw_sql="t1.col1 > 99"),
WhereCondition(raw_sql="t1.col2 < 999")]
))
if __name__ == '__main__':
main_only_quicksetup_rootlogger()
unit_tests()