Module pggm_datalab_utils.db
Expand source code
from contextlib import contextmanager
from datetime import datetime
import json
import logging
import os
import pyodbc
import sqlite3
from typing import Dict, List, Optional, Iterator, Any, Union, Hashable
Connection = Union[pyodbc.Connection, sqlite3.Connection] # Todo: unify in own interface to get autocomplete
Cursor = Union[pyodbc.Cursor, sqlite3.Cursor]
Record = Dict[str, Any]
RecordList = List[Record]
def get_db_connection(server: str, database: str) -> Connection:
"""
Initiate pyodbc connection to a SQL Server database. This function is intended to be suitable for cloud development:
in the cloud you use environment variables to log in under a certain username and password, whereas locally you
simply log in using your AAD credentials.
You may specify `DB_DRIVER` to a pyodbc-compatible driver name for your system. It defaults
to `{ODBC Driver 17 for SQL Server}`. If you specify the value `SQLITE` this routine will use the built-in sqlite3
library to connect instead.
By default, this connection will try to log into the database with the user account it's executing under. This is
compatible with AAD login and suitable for local development. If you want to log into a database from another
environment, you will have to use Windows credentials. Save the username in the environment variable `DB_UID`,
the password in `DB_PASSWORD`.
"""
driver = os.environ.get('DB_DRIVER', '{ODBC Driver 17 for SQL Server}')
if driver == 'SQLITE':
logging.info(f'Connecting to SQLite database {database} because driver=SQLITE. Ignoring server {server}.')
# Make sqlite3 somewhat well-behaved.
sqlite3.register_converter('datetime', lambda b: datetime.fromisoformat(b.decode()))
sqlite3.register_converter('json', json.loads)
sqlite3.register_adapter(list, json.dumps)
sqlite3.register_adapter(dict, json.dumps)
return sqlite3.connect(database, detect_types=sqlite3.PARSE_DECLTYPES)
elif user := os.environ.get('DB_UID', False):
logging.info(f'Logging into {database}/{server} as {user}.')
pwd = os.environ['DB_PASSWORD']
return pyodbc.connect(
f'Driver={driver};'
f'Server={server};'
f'Database={database};'
f'Uid={os.environ["DB_UID"]};'
f'Pwd={pwd};')
else:
logging.info(f'Logging into {database}/{server} as program user.')
return pyodbc.connect(
f'Driver={driver};'
f'Server={server};'
f'Database={database};'
f'trusted_connection=yes;')
@contextmanager
def cursor(db_server: str, db_name: str) -> Iterator[Cursor]:
"""
Obtain a cursor for a certain database server and database name. Internally uses `get_db_connection`. Use this as
a context manager, which will handle closing the cursor and the connection. NOTE: this will not handle transaction
support: most of the time that means you need to commit your transactions yourself!
Example usage:
```
with cursor('my_server.net', 'test') as c:
my_data = c.execute('select * from test_database').fetchall()
```
"""
conn = get_db_connection(db_server, db_name)
c = conn.cursor()
try:
yield c
finally:
c.close()
conn.close()
def query(c: Cursor, sql: str, data: Optional[tuple] = None) -> RecordList:
"""
Call `c.execute(sql, data).fetchall()` and format the resulting rowset a list of records of the form
[{colname: value}].
"""
if data is None:
result = c.execute(sql).fetchall()
else:
result = c.execute(sql, data).fetchall()
headers = [name for name, *_ in c.description]
return [dict(zip(headers, r)) for r in result]
def get_all(c: Cursor, table_name: str) -> RecordList:
"""
Get all current data from table `table_name`.
IMPORTANT WARNING: `table_name` is not sanitized. Don't pass untrusted table names to this function!
"""
return query(c, f'select * from {table_name}')
def validate(data: RecordList):
assert len(unique := set(tuple(sorted(r.keys())) for r in data)) == 1, \
f'Non-uniform list of dictionaries passed, got differing keys {unique}.'
assert not any(non_str := {k: type(k) for k in data[0].keys() if not isinstance(k, str)}), \
f'Non-string keys in data, got keys with types {non_str}.'
def insert_with_return(
c: Cursor, table_name: str, data: Record, return_columns: Optional[Union[str, tuple]] = None
) -> Record:
"""
Insert data into the database, returning a set of return columns. The primary use for this is if you have columns
generated by your database, like an identity. Returns input record with returned columns added (if any).
"""
validate([data])
return_columns = (return_columns,) if isinstance(return_columns, str) else return_columns
columns = data.keys()
insert_data = tuple(data[col] for col in columns)
text_columns = ', '.join(columns)
placeholders = ', '.join('?' for _ in columns)
# Dispatch on cursor type for now, pyodbc type is for MSSQL only
if return_columns is None:
sql = f'insert into {table_name}({text_columns}) values ({placeholders})'
c.execute(sql)
return data
elif isinstance(c, sqlite3.Cursor):
text_return_columns = ', '.join(return_columns)
sql = f'insert into {table_name}({text_columns}) values ({placeholders}) returning {text_return_columns}'
output = query(c, sql, insert_data)[0]
return {**data, **output}
else:
text_return_columns = ', '.join(f'Inserted.{col}' for col in return_columns)
sql = f'insert into {table_name}({text_columns}) output {text_return_columns} values ({placeholders})'
output = query(c, sql, insert_data)[0]
return {**data, **output}
def write(c: Cursor, table_name: str, data: RecordList, primary_key: Optional[Union[str, tuple]] = None, *,
update=True, insert=True, delete=True):
"""
Update data in database table. We check identity based on the keys of the IndexedPyFrame.
`update`, `insert`, and `delete` control which actions to take. By default, this function emits the correct update,
insert, and delete queries to make the database table equal to the in-memory table.
- `update=True` means rows already in the database will be updated with the in-memory data
- `insert=True` means rows not already in the database will be added from the in-memory data
- `delete=True` means rows present in the database but not in the in-memory database will be deleted
If primary_key is None, only inserting is supported.
IMPORTANT WARNING: `table_name` is not sanitized. Don't pass untrusted table names to this function!
"""
validate(data)
# Deal with primary key, list of writeable columns, indexed data, data in db
if primary_key is None:
assert not update and not delete, 'updating and deleting without specifying a primary key not supported'
primary_key = tuple()
data = {i: r for i, r in enumerate(data)}
columns = tuple(k for k in data[0].keys())
in_db = set()
else:
primary_key = (primary_key,) if isinstance(primary_key, str) else tuple(primary_key)
assert all(isinstance(r[k], Hashable) for r in data for k in primary_key)
if any(empty_strings := [name for name in data[0].keys() if any(r[name] == '' for r in data)]):
logging.warning(f'Columns {empty_strings} contain empty strings. '
f'Generally inserting empty strings into a database is a bad idea.')
# List of writeable columns (for updates we don't try to overwrite the primary key)
columns = tuple(k for k in data[0].keys() if k not in primary_key)
# Indexed data on primary key
data = {tuple(r[i] for i in primary_key): r for r in data}
# Data present in database
sql = f'select {", ".join(primary_key)} from {table_name}'
in_db = {tuple(r[k] for k in primary_key) for r in query(c, sql)}
if update and (update_keys := data.keys() & in_db):
update_data = [
(tuple(data[k][col] for col in columns) + tuple(data[k][col] for col in primary_key)) for k in update_keys
]
# Cannot use keyword placeholders because pyodbc doesn't support named paramstyle. Would be better.
assignment = ', '.join(f'{col}=?' for col in columns)
pk_cols = ' AND '.join(f'{col}=?' for col in primary_key)
sql = f'update {table_name} set {assignment} where {pk_cols}'
c.executemany(sql, update_data)
if insert and (insert_keys := data.keys() - in_db):
insert_data = [tuple(data[k][col] for col in columns + primary_key) for k in insert_keys]
placeholders = ', '.join(f'?' for _ in columns + primary_key)
text_columns = ', '.join(columns + primary_key)
sql = f'insert into {table_name}({text_columns}) VALUES ({placeholders})'
c.executemany(sql, insert_data)
if delete and (delete_keys := in_db - data.keys()):
condition = ' AND '.join(f'{k}=?' for k in primary_key)
sql = f'delete from {table_name} where {condition}'
c.executemany(sql, list(delete_keys))
Functions
def cursor(db_server: str, db_name: str) ‑> Iterator[Union[pyodbc.Cursor, sqlite3.Cursor]]
-
Obtain a cursor for a certain database server and database name. Internally uses
get_db_connection()
. Use this as a context manager, which will handle closing the cursor and the connection. NOTE: this will not handle transaction support: most of the time that means you need to commit your transactions yourself! Example usage:with cursor('my_server.net', 'test') as c: my_data = c.execute('select * from test_database').fetchall()
Expand source code
@contextmanager def cursor(db_server: str, db_name: str) -> Iterator[Cursor]: """ Obtain a cursor for a certain database server and database name. Internally uses `get_db_connection`. Use this as a context manager, which will handle closing the cursor and the connection. NOTE: this will not handle transaction support: most of the time that means you need to commit your transactions yourself! Example usage: ``` with cursor('my_server.net', 'test') as c: my_data = c.execute('select * from test_database').fetchall() ``` """ conn = get_db_connection(db_server, db_name) c = conn.cursor() try: yield c finally: c.close() conn.close()
def get_all(c: Union[pyodbc.Cursor, sqlite3.Cursor], table_name: str) ‑> List[Dict[str, Any]]
-
Get all current data from table
table_name
.IMPORTANT WARNING:
table_name
is not sanitized. Don't pass untrusted table names to this function!Expand source code
def get_all(c: Cursor, table_name: str) -> RecordList: """ Get all current data from table `table_name`. IMPORTANT WARNING: `table_name` is not sanitized. Don't pass untrusted table names to this function! """ return query(c, f'select * from {table_name}')
def get_db_connection(server: str, database: str) ‑> Union[pyodbc.Connection, sqlite3.Connection]
-
Initiate pyodbc connection to a SQL Server database. This function is intended to be suitable for cloud development: in the cloud you use environment variables to log in under a certain username and password, whereas locally you simply log in using your AAD credentials. You may specify
DB_DRIVER
to a pyodbc-compatible driver name for your system. It defaults to{ODBC Driver 17 for SQL Server}
. If you specify the valueSQLITE
this routine will use the built-in sqlite3 library to connect instead. By default, this connection will try to log into the database with the user account it's executing under. This is compatible with AAD login and suitable for local development. If you want to log into a database from another environment, you will have to use Windows credentials. Save the username in the environment variableDB_UID
, the password inDB_PASSWORD
.Expand source code
def get_db_connection(server: str, database: str) -> Connection: """ Initiate pyodbc connection to a SQL Server database. This function is intended to be suitable for cloud development: in the cloud you use environment variables to log in under a certain username and password, whereas locally you simply log in using your AAD credentials. You may specify `DB_DRIVER` to a pyodbc-compatible driver name for your system. It defaults to `{ODBC Driver 17 for SQL Server}`. If you specify the value `SQLITE` this routine will use the built-in sqlite3 library to connect instead. By default, this connection will try to log into the database with the user account it's executing under. This is compatible with AAD login and suitable for local development. If you want to log into a database from another environment, you will have to use Windows credentials. Save the username in the environment variable `DB_UID`, the password in `DB_PASSWORD`. """ driver = os.environ.get('DB_DRIVER', '{ODBC Driver 17 for SQL Server}') if driver == 'SQLITE': logging.info(f'Connecting to SQLite database {database} because driver=SQLITE. Ignoring server {server}.') # Make sqlite3 somewhat well-behaved. sqlite3.register_converter('datetime', lambda b: datetime.fromisoformat(b.decode())) sqlite3.register_converter('json', json.loads) sqlite3.register_adapter(list, json.dumps) sqlite3.register_adapter(dict, json.dumps) return sqlite3.connect(database, detect_types=sqlite3.PARSE_DECLTYPES) elif user := os.environ.get('DB_UID', False): logging.info(f'Logging into {database}/{server} as {user}.') pwd = os.environ['DB_PASSWORD'] return pyodbc.connect( f'Driver={driver};' f'Server={server};' f'Database={database};' f'Uid={os.environ["DB_UID"]};' f'Pwd={pwd};') else: logging.info(f'Logging into {database}/{server} as program user.') return pyodbc.connect( f'Driver={driver};' f'Server={server};' f'Database={database};' f'trusted_connection=yes;')
def insert_with_return(c: Union[pyodbc.Cursor, sqlite3.Cursor], table_name: str, data: Dict[str, Any], return_columns: Union[str, tuple, ForwardRef(None)] = None) ‑> Dict[str, Any]
-
Insert data into the database, returning a set of return columns. The primary use for this is if you have columns generated by your database, like an identity. Returns input record with returned columns added (if any).
Expand source code
def insert_with_return( c: Cursor, table_name: str, data: Record, return_columns: Optional[Union[str, tuple]] = None ) -> Record: """ Insert data into the database, returning a set of return columns. The primary use for this is if you have columns generated by your database, like an identity. Returns input record with returned columns added (if any). """ validate([data]) return_columns = (return_columns,) if isinstance(return_columns, str) else return_columns columns = data.keys() insert_data = tuple(data[col] for col in columns) text_columns = ', '.join(columns) placeholders = ', '.join('?' for _ in columns) # Dispatch on cursor type for now, pyodbc type is for MSSQL only if return_columns is None: sql = f'insert into {table_name}({text_columns}) values ({placeholders})' c.execute(sql) return data elif isinstance(c, sqlite3.Cursor): text_return_columns = ', '.join(return_columns) sql = f'insert into {table_name}({text_columns}) values ({placeholders}) returning {text_return_columns}' output = query(c, sql, insert_data)[0] return {**data, **output} else: text_return_columns = ', '.join(f'Inserted.{col}' for col in return_columns) sql = f'insert into {table_name}({text_columns}) output {text_return_columns} values ({placeholders})' output = query(c, sql, insert_data)[0] return {**data, **output}
def query(c: Union[pyodbc.Cursor, sqlite3.Cursor], sql: str, data: Optional[tuple] = None) ‑> List[Dict[str, Any]]
-
Call
c.execute(sql, data).fetchall()
and format the resulting rowset a list of records of the form [{colname: value}].Expand source code
def query(c: Cursor, sql: str, data: Optional[tuple] = None) -> RecordList: """ Call `c.execute(sql, data).fetchall()` and format the resulting rowset a list of records of the form [{colname: value}]. """ if data is None: result = c.execute(sql).fetchall() else: result = c.execute(sql, data).fetchall() headers = [name for name, *_ in c.description] return [dict(zip(headers, r)) for r in result]
def validate(data: List[Dict[str, Any]])
-
Expand source code
def validate(data: RecordList): assert len(unique := set(tuple(sorted(r.keys())) for r in data)) == 1, \ f'Non-uniform list of dictionaries passed, got differing keys {unique}.' assert not any(non_str := {k: type(k) for k in data[0].keys() if not isinstance(k, str)}), \ f'Non-string keys in data, got keys with types {non_str}.'
def write(c: Union[pyodbc.Cursor, sqlite3.Cursor], table_name: str, data: List[Dict[str, Any]], primary_key: Union[str, tuple, ForwardRef(None)] = None, *, update=True, insert=True, delete=True)
-
Update data in database table. We check identity based on the keys of the IndexedPyFrame.
update
,insert
, anddelete
control which actions to take. By default, this function emits the correct update, insert, and delete queries to make the database table equal to the in-memory table. -update=True
means rows already in the database will be updated with the in-memory data -insert=True
means rows not already in the database will be added from the in-memory data -delete=True
means rows present in the database but not in the in-memory database will be deletedIf primary_key is None, only inserting is supported.
IMPORTANT WARNING:
table_name
is not sanitized. Don't pass untrusted table names to this function!Expand source code
def write(c: Cursor, table_name: str, data: RecordList, primary_key: Optional[Union[str, tuple]] = None, *, update=True, insert=True, delete=True): """ Update data in database table. We check identity based on the keys of the IndexedPyFrame. `update`, `insert`, and `delete` control which actions to take. By default, this function emits the correct update, insert, and delete queries to make the database table equal to the in-memory table. - `update=True` means rows already in the database will be updated with the in-memory data - `insert=True` means rows not already in the database will be added from the in-memory data - `delete=True` means rows present in the database but not in the in-memory database will be deleted If primary_key is None, only inserting is supported. IMPORTANT WARNING: `table_name` is not sanitized. Don't pass untrusted table names to this function! """ validate(data) # Deal with primary key, list of writeable columns, indexed data, data in db if primary_key is None: assert not update and not delete, 'updating and deleting without specifying a primary key not supported' primary_key = tuple() data = {i: r for i, r in enumerate(data)} columns = tuple(k for k in data[0].keys()) in_db = set() else: primary_key = (primary_key,) if isinstance(primary_key, str) else tuple(primary_key) assert all(isinstance(r[k], Hashable) for r in data for k in primary_key) if any(empty_strings := [name for name in data[0].keys() if any(r[name] == '' for r in data)]): logging.warning(f'Columns {empty_strings} contain empty strings. ' f'Generally inserting empty strings into a database is a bad idea.') # List of writeable columns (for updates we don't try to overwrite the primary key) columns = tuple(k for k in data[0].keys() if k not in primary_key) # Indexed data on primary key data = {tuple(r[i] for i in primary_key): r for r in data} # Data present in database sql = f'select {", ".join(primary_key)} from {table_name}' in_db = {tuple(r[k] for k in primary_key) for r in query(c, sql)} if update and (update_keys := data.keys() & in_db): update_data = [ (tuple(data[k][col] for col in columns) + tuple(data[k][col] for col in primary_key)) for k in update_keys ] # Cannot use keyword placeholders because pyodbc doesn't support named paramstyle. Would be better. assignment = ', '.join(f'{col}=?' for col in columns) pk_cols = ' AND '.join(f'{col}=?' for col in primary_key) sql = f'update {table_name} set {assignment} where {pk_cols}' c.executemany(sql, update_data) if insert and (insert_keys := data.keys() - in_db): insert_data = [tuple(data[k][col] for col in columns + primary_key) for k in insert_keys] placeholders = ', '.join(f'?' for _ in columns + primary_key) text_columns = ', '.join(columns + primary_key) sql = f'insert into {table_name}({text_columns}) VALUES ({placeholders})' c.executemany(sql, insert_data) if delete and (delete_keys := in_db - data.keys()): condition = ' AND '.join(f'{k}=?' for k in primary_key) sql = f'delete from {table_name} where {condition}' c.executemany(sql, list(delete_keys))