Source code for ayx_python_sdk.core.testing
# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions, classes, and fixtures for user-generated unit testing."""
import asyncio
from collections import namedtuple
from dataclasses import dataclass
from typing import Dict, List, Optional, Type
from ayx_python_sdk.core import Anchor, PluginV2
from ayx_python_sdk.providers.amp_provider import AMPProviderV2
from ayx_python_sdk.providers.amp_provider.sdk_tool_service_v2 import SdkToolServiceV2
import pyarrow as pa
@dataclass
class _IncomingConnection:
name: str
metadata: Optional[pa.schema]
@dataclass
class _IncomingAnchor:
name: str
connections: List[_IncomingConnection]
@dataclass
class _OutgoingAnchor:
name: str
num_connections: int
metadata: Optional[pa.schema]
@dataclass
class _PluginInitializationData:
configXml: str # noqa
incomingAnchors: List[_IncomingAnchor] # noqa
outgoingAnchors: List[_OutgoingAnchor] # noqa
@dataclass
class _DummyCtrlIn:
plugin_initialization_data: _PluginInitializationData
BatchTuple = namedtuple("Batchtuple", ["input_data", "expected_output_data"])
[docs]class SdkToolTestService:
"""A minimal in-process client for testing user plugins."""
MAPPING = ["NONE", "INFO", "WARN", "ERROR", "FILEINPUT", "FILEOUTPUT"]
def __init__(
self,
plugin_class: Type[PluginV2],
input_anchor_config: Dict[str, Optional[pa.schema]],
output_anchor_config: Dict[str, Optional[pa.schema]],
config_mock: str,
):
self.service = SdkToolServiceV2()
provider = AMPProviderV2()
incoming_anchors = [
_IncomingAnchor(
name=key,
connections=[
_IncomingConnection(
name=str(i + 1), metadata=input_anchor_config.get(key)
)
],
)
for i, key in enumerate(input_anchor_config.keys())
]
outgoing_anchors = [
_OutgoingAnchor(
name=key, num_connections=1, metadata=output_anchor_config.get(key)
)
for key in output_anchor_config.keys()
]
# TODO: generate/pass config_mock in as a dict/json, convert to xml here
init_data = _DummyCtrlIn(
_PluginInitializationData(
configXml=config_mock,
incomingAnchors=incoming_anchors,
outgoingAnchors=outgoing_anchors,
)
)
provider.set_anchors(init_data)
provider.environment.set_tool_config(config_mock)
self.plugin = plugin_class(provider)
self.service.driver.plugin = self.plugin
self._data_streams: Dict[str, List[pa.RecordBatch]] = {}
self._io_stream: List[str] = []
# TODO: auto-generate plugin initialization from configXML
@property
def io_stream(self) -> List[str]:
"""
Store any strings sent over through provider.io.
This is roughly the format the resulting strings follow:
"(INFO|WARN|ERROR):(.*)"
For example, this provider.io call in the plugin code:
self.provider.io.info("Test Code")
would show up as "INFO:Test Code" in self.io_stream.
Returns
-------
List[str]
All strings generated by provider.io calls within the plugin code.
"""
return self._io_stream
@property
def data_streams(self) -> Dict[str, List[pa.RecordBatch]]:
"""
Capture any data written via provider.io.write_to_anchor().
The completed stream is represented by a dictionary, typed like so:
{
"<Output Anchor name>": [pa.RecordBatch]
}
If provider.io.write_to_anchor is never called, the data_streams attribute should be an empty dictionary.
Returns
-------
Dict[str, List[pa.RecordBatch]
A mapping between output anchor names and the corresponding list of RecordBatches
written to that output anchor.
"""
return self._data_streams
[docs] def run_on_record_batch(self, batch: pa.RecordBatch, anchor: Anchor) -> None:
"""Prepare and run the plugin's on_record_batch method, and capture the I/O and data streams."""
self._flush_and_save_streams()
self.service.driver.record_batch_received(batch, anchor)
self._flush_and_save_streams()
[docs] def run_on_incoming_connection_complete(self, anchor: Anchor) -> None:
"""Prepare and run the plugin's on_incoming_connection_complete method, and capture the I/O and data streams."""
self._flush_and_save_streams()
self.service.driver.incoming_connection_complete_callback(anchor)
self._flush_and_save_streams()
[docs] def run_on_complete(self) -> None:
"""Prepare and run the plugin's on_complete method, and capture the I/O and data streams."""
self._flush_and_save_streams()
self.service.driver.on_complete_callback()
self._flush_and_save_streams()
def _flush_and_save_streams(self) -> None:
outputs: Dict[str, List[pa.Table]] = {}
io_outputs: List[str] = []
while not self.service.driver.ctrl_io.ctrl_out.empty():
item = asyncio.get_event_loop().run_until_complete(
self.service.driver.ctrl_io.ctrl_out.get()
)
io_outputs.append(
f"{self.MAPPING[item.output_message.message_type]}:{item.output_message.message}"
)
while not self.service.driver.record_io.pending_writes.empty():
item = asyncio.get_event_loop().run_until_complete(
self.service.driver.record_io.pending_writes.get()
)
if item["anchor_name"] not in outputs.keys():
outputs[item["anchor_name"]] = []
outputs[item["anchor_name"]].append(
pa.Table.from_pydict(item["data"].to_pydict())
)
self._data_streams = {
key: pa.concat_tables(value).to_batches() for key, value in outputs.items()
}
self._io_stream = io_outputs