Source code for myscaledb.client

import json as json_
import sys
import warnings
from aiobotocore.session import get_session

import aiohttp
from enum import Enum
from types import TracebackType
from typing import Any, AsyncGenerator, Dict, List, Optional, Type

from myscaledb.exceptions import ClientError, GetObjectException
from myscaledb.http_clients.abc import HttpClientABC
from myscaledb.records import FromJsonFabric, Record, RecordsFabric
from myscaledb.sql import sqlparse
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.WARN)

# Optional cython extension:
try:
    from myscaledb._types import rows2ch, json2ch, py2ch, list2ch, ObjectToFetch
except ImportError:
    from myscaledb.types import rows2ch, json2ch, py2ch, list2ch, ObjectToFetch

import functools
import asyncio
from concurrent.futures import ThreadPoolExecutor


def force_async(fn):
    """
    turns a sync function to async function using threads
    """
    pool = ThreadPoolExecutor()

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        future = pool.submit(fn, *args, **kwargs)
        return asyncio.wrap_future(future)  # make it awaitable

    return wrapper


def force_sync(fn):
    """
    turn an async function to sync function
    """

    def handle_exception(loop, context):
        # context["message"] will always be there; but context["exception"] may not
        msg = context.get("exception", context["message"])
        logger.error("Caught exception in async method: %s", msg)
        raise ClientError(msg)

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        res = fn(*args, **kwargs)
        if asyncio.iscoroutine(res):
            try:
                loop = asyncio.get_event_loop()
            except RuntimeError:
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)
            loop.set_exception_handler(handle_exception)
            return loop.run_until_complete(res)
        return res

    return wrapper


def streamFile(file_name, offset):
    batch = []
    max_batch_size = 100000
    line = 0
    rows = ""

    csvfile = open(file_name, newline='')
    csvfile.seek(offset)
    if line == 1:
        for i in range(max_batch_size):
            row = csvfile.readline().rstrip('\r\n').replace('\"', '\'').encode()
            if len(row) != 0:
                row = b''.join([b'(', row, b')'])
                batch.append(row)
            else:
                batch_str = b",".join(row for row in batch)
                csvfile.close()
                return batch_str, True, offset

    else:
        for i in range(max_batch_size):
            row = csvfile.readline()
            rows += row
        new_offset = csvfile.tell()
        csvfile.close()
        return rows, False, new_offset

    new_offset = csvfile.tell()
    csvfile.close()
    batch_str = b",".join(row for row in batch)
    return batch_str, False, new_offset


class QueryTypes(Enum):
    FETCH = 0
    INSERT = 1
    OTHER = 2


