# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for running a plugin out of process."""
import asyncio
import os
import sys
import time
from enum import Enum
from pathlib import Path
from typing import AsyncIterable, Callable, List, Optional, TYPE_CHECKING, Tuple
from ayx_python_sdk.providers.amp_provider import (
AMPInputAnchor,
AMPInputConnection,
AMPOutputAnchor,
)
from ayx_python_sdk.providers.amp_provider.builders import OutputAnchorBuilder
from ayx_python_sdk.providers.amp_provider.builders.input_anchor_builder import (
InputAnchorBuilder,
)
from ayx_python_sdk.providers.amp_provider.data_transport.amp_transport import (
AmpStreamTransport,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.plugin_initialization_data_pb2 import (
UpdateMode,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_pb2_grpc import (
SdkToolStub,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_v2_pb2 import (
ControlIn,
RecordTransferIn,
RecordTransferOut,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_v2_pb2_grpc import (
SdkToolV2Stub,
)
from ayx_python_sdk.test_harness.process_lifecycle_manager import (
ProcessLifecycleManager,
)
import grpc
import pyarrow as pa
import typer
import xmltodict
if TYPE_CHECKING:
from ayx_python_sdk.core import Metadata # noqa: F401
import pandas as pd # noqa: F401
[docs]class RunMode(str, Enum):
"""Run mode types."""
update_only = "update_only"
full_run = "full"
def _handle_sdk_tool_service_exception(method: Callable) -> Callable:
def _log_grpc_errors(*args, **kwargs): # type: ignore
try:
return method(*args, **kwargs)
except grpc.RpcError as e:
typer.echo("SDK Tool Service failed.")
typer.echo(f"gRPC Status Code: {e.code()}")
typer.echo(f"gRPC Details (Stacktrace):\n{e.details()}")
raise typer.Exit(code=1)
return _log_grpc_errors
def _parse_config_xml(
xml_path: Path,
) -> Tuple[List[AMPInputAnchor], List[AMPOutputAnchor]]:
input_anchors = []
output_anchors = []
typer.echo(f"{xml_path}")
with open(xml_path) as config_xml:
config_dict = xmltodict.parse(config_xml.read())
inputs = config_dict["AlteryxJavaScriptPlugin"]["GuiSettings"][
"InputConnections"
]["Connection"]
outputs = config_dict["AlteryxJavaScriptPlugin"]["GuiSettings"][
"OutputConnections"
]["Connection"]
if type(inputs) is list:
for input_anchor in inputs:
anchor = AMPInputAnchor(
name=input_anchor["@Name"],
allow_multiple=input_anchor["@AllowMultiple"],
optional=input_anchor["@Optional"],
)
input_anchors.append(anchor)
else:
anchor = AMPInputAnchor(
name=inputs["@Name"],
allow_multiple=inputs["@AllowMultiple"],
optional=inputs["@Optional"],
)
input_anchors.append(anchor)
if type(outputs) is list:
for output_anchor in outputs:
anchor = AMPOutputAnchor(
name=output_anchor["@Name"],
allow_multiple=output_anchor["@AllowMultiple"],
optional=output_anchor["@Optional"],
)
output_anchors.append(anchor)
else:
anchor = AMPOutputAnchor(
name=outputs["@Name"],
allow_multiple=outputs["@AllowMultiple"],
optional=outputs["@Optional"],
)
output_anchors.append(anchor)
return input_anchors, output_anchors
[docs]class PluginRunner:
"""Class for running a plugin out of process with test data."""
def __init__(
self,
plugin_entrypoint: Path,
plugins_package: str,
tool_name: str,
input_metadata: List["Metadata"],
input_data: List["pd.DataFrame"],
config_xml: Path,
transport_type: str,
) -> None:
"""Construct the plugin runner."""
self._channel = None
self._plugin_entrypoint = plugin_entrypoint
self._plugins_package = plugins_package
self._tool_name = tool_name
self._input_metadata = input_metadata
self._input_data = input_data
self._transport_type = transport_type
self._sdk_tool_client: Optional[SdkToolStub] = None
self._input_anchors, self._output_anchors = _parse_config_xml(config_xml)
self._control_write_queue: asyncio.Queue[ControlIn] = asyncio.Queue()
self._control_read_queue: list = []
self._record_write_queue: asyncio.Queue[RecordTransferIn] = asyncio.Queue()
self._received_records: List[RecordTransferOut] = []
self.stream_transport = AmpStreamTransport()
for i in range(len(self._input_anchors)):
connection = [
AMPInputConnection(
"Connection" + str(i),
metadata=self._input_metadata[i],
anchor=self._input_anchors[i],
)
]
self._input_anchors[i].connections.extend(connection)
typer.echo(f"{transport_type}")
if not (len(self._input_data) == len(self._input_metadata)):
typer.echo(
f"{len(self._input_data)}\n{self._input_anchors}\n{len(self._input_metadata)}"
)
raise Exception(
"Input datasets, input metadata, and input anchors must match."
)
[docs] async def run_plugin_v2(self, mode: RunMode) -> None:
"""Run the plugin out of process."""
with ProcessLifecycleManager(
[
sys.executable,
str(self._plugin_entrypoint.resolve()),
"start-sdk-tool-service",
self._plugins_package,
self._tool_name,
"--sdk-engine-server-address",
"localhost:6500",
]
) as plugin_process: # noqa: F841
try:
self.channel = grpc.aio.insecure_channel(
"localhost:6500",
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
await self._wait_for_handshake_v2(plugin_process)
self.sdk_client_stub = SdkToolV2Stub(self.channel)
except Exception as e:
typer.echo(f"Error while running handshake")
raise e
# TODO: should return a status and raise if failed
if mode == RunMode.full_run:
await self._handle_plugin_runtime()
async def _wait_for_handshake_v2(
self, plugin_process: ProcessLifecycleManager, timeout: float = 30.0
) -> None:
"""Wait for the initialization handshake to complete."""
start = time.time()
await asyncio.wait_for(self.channel.channel_ready(), timeout)
typer.echo("Channel ready")
if not plugin_process.process_alive():
typer.echo(
f"ERROR: Plugin process died before handshake completed with error."
)
raise typer.Exit(code=1)
if time.time() - start > timeout:
typer.echo("ERROR: Handshake didn't complete within timeout.")
raise typer.Exit(code=1)
@_handle_sdk_tool_service_exception
def _initialize_plugin_v2_msg(self) -> None:
"""Initialize the plugin with metadata and configuration."""
dummy_plugin_data = ControlIn()
dummy_plugin_data.plugin_initialization_data.configXml = "<Configuration />"
dummy_plugin_data.plugin_initialization_data.incomingAnchors.extend(
[
InputAnchorBuilder.to_protobuf(input_anchor)
for input_anchor in self._input_anchors
]
)
dummy_plugin_data.plugin_initialization_data.outgoingAnchors.extend(
[
OutputAnchorBuilder.to_protobuf(output_anchor)
for output_anchor in self._output_anchors
]
),
engine_constants = {
"Engine.TempFilePath": os.getcwd(),
"Engine.WorkflowDirectory": os.getcwd(),
"Engine.Version": "0.0.0.0",
"AlteryxExecutable": os.getcwd(),
"ProxyConfiguration": "ProxyRequiresCredentials=false\nProxyCommonUserName=\nProxyCommonPassword=\n",
}
for k, v in engine_constants.items():
dummy_plugin_data.plugin_initialization_data.engineConstants[k] = v
dummy_plugin_data.plugin_initialization_data.updateMode = UpdateMode.UM_Run
return dummy_plugin_data
def _push_input_data_to_queue(self) -> None:
try:
for num, anchor in enumerate(self._input_anchors):
batch = pa.RecordBatch.from_pandas(self._input_data[num])
stream = self.stream_transport.get_rec_stream(batch)
try:
for chunk, end_of_chunk in self.stream_transport.iter_stream(
stream
):
msg = RecordTransferIn()
msg.incoming_records.anchor_name = anchor.name
msg.incoming_records.data = chunk
msg.incoming_records.end_of_chunk = end_of_chunk
msg.incoming_records.connection_name = anchor.connections[
num
].name
self._record_write_queue.put_nowait(msg)
except StopIteration:
pass
except Exception as e:
typer.echo(f"Exception during push input data {repr(e)}")
[docs] async def record_transfer_producer(
self, event_plugin_ready: "asyncio.Event",
) -> "RecordTransferIn":
"""Send records to the plugin, notify client is done sending record data."""
# wait for plugin init event
await event_plugin_ready.wait()
typer.echo("Sending initial recordtransfers")
while not self._record_write_queue.empty():
yield await self._record_write_queue.get()
self._record_write_queue.task_done()
typer.echo("Sent initial transfers.")
typer.echo("Notifying plugin to complete")
self._control_write_queue.put_nowait(self._get_notify_complete_msg())
[docs] async def control_producer(
self,
event_notify_complete: "asyncio.Event",
event_plugin_ready: "asyncio.Event",
) -> "ControlIn":
"""Generate initial control messages, then wait for any new msgs to send from the queue."""
typer.echo("Sending init message")
yield self._initialize_plugin_v2_msg()
# await event_plugin_ready.wait() # Wait for plugin to send back init response
typer.echo("Init successful, starting control in loop")
# Start regular producer loop
while not event_notify_complete.is_set():
if not self._control_write_queue.empty():
yield await self._control_write_queue.get()
# send back any handled control_out messages
self._control_write_queue.task_done()
# let the client check on other streams
await asyncio.sleep(0)
[docs] @_handle_sdk_tool_service_exception
async def record_transfer_consumer(
self, resp_iterator: AsyncIterable[RecordTransferOut]
) -> None:
"""Receive and process RecordTransferOut messages from the server."""
# check for notify_complete AFTER handling any remaining pending record_out respones
async for resp in resp_iterator:
typer.echo("Response received on record consumer stream")
if resp.WhichOneof("payload") == "outgoing_records":
typer.echo("Consumer received record msg.")
record_chunk = resp.outgoing_records
try:
completed_stream = self.stream_transport.receive_record(
record_chunk
)
if completed_stream:
self._received_records.append(completed_stream)
except Exception as e:
typer.echo(f"{repr(e)}")
[docs] async def control_consumer(
self,
event_notify_complete: "asyncio.Event",
event_plugin_ready: "asyncio.Event",
resp_iterator: AsyncIterable["RecordTransferOut"],
) -> None:
"""Receives and handles Control msg responses from the server."""
async for resp in resp_iterator:
oneof = resp.WhichOneof("payload")
typer.echo(f"control consumer received oneof {oneof}")
if oneof == "translate_message":
# Return translated message, for now just flip it back
msg = ControlIn()
fake_translated = (
f"'Translated msg': {resp.translate_message.unlocalized_string}"
)
msg.translated_message.translated_message = fake_translated
msg.msg_id = resp.msg_id
typer.echo(
f"Sending translated message response to control write queue..."
)
self._control_write_queue.put_nowait(msg)
elif resp.WhichOneof("payload") == "confirm_complete":
typer.echo("Received plugin complete confirmation from SDK tool")
# Set plugin complete, asyncio.gather should let data streams finish work.
event_notify_complete.set()
elif oneof == "output_message":
if (
event_plugin_ready.is_set() is False
and "Plugin Initialized" in resp.output_message.message
):
typer.echo("Recieved SDK PLUGIN init confirmation")
event_plugin_ready.set()
typer.echo(
f"Received control out from server. \n {resp.output_message.message}"
)
def _get_notify_complete_msg(self) -> "ControlIn":
"""Create notify complete msg."""
ctrl_msg = ControlIn()
ctrl_msg.notify_complete.SetInParent()
return ctrl_msg
@_handle_sdk_tool_service_exception
async def _handle_plugin_runtime(self) -> None:
"""Send/receive all of the record packets to/from the plugin."""
typer.echo("Starting plugin runtime")
if self.sdk_client_stub is None:
typer.echo("Stub not set")
raise ValueError("SDK Tool Client must be set.")
event_notify_plugin_complete = asyncio.Event()
event_plugin_ready = asyncio.Event()
typer.echo("Setting input data...")
try:
self._push_input_data_to_queue()
except Exception as e:
typer.echo("Exception during setting input data")
raise e
typer.echo("Setting Producers")
record_transfer_responses = self.sdk_client_stub.RecordTransfer(
self.record_transfer_producer(event_plugin_ready)
)
control_responses = self.sdk_client_stub.Control(
self.control_producer(event_notify_plugin_complete, event_plugin_ready)
)
# Separate tasks to thread to let control listen for server msgs
typer.echo("Setting Consumers")
_record_consumer = self.record_transfer_consumer(record_transfer_responses)
_ctrl_consumer = self.control_consumer(
event_notify_plugin_complete, event_plugin_ready, control_responses,
)
# TODO: Add task status returns
typer.echo("Gathering...")
tasks = await asyncio.gather(_record_consumer, _ctrl_consumer)
typer.echo(tasks)
typer.echo("Received records:")
for rec in self._received_records:
typer.echo(rec)