Source code for ayx_python_sdk.core.testing

# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions, classes, and fixtures for user-generated unit testing."""
import asyncio
from collections import namedtuple
from dataclasses import dataclass
from typing import Dict, List, Optional, Type

from ayx_python_sdk.core import Anchor, PluginV2
from ayx_python_sdk.providers.amp_provider import AMPProviderV2
from ayx_python_sdk.providers.amp_provider.sdk_tool_service_v2 import SdkToolServiceV2

import pyarrow as pa


@dataclass
class _IncomingConnection:
    name: str
    metadata: Optional[pa.schema]


@dataclass
class _IncomingAnchor:
    name: str
    connections: List[_IncomingConnection]


@dataclass
class _OutgoingAnchor:
    name: str
    num_connections: int
    metadata: Optional[pa.schema]


@dataclass
class _PluginInitializationData:
    configXml: str  # noqa
    incomingAnchors: List[_IncomingAnchor]  # noqa
    outgoingAnchors: List[_OutgoingAnchor]  # noqa


@dataclass
class _DummyCtrlIn:
    plugin_initialization_data: _PluginInitializationData


BatchTuple = namedtuple("BatchTuple", ["input_data", "expected_output_data"])


[docs]class SdkToolTestService: """A minimal in-process client for testing user plugins.""" MAPPING = ["NONE", "INFO", "WARN", "ERROR", "FILEINPUT", "FILEOUTPUT"] def __init__( self, plugin_class: Type[PluginV2], input_anchor_config: Dict[str, Optional[pa.schema]], output_anchor_config: Dict[str, Optional[pa.schema]], config_mock: str, ): self.service = SdkToolServiceV2() provider = AMPProviderV2() incoming_anchors = [ _IncomingAnchor( name=key, connections=[ _IncomingConnection( name=str(i + 1), metadata=input_anchor_config.get(key) ) ], ) for i, key in enumerate(input_anchor_config.keys()) ] outgoing_anchors = [ _OutgoingAnchor( name=key, num_connections=1, metadata=output_anchor_config.get(key) ) for key in output_anchor_config.keys() ] # TODO: generate/pass config_mock in as a dict/json, convert to xml here init_data = _DummyCtrlIn( _PluginInitializationData( configXml=config_mock, incomingAnchors=incoming_anchors, outgoingAnchors=outgoing_anchors, ) ) provider.set_anchors(init_data) provider.environment.set_tool_config(config_mock) self.plugin = plugin_class(provider) self.service.driver.plugin = self.plugin self._data_streams: Dict[str, List[pa.RecordBatch]] = {} self._io_stream: List[str] = [] # TODO: auto-generate plugin initialization from configXML @property def io_stream(self) -> List[str]: """ Store any strings sent over through provider.io. This is roughly the format the resulting strings follow: "(INFO|WARN|ERROR):(.*)" For example, this provider.io call in the plugin code: self.provider.io.info("Test Code") would show up as "INFO:Test Code" in self.io_stream. Returns ------- List[str] All strings generated by provider.io calls within the plugin code. """ return self._io_stream @property def data_streams(self) -> Dict[str, List[pa.RecordBatch]]: """ Capture any data written via provider.io.write_to_anchor(). The completed stream is represented by a dictionary, typed like so: { "<Output Anchor name>": [pa.RecordBatch] } If provider.io.write_to_anchor is never called, the data_streams attribute should be an empty dictionary. Returns ------- Dict[str, List[pa.RecordBatch] A mapping between output anchor names and the corresponding list of RecordBatches written to that output anchor. """ return self._data_streams
[docs] def run_on_record_batch(self, batch: pa.RecordBatch, anchor: Anchor) -> None: """Prepare and run the plugin's on_record_batch method, and capture the I/O and data streams.""" self._flush_and_save_streams() self.service.driver.record_batch_received(batch, anchor) self._flush_and_save_streams()
[docs] def run_on_incoming_connection_complete(self, anchor: Anchor) -> None: """Prepare and run the plugin's on_incoming_connection_complete method, and capture the I/O and data streams.""" self._flush_and_save_streams() self.service.driver.incoming_connection_complete_callback(anchor) self._flush_and_save_streams()
[docs] def run_on_complete(self) -> None: """Prepare and run the plugin's on_complete method, and capture the I/O and data streams.""" self._flush_and_save_streams() self.service.driver.on_complete_callback() self._flush_and_save_streams()
def _flush_and_save_streams(self) -> None: outputs: Dict[str, List[pa.Table]] = {} io_outputs: List[str] = [] while not self.service.driver.ctrl_io.ctrl_out.empty(): item = asyncio.get_event_loop().run_until_complete( self.service.driver.ctrl_io.ctrl_out.get() ) io_outputs.append( f"{self.MAPPING[item.output_message.message_type]}:{item.output_message.message}" ) while not self.service.driver.record_io.pending_writes.empty(): item = asyncio.get_event_loop().run_until_complete( self.service.driver.record_io.pending_writes.get() ) if item["anchor_name"] not in outputs.keys(): outputs[item["anchor_name"]] = [] outputs[item["anchor_name"]].append( pa.Table.from_pydict(item["data"].to_pydict()) ) self._data_streams = { key: pa.concat_tables(value).to_batches() for key, value in outputs.items() } self._io_stream = io_outputs