[docs]class Client: """Client connection class. Usage: .. code-block:: python async with aiohttp.ClientSession() as s: client = Client(s, compress_response=True) nums = await client.fetch("SELECT number FROM system.numbers LIMIT 100") :param aiohttp.ClientSession session: aiohttp client session. Please, use one session and one Client for all connections in your app. :param str url: Clickhouse server url. Need full path, like "http://localhost:8123/". :param str user: User name for authorization. :param str password: Password for authorization. :param str database: Database name. :param bool compress_response: Pass True if you want Clickhouse to compress its responses with gzip. They will be decompressed automatically. But overall it will be slightly slower. :param **settings: Any settings from https://clickhouse.yandex/docs/en/operations/settings """ __slots__ = ( "_session", "url", "params", "_json", "_http_client", "stream_batch_size", "connection_map", "aws_session_map", ) @force_sync async def generate_http_session(self): session = aiohttp.ClientSession() return session def __init__( self, session=None, url: str = "http://localhost:8123/", user: str = None, password: str = None, database: str = "default", compress_response: bool = False, stream_batch_size: int = 1000000, json=json_, # type: ignore **settings, ): if session: _http_client = HttpClientABC.choose_http_client(session) self._http_client = _http_client(session) self.url = url self.params = {} if user: self.params["user"] = user if password: self.params["password"] = password if database: self.params["database"] = database if compress_response: self.params["enable_http_compression"] = 1 self._json = json self.params.update(settings) self.stream_batch_size = stream_batch_size self.connection_map = {} async def __aenter__(self) -> 'Client': return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: await self.close_async()
[docs] async def close_async(self) -> None: """Close the session""" await self._http_client.close()
def close(self) -> None: @force_sync async def sync_run(): await self.close_async() return sync_run()
[docs] async def is_alive_async(self) -> bool: """Checks if connection is Ok. Usage: .. code-block:: python assert await client.is_alive() :return: True if connection Ok. False instead. """ try: await self._http_client.get( url=self.url, params={**self.params, "query": "SELECT 1"} ) except ClientError: return False return True
def is_alive(self) -> bool: @force_sync async def sync_run(): async with aiohttp.ClientSession() as session: _http_client = HttpClientABC.choose_http_client(session) self._http_client = _http_client(session) return await self.is_alive_async() return sync_run() @staticmethod def _prepare_query_params(params: Optional[Dict[str, Any]] = None): if params is None: return {} if not isinstance(params, dict): raise TypeError('Query params must be a Dict[str, Any]') prepared_query_params = {} for key, value in params.items(): prepared_query_params[key] = py2ch(value).decode('utf-8') return prepared_query_params async def _execute( self, query: str, *args, json: bool = False, query_params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> AsyncGenerator[Record, None]: query_params = self._prepare_query_params(query_params) if query_params: query = query.format(**query_params) need_fetch, is_json, is_csv, is_get_object, statement_type = self._parse_squery( query ) if not is_json and json: query += " FORMAT JSONEachRow" is_json = True if not is_json and need_fetch: query += " FORMAT TSVWithNamesAndTypes" if args: if statement_type != 'INSERT': raise ClientError( "It is possible to pass arguments only for INSERT queries" ) params = {**self.params, "query": query} if is_json: data = json2ch(*args, dumps=self._json.dumps) elif is_csv: # we'll fill the data incrementally from file if len(args) > 1: raise ClientError("only one argument is accepted in file read mode") data = [] elif isinstance(args[0], list): data = list2ch(args[0]) else: data = rows2ch(*args) else: params = {**self.params} data = query.encode() if query_id is not None: params["query_id"] = query_id if is_csv: sent = False rows_read = 0 retry = 0 max_batch_size = self.stream_batch_size csvfile = open(args[0], newline='') while True: rows = "".join(csvfile.readlines(max_batch_size)) if len(rows) == 0: csvfile.close() break rows_read += max_batch_size while not sent: if retry >= 3: logger.error("pipe breaks too many time, existing") sys.exit(1) try: await self._http_client.post_no_return( url=self.url, params=params, data=rows ) sent = True except aiohttp.ClientOSError as e: if e.errno == 32: logger.warning("broken pipe, retrying") retry += 1 else: raise e retry = 0 sent = False elif is_get_object: response = self._http_client.post_return_lines( url=self.url, params=params, data=data ) rf = RecordsFabric( names=await response.__anext__(), tps=await response.__anext__(), convert=decode, ) async for line in response: yield rf.new(line) elif need_fetch: response = self._http_client.post_return_lines( url=self.url, params=params, data=data ) if is_json: rf = FromJsonFabric(loads=self._json.loads) async for line in response: yield rf.new(line) else: rf = RecordsFabric( names=await response.__anext__(), tps=await response.__anext__(), convert=decode, ) async for line in response: yield rf.new(line) else: await self._http_client.post_no_return( url=self.url, params=params, data=data )
[docs] async def execute_async( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, ) -> None: """Execute query. Returns None. :param str query: Clickhouse query string. :param args: Arguments for insert queries. :param bool json: Execute query in JSONEachRow mode. :param Optional[Dict[str, Any]] params: Params to escape inside query string. :param str query_id: Clickhouse query_id. Usage: .. code-block:: python await client.execute( "CREATE TABLE t (a UInt8, b Tuple(Date, Nullable(Float32))) ENGINE = Memory" ) await client.execute( "INSERT INTO t VALUES", (1, (dt.date(2018, 9, 7), None)), (2, (dt.date(2018, 9, 8), 3.14)), ) await client.execute( "INSERT INTO {table_name} VALUES", (1, (dt.date(2018, 9, 7), None)), (2, (dt.date(2018, 9, 8), 3.14)), params={"table_name": "t"} ) :return: Nothing. """ async for _ in self._execute( query, *args, json=json, query_params=params, query_id=query_id ): return None
def execute( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, ): @force_sync async def sync_run( query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, ): async with aiohttp.ClientSession() as session: _http_client = HttpClientABC.choose_http_client(session) self._http_client = _http_client(session) await self.execute_async( query, *args, json=json, params=params, query_id=query_id ) sync_run(query, *args, json=json, params=params, query_id=query_id)
[docs] async def fetch_async( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> List[Record]: """Execute query and fetch all rows from query result at once in a list. :param query: Clickhouse query string. :param bool json: Execute query in JSONEachRow mode. :param Optional[Dict[str, Any]] params: Params to escape inside query string. :param str query_id: Clickhouse query_id. :param decode: Decode to python types. If False, returns bytes for each field instead. Usage: .. code-block:: python all_rows = await client.fetch("SELECT * FROM t") :return: All rows from query. """ return [ row async for row in self._execute( query, *args, json=json, query_params=params, query_id=query_id, decode=decode, ) ]
def fetch( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> List[Record]: @force_sync async def sync_run() -> List[Record]: async with aiohttp.ClientSession() as session: _http_client = HttpClientABC.choose_http_client(session) self._http_client = _http_client(session) return await self.fetch_async( query, *args, json=json, params=params, query_id=query_id, decode=decode, ) return sync_run()
[docs] async def fetchrow( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> Optional[Record]: """Execute query and fetch first row from query result or None. :param query: Clickhouse query string. :param bool json: Execute query in JSONEachRow mode. :param Optional[Dict[str, Any]] params: Params to escape inside query string. :param str query_id: Clickhouse query_id. :param decode: Decode to python types. If False, returns bytes for each field instead. Usage: .. code-block:: python row = await client.fetchrow("SELECT * FROM t WHERE a=1") assert row[0] == 1 assert row["b"] == (dt.date(2018, 9, 7), None) :return: First row from query or None if there no results. """ async for row in self._execute( query, *args, json=json, query_params=params, query_id=query_id, decode=decode, ): return row return None
[docs] async def fetchone(self, query: str, *args) -> Optional[Record]: """Deprecated. Use ``fetchrow`` method instead""" warnings.warn( "'fetchone' method is deprecated. Use 'fetchrow' method instead", PendingDeprecationWarning, ) return await self.fetchrow(query, *args)
[docs] async def fetchval( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> Any: """Execute query and fetch first value of the first row from query result or None. :param query: Clickhouse query string. :param bool json: Execute query in JSONEachRow mode. :param Optional[Dict[str, Any]] params: Params to escape inside query string. :param str query_id: Clickhouse query_id. :param decode: Decode to python types. If False, returns bytes for each field instead. Usage: .. code-block:: python val = await client.fetchval("SELECT b FROM t WHERE a=2") assert val == (dt.date(2018, 9, 8), 3.14) :return: First value of the first row or None if there no results. """ async for row in self._execute( query, *args, json=json, query_params=params, query_id=query_id, decode=decode, ): if row: return row[0] return None
[docs] async def iterate( self, query: str, *args, json: bool = False, params: Optional[Dict[str, Any]] = None, query_id: str = None, decode: bool = True, ) -> AsyncGenerator[Record, None]: """Async generator by all rows from query result. :param str query: Clickhouse query string. :param bool json: Execute query in JSONEachRow mode. :param Optional[Dict[str, Any]] params: Params to escape inside query string. :param str query_id: Clickhouse query_id. :param decode: Decode to python types. If False, returns bytes for each field instead. Usage: .. code-block:: python async for row in client.iterate( "SELECT number, number*2 FROM system.numbers LIMIT 10000" ): assert row[0] * 2 == row[1] async for row in client.iterate( "SELECT number, number*2 FROM system.numbers LIMIT {numbers_limit}", params={"numbers_limit": 10000} ): assert row[0] * 2 == row[1] :return: Rows one by one. """ async for row in self._execute( query, *args, json=json, query_params=params, query_id=query_id, decode=decode, ): yield row
[docs] async def get_objects_async(self, records: List[Record]): """ Process each row and retrieve binary data """ if len(records) == 0: return logger.info("total number of files to download: %d", len(records)) temp = records[0] get_object_columns = [] for i in range(len(temp)): if isinstance(temp[i], ObjectToFetch): get_object_columns.append(i) # if keys[i].startswith('getObject(') and isinstance(temp[i], tuple): # get_object_name = keys[i] # break if not get_object_columns: return # asyncio is annoying, had to do two loops to bypass the "async with" statement for get_object_column in get_object_columns: id_list = {} boto_clients = {} for i in range(len(records)): row = records[i] credential_string = str(row[get_object_column][2]) credential_strings = credential_string.split('&') credential = {} for c in credential_strings: credential[c.split('=')[0]] = c.split('=')[1] if boto_clients.get(credential_string) is None: boto_clients[credential_string] = credential id_list[credential_string] = [] id_list[credential_string].append(i) for credential_string in boto_clients.keys(): retries = 3 element_to_get = True ToDoList = id_list[credential_string] credential = boto_clients[credential_string] async with get_session().create_client( 's3', aws_access_key_id=credential["AccessKeyId"], aws_secret_access_key=credential["SecretAccessKey"], aws_session_token=credential["SessionToken"], ) as client: while retries > 0 and element_to_get: tasks = [] for i in ToDoList: tasks.append( asyncio.create_task( self._getObject( records[i], get_object_column, client, i ) ) ) await asyncio.gather(*tasks, return_exceptions=True) ToDoList = [] for one_task in tasks: if isinstance(one_task.exception(), GetObjectException): ToDoList.append(one_task.exception().index) if len(ToDoList) != 0: retries -= 1 logger.warning( "%d elements need to retry get object", len(ToDoList) ) else: element_to_get = False if element_to_get: raise ClientError(tasks[0].exception())
def get_objects(self, records: List[Record]): @force_sync async def sync_run(): await self.get_objects_async(records) return sync_run()
[docs] async def cursor(self, query: str, *args) -> AsyncGenerator[Record, None]: """Deprecated. Use ``iterate`` method instead""" warnings.warn( "'cursor' method is deprecated. Use 'iterate' method instead", PendingDeprecationWarning, ) async for row in self.iterate(query, *args): yield row
@staticmethod def _parse_squery(query): statement = sqlparse.parse(query)[0] statement_type = statement.get_type() if statement_type in ('SELECT', 'SHOW', 'DESCRIBE', 'EXISTS'): need_fetch = True else: need_fetch = False fmt = statement.token_matching( (lambda tk: tk.match(sqlparse.tokens.Keyword, 'FORMAT'),), 0 ) if fmt: is_json = statement.token_matching( (lambda tk: tk.match(None, ['JSONEachRow']),), statement.token_index(fmt) + 1, ) else: is_json = False fmt2 = statement.token_matching( (lambda tk: tk.match(sqlparse.tokens.Keyword, 'FORMAT'),), 0 ) if fmt2: is_csv = statement.token_matching( (lambda tk: tk.match(None, ['CSV']),), statement.token_index(fmt2) + 1, ) else: is_csv = False is_get_object = False fmt3 = statement.token_matching( (lambda tk: tk.match(sqlparse.tokens.Keyword, 'getObject'),), 0 ) if fmt3: is_get_object = True return need_fetch, is_json, is_csv, is_get_object, statement_type async def _getObject(self, row: Record, get_object_column: int, client, index: int): try: # sample url: https://region-1/myurlbukect/layer1/object1 # sample url2: s3://myurlbukect/layer1/object1 object_url = row[get_object_column][1] if object_url.find("http") != -1: object_urls = object_url.split('//')[1].split('/') bukect_name = object_urls[1] key = '/'.join(object_urls[2:]).lstrip('/') obj = await client.get_object(Bucket=bukect_name, Key=key) bin = await obj['Body'].read() row.decode_again(get_object_column, bin) return else: object_urls = object_url.split('//')[1].split('/') bukect_name = object_urls[0] key = '/'.join(object_urls[1:]).lstrip('/') obj = await client.get_object(Bucket=bukect_name, Key=key) bin = await obj['Body'].read() row.decode_again(get_object_column, bin) return except Exception as e: raise GetObjectException(e, index=index)