Source code for ayx_python_sdk.providers.amp_provider.data_transport.amp_transport

# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arrows stream formatted transport interface."""
import logging
from collections import defaultdict
from typing import Generator, TYPE_CHECKING, Tuple

from ayx_python_sdk.providers.amp_provider.data_transport.transport_base import (
    TransportBase,
)
from ayx_python_sdk.providers.amp_provider.grpc_helpers.record_transfer_msgs import (
    new_record_transfer_out_msg,
)

import pyarrow as pa

if TYPE_CHECKING:
    from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_v2_pb2 import (  # noqa: F401
        RecordTransferIn,
        RecordTransferOut,
    )

# Size of stream chunks, in n bytes.
# Desired 1MB default is causing some failures. Fairly sure this is on the slicing
# logic, and just needs better bounds to check so arrow doesn't freak out.
# Boxing that for another ticket after bigger stuff is done, though.
DEFAULT_CHUNK_SIZE = 1 * (10 ** 6)


[docs]class AmpStreamTransport(TransportBase): """Transport interface to deliver data in arrows stream format.""" logger = logging.getLogger() def __init__(self) -> None: self.stream_buffer: dict = defaultdict(dict) self.chunk_size = DEFAULT_CHUNK_SIZE self.incomplete_streams = 0 # TODO: Rename this to something better for Transport interface.
[docs] def send_record( self, record_batch: "pa.RecordBatch", anchor_name: str, connection_name: str = "", ) -> Generator["RecordTransferOut", None, None]: """Convert a given record batch into a sequence of RecordTransferOut msgs.""" stream = self.get_rec_stream(record_batch) for chunk, end_of_chunk in self.iter_stream(stream): record_out = { "anchor_name": anchor_name, "data": chunk, "connection_name": connection_name, "end_of_chunk": end_of_chunk, } yield new_record_transfer_out_msg(record_out, "outgoing_records")
[docs] def get_rec_stream(self, record_batch: "pa.RecordBatch") -> "pa.Buffer": """Get an arrows formatted stream buffer.""" try: sink = pa.BufferOutputStream() writer = pa.ipc.new_stream(sink, record_batch.schema) writer.write(record_batch) writer.close() except Exception as e: self.logger.error(repr(e)) raise e return sink.getvalue()
[docs] def iter_stream( self, stream: "pa.Buffer" ) -> Generator[Tuple[bytes, bool], None, None]: """Break up a given stream by set chunk size.""" buf = pa.BufferReader(stream) end_of_chunk = False while not end_of_chunk: end_of_chunk = buf.tell() + self.chunk_size >= buf.size() chunk = buf.read() if end_of_chunk else buf.read(self.chunk_size) yield (chunk, end_of_chunk)
def _unpack_and_append_rec_in_payload(self, record_in: "RecordTransferIn") -> None: anchor_name = record_in.incoming_records.anchor_name connection_name = record_in.incoming_records.connection_name data = record_in.incoming_records.data if not self.stream_buffer[anchor_name].get(connection_name): self.incomplete_streams += 1 self.stream_buffer[anchor_name][connection_name] = pa.BufferOutputStream() self.stream_buffer[anchor_name][connection_name].write(data) def _handle_completed_stream(self, record_in: "RecordTransferIn") -> None: self.logger.debug("Stream terminater received, completing batch.") anchor_name = record_in.incoming_records.anchor_name connection_name = record_in.incoming_records.connection_name data = self.stream_buffer[anchor_name][connection_name] try: reader = pa.ipc.open_stream(data.getvalue()) self.stream_buffer[anchor_name][connection_name] = pa.BufferOutputStream() self.incomplete_streams -= 1 return reader.read_all() except pa.ArrowException as ae: self.logger.error( "Exception caught reading completed batch chunk: %s", repr(ae) ) raise ae
[docs] def receive_record(self, record_in_msg: "RecordTransferIn") -> "pa.Table": """ Receive RecordTransferIn messages. Returns None if the data received is only a chunk of a full record. Returns pa.Table if `end_of_chunk` is True. """ self._unpack_and_append_rec_in_payload(record_in_msg) if record_in_msg.incoming_records.end_of_chunk: return self._handle_completed_stream(record_in_msg) return None