wallaroo.deployment

  1import asyncio
  2import json
  3import os
  4import pathlib
  5import sys
  6import time
  7from typing import (
  8    TYPE_CHECKING,
  9    Callable,
 10    Dict,
 11    Iterator,
 12    List,
 13    Mapping,
 14    Sequence,
 15    Tuple,
 16    cast,
 17)
 18
 19import aiohttp
 20import gql  # type: ignore
 21import orjson
 22import pandas as pd
 23import pyarrow as pa  # type: ignore
 24import requests
 25from IPython.display import HTML, display  # type: ignore
 26from requests.adapters import HTTPAdapter, Retry
 27
 28from .auth import _PlatformAuth
 29from .inference_result import InferenceResult
 30from .logs import LogEntries
 31from .model_config import ModelConfig
 32from .object import *
 33from .wallaroo_ml_ops_api_client.api.pipeline import pipelines_undeploy
 34from .wallaroo_ml_ops_api_client.models import pipelines_undeploy_json_body
 35from .wallaroo_ml_ops_api_client.types import UNSET
 36
 37if TYPE_CHECKING:
 38    from .client import Client
 39    from .model import Model
 40    from .pipeline_variant import PipelineVariant
 41
 42ARROW_ENABLED = "ARROW_ENABLED"
 43ARROW_HEADER = "application/vnd.apache.arrow.file"
 44PANDAS_RECORDS_HEADER = "application/json; format=pandas-records"
 45JSON_HEADER = "application/json"
 46
 47
 48class WaitForError(Exception):
 49    def __init__(self, message: str, status: Optional[Dict[str, Any]]):
 50        super().__init__(message)
 51        self.status = status
 52
 53
 54class WaitForDeployError(RuntimeError):
 55    def __init__(self, message: str):
 56        super().__init__(message)
 57
 58    def _render_traceback_(self):
 59        display(
 60            HTML(
 61                "<strong>*** An error occurred while deploying your pipeline.</strong>"
 62            )
 63        )
 64        return [str(self)]
 65
 66
 67class Deployment(Object):
 68    def __init__(self, client: Optional["Client"], data: Dict[str, Any]) -> None:
 69        self.client = client
 70        assert client is not None
 71        # TODO: revisit session initialization during connection pooling work
 72        self.session = self._initialize_session()
 73        super().__init__(gql_client=client._gql_client, data=data)
 74
 75    def _fill(self, data: Dict[str, Any]) -> None:
 76        """Fills an object given a response dictionary from the GraphQL API.
 77
 78        Only the primary key member must be present; other members will be
 79        filled in via rehydration if their corresponding member function is
 80        called.
 81        """
 82        from .pipeline_variant import PipelineVariant  # avoids circular imports
 83
 84        for required_attribute in ["id"]:
 85            if required_attribute not in data:
 86                raise RequiredAttributeMissing(
 87                    self.__class__.__name__, required_attribute
 88                )
 89        self._id = data["id"]
 90
 91        self._name = value_if_present(data, "deploy_id")
 92        self._deployed = value_if_present(data, "deployed")
 93        self._model_configs = (
 94            [
 95                ModelConfig(self.client, elem["model_config"])
 96                for elem in data["deployment_model_configs"]
 97            ]
 98            if "deployment_model_configs" in data
 99            else DehydratedValue()
100        )
101        self._pipeline_variants = (
102            [
103                PipelineVariant(self.client, elem["pipeline_version"])
104                for elem in data["deployment_pipeline_versions"]
105            ]
106            if "deployment_pipeline_versions" in data
107            else DehydratedValue()
108        )
109
110        self._pipeline_id = value_if_present(data, "pipeline_id")
111
112    def _fetch_attributes(self) -> Dict[str, Any]:
113        """Fetches all member data from the GraphQL API."""
114        return self._gql_client.execute(
115            gql.gql(
116                """
117            query DeploymentById($deployment_id: bigint!) {
118                deployment_by_pk(id: $deployment_id) {
119                    id
120                    deploy_id
121                    deployed
122                    deployment_model_configs {
123                        model_config {
124                            id
125                        }
126                    }
127                    deployment_pipeline_versions(order_by: {pipeline_version: {id: desc}}) {
128                        pipeline_version {
129                            id
130                        }
131                    }
132                }
133            }
134            """
135            ),
136            variable_values={
137                "deployment_id": self._id,
138            },
139        )["deployment_by_pk"]
140
141    def _initialize_session(self) -> requests.Session:
142        # TODO: make session initialization configurable
143        #  to be informed by connection polling reqs.
144        #  includes sane defaults to match current retry time (~45s)
145        session = requests.Session()
146        retries = Retry(
147            total=10,
148            backoff_factor=0.1,
149            status_forcelist=[503],
150            allowed_methods=["GET", "POST"],
151        )
152        session.mount("http://", HTTPAdapter(max_retries=retries))
153        return session
154
155    def id(self) -> int:
156        return self._id
157
158    @rehydrate("_name")
159    def name(self) -> str:
160        return cast(str, self._name)
161
162    @rehydrate("_deployed")
163    def deployed(self) -> bool:
164        return cast(bool, self._deployed)
165
166    @rehydrate("_model_configs")
167    def model_configs(self) -> List[ModelConfig]:
168        return cast(List[ModelConfig], self._model_configs)
169
170    @rehydrate("_pipeline_variants")
171    def pipeline_variants(self) -> List["PipelineVariant"]:
172        from .pipeline_variant import PipelineVariant  # avoids circular imports
173
174        return cast(List[PipelineVariant], self._pipeline_variants)
175
176    def deploy(self) -> "Deployment":
177        """Deploys this deployment, if it is not already deployed.
178
179        If the deployment is already deployed, this is a no-op.
180        """
181        q = gql.gql(
182            """
183        mutation Deploy($id: bigint!) {
184            update_deployment_by_pk(pk_columns: {id: $id} _set: { deployed: true }) {
185                id
186                deploy_id
187                deployed
188            }
189        }
190        """
191        )
192        variables = {"id": self.id()}
193        assert self.client is not None
194        self.client._gql_client.execute(q, variable_values=variables)
195        self._rehydrate()
196        return self
197
198    def undeploy(self) -> "Deployment":
199        """Shuts down this deployment, if it is deployed.
200
201        If the deployment is already undeployed, this is a no-op.
202        """
203        # TODO: Error handling.
204        assert self.client is not None
205
206        data = pipelines_undeploy.sync_detailed(
207            client=self.client.mlops(),
208            json_body=pipelines_undeploy_json_body.PipelinesUndeployJsonBody(
209                UNSET, self.id()
210            ),
211        )
212
213        if data.status_code != 200:
214            err = data.content.decode("utf-8")
215            raise Exception(f"Failed to undeploy. {err}")
216
217        self._rehydrate()
218        return self.wait_for_undeployed()
219
220    def _get_auth(self):
221        # TODO: Digging the auth object out of the gql_client is cheating
222        return self.client._gql_client.transport.auth
223
224    def status(self) -> Dict[str, Any]:
225        """Returns a dict of deployment status useful for determining if a deployment has succeeded.
226
227        :return: Dict of deployment internal state information.
228        :rtype: Dict[str, Any]
229        """
230
231        assert self.client is not None
232
233        status_url = f"{self.client.api_endpoint}/v1/api/status/get_deployment"
234        params = {"name": f"{self.name()}-{self.id()}"}
235
236        kind = ""
237        resp = None
238        try:
239            resp = requests.post(
240                status_url,
241                timeout=5,
242                auth=self._get_auth(),
243                json=params,
244            )
245            kind = ""
246        except requests.ReadTimeout as ex:
247            raise CommunicationError(f"rest-api connection to {status_url}")
248        except Exception:
249            kind = "comm"
250
251        if resp is not None and resp.status_code == 200:
252            res = resp.json()
253            if res is not None and res["status"] == "Running":
254                # retry for a running status
255                return res
256
257        details = ""
258        if resp is not None:
259            if resp.status_code == 200:
260                return resp.json()
261
262            if resp.status_code == 404:
263                raise EntityNotFoundError(
264                    f"Deployment not found", {"name": self.name()}
265                )
266
267            details = f"\nStatus code: {resp.status_code}\nMessage: {resp.text}"
268
269        if kind == "comm":
270            raise CommunicationError(f"rest-api connection to {status_url}")
271
272        raise RuntimeError(f"Unable to query deployment status {status_url}{details}")
273
274    def check_limit_status(self):
275        q = gql.gql(
276            """
277            query QueryLimitStatus($id: bigint!) {
278                deployment(where: {id: {_eq: $id}}) {
279                    id
280                    deployed
281                    limit_status
282                }
283            }
284            """
285        )
286
287        variables = {"id": self.id()}
288        assert self.client is not None
289        res = self.client._gql_client.execute(q, variable_values=variables)[
290            "deployment"
291        ]
292        if len(res) > 0:
293            status = res[0]
294            if "limit_status" in status:
295                limit_status = status["limit_status"]
296                if limit_status is not None:
297                    raise LimitError(limit_status)
298
299    def _wait_for(
300        self,
301        poll_fn: Callable[[], Tuple[bool, str, str]],
302        task_name: str,
303        on_iter: Callable[[int], None] = lambda ix: None,
304        timeout: Optional[int] = None,
305    ) -> "Deployment":
306        """Wait for a generic task to finish before completing.
307
308        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
309
310
311        :return: The deployment, for chaining.
312        :rtype: Deployment
313        """
314        assert self.client is not None
315        warning = False
316        duration = timeout if timeout is not None else self.client.timeout
317        message = "(none)"
318        kind = "unset"
319
320        start = time.monotonic()
321        ix = 0
322        while ix == 0 or time.monotonic() - start < duration:
323            on_iter(ix)
324            ix += 1
325
326            res, message, kind = poll_fn()
327            if res:
328                if self.client._interactive:
329                    sys.stdout.write(" ok\n")
330                return self
331
332            if self.client._interactive:
333                if not warning:
334                    sys.stdout.write(
335                        f"Waiting for {task_name} - this will take up to {duration}s "
336                    )
337                    warning = True
338                time.sleep(1)
339                sys.stdout.write(".")
340            else:
341                time.sleep(1)
342
343        if kind == "comm":
344            raise CommunicationError(message)
345        else:
346            try:
347                status: Optional[Dict[str, Any]] = self.status()
348                message = f"{task_name.capitalize()} failed. See status for details."
349            except:
350                message = f"{task_name.capitalize()} did not finish within {duration}s."
351                status = None
352            raise WaitForError(message, status)
353
354    def wait_for_running(self, timeout: Optional[int] = None) -> "Deployment":
355        """Waits for the deployment status to enter the "Running" state.
356
357        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
358        in the "Client" object constructor. Will raise various exceptions on failures.
359
360        :return: The deployment, for chaining.
361        :rtype: Deployment
362        """
363
364        def check_limit(ix: int) -> None:
365            # If this checks immediately, it will happen too soon for the deployment manager to
366            # have cleared the limit_status column on the deployment and this will fail erroneously
367            if ix > 5:
368                self.check_limit_status()
369
370        def check_for_running() -> Tuple[bool, str, str]:
371            try:
372                res = self.status()
373                if res is not None and res["status"] == "Running":
374                    return True, "", ""
375                return False, "not running", "runtime"
376            except CommunicationError as ex:
377                # Connection may be coming up, try again
378                return False, str(ex), "comm"
379            except (EntityNotFoundError, RuntimeError) as ex:
380                # Not found may switch to found, after a while. Retry it.
381                return False, f"not found {ex}", "runtime"
382
383        try:
384            return self._wait_for(check_for_running, "deployment", check_limit, timeout)
385        except WaitForError as ex:
386            status = ex.status
387            message = f"{str(ex)}\nStatus: {str(status)}"
388            if status is not None and status.get("status") == "Error":
389                quantity, resource = None, None
390                engines = status.get("engines", [])
391                engine_lbs = status.get("engine_lbs", [])
392                required_cpu = next(
393                    filter(
394                        lambda item: item.get("status") == "Pending"
395                        and item.get("required_cpu"),
396                        engines + engine_lbs,
397                    ),
398                    cast(Dict[str, Any], {}),
399                ).get("required_cpu")
400                if required_cpu:
401                    resource = "CPU"
402                    quantity = (
403                        "one CPU"
404                        if required_cpu == "1"
405                        else f"{required_cpu} units of CPU"
406                    )
407                else:
408                    required_memory = next(
409                        filter(
410                            lambda item: item.get("status") == "Pending"
411                            and item.get("required_memory"),
412                            engines + engine_lbs,
413                        ),
414                        cast(Dict[str, Any], {}),
415                    ).get("required_memory")
416                    if required_memory:
417                        resource = "memory"
418                        quantity = f"{required_memory} of memory"
419
420                if quantity is not None and resource is not None:
421                    message = (
422                        "Cannot deploy pipeline due to insufficient resources. "
423                        f"Your pipeline needs {quantity} to run but there is not enough {resource} currently available. "
424                        "Please try again or un-deploy pipelines not in use to adjust the resources that are available for your Wallaroo instance. "
425                        "Contact your Wallaroo platform administrator for additional support."
426                    )
427
428            raise WaitForDeployError(message)
429
430    def wait_for_undeployed(self) -> "Deployment":
431        """Waits for the deployment to end.
432
433        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
434        in the "Client" object constructor. Will raise various exceptions on failures.
435
436        :return: The deployment, for chaining.
437        :rtype: Deployment
438        """
439
440        def check_for_undeployed() -> Tuple[bool, str, str]:
441            try:
442                self.status()
443                return False, "still running", "runtime"
444            except CommunicationError as ex:
445                # Connection may be coming up, try again
446                return False, str(ex), "comm"
447            except RuntimeError as ex:
448                # Not found may switch to found, after a while. Retry it.
449                return False, f"not found {ex}", "runtime"
450            except EntityNotFoundError:
451                return True, "", ""
452
453        try:
454            return self._wait_for(check_for_undeployed, "undeployment")
455        except WaitForError as ex:
456            message = f"{str(ex)}\nStatus: {str(ex.status)}"
457            raise RuntimeError(message)
458
459    @staticmethod
460    def _write_table_to_arrow_file(table: pa.Table, schema: pa.Schema):
461        sink = pa.BufferOutputStream()
462        with pa.ipc.new_file(sink, schema) as arrow_ipc:
463            arrow_ipc.write(table)
464            arrow_ipc.close()
465        return sink.getvalue()
466
467    def _infer_with_pandas(
468        self,
469        tensor: pd.DataFrame,
470        params: Mapping[str, Union[str, Sequence[str]]],
471        timeout: Optional[Union[int, float]],
472    ) -> pd.DataFrame:
473        input_records = tensor.to_json(orient="records")
474        headers = {
475            "Content-Type": PANDAS_RECORDS_HEADER,
476            "Accept": PANDAS_RECORDS_HEADER,
477        }
478        res = self._make_infer_request(
479            data=input_records,
480            headers=headers,
481            params=params,
482            timeout=timeout,
483        )
484        try:
485            data = res.json()
486        except (json.JSONDecodeError, ValueError) as err:
487            raise RuntimeError(f"Inference unable to complete.") from err
488        return pd.DataFrame.from_records(data)
489
490    def _infer_with_arrow(
491        self,
492        tensor: pa.Table,
493        params: Mapping[str, Union[str, Sequence[str]]],
494        timeout: Optional[Union[int, float]],
495    ) -> pa.Table:
496        input_arrow = self._write_table_to_arrow_file(tensor, tensor.schema)
497        headers = {"Content-Type": ARROW_HEADER, "Accept": ARROW_HEADER}
498        res = self._make_infer_request(
499            data=input_arrow,
500            headers=headers,
501            params=params,
502            timeout=timeout,
503        )
504        with pa.ipc.open_file(res.content) as reader:
505            data_table = reader.read_all()
506        return data_table
507
508    def _make_infer_request(
509        self,
510        data: Union[Dict[str, Any], pd.DataFrame],
511        headers: Dict[str, str],
512        params: Mapping[str, Union[str, Sequence[str]]],
513        timeout: Optional[Union[int, float]] = None,
514    ) -> requests.Response:
515
516        try:
517            res = self.session.post(
518                self._url(),
519                data=data,
520                timeout=timeout,
521                # TODO: Digging the auth object out of the gql_client is cheating
522                auth=self._get_auth(),
523                params=params,
524                headers=headers,
525            )
526        except (
527            requests.exceptions.Timeout,
528            requests.exceptions.ReadTimeout,
529            requests.exceptions.RequestException,
530        ):
531            raise RuntimeError(
532                f"Inference did not return within {timeout}s, adjust if necessary"
533            )
534        return res
535
536    def infer(
537        self,
538        tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table],
539        timeout: Optional[Union[int, float]] = None,
540        dataset: Optional[Union[Sequence[str], str]] = None,
541        dataset_exclude: Optional[Union[Sequence[str], str]] = None,
542        dataset_separator: Optional[str] = None,
543    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
544        """
545        Returns an inference result on this deployment, given a tensor.
546        :param: tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table] Inference data. Should be a dictionary.
547        Future improvement: will be a pandas dataframe or arrow table
548        :param: timeout: Optional[Union[int, float]] infer requests will time out after
549            the amount of seconds provided are exceeded. timeout defaults
550            to 15 secs.
551        :param: dataset: Optional[Union[Sequence[str], str]] By default this is set to return, ["time", "out"].
552            Other available options "check_failures", "metadata"
553        :param: dataset_exclude: Optional[Union[Sequence[str], str]] If set, allows user to exclude parts of dataset.
554        :param: dataset_separator: Optional[str] If set to ".", return dataset will be flattened.
555        :return: InferenceResult in dictionary, dataframe or arrow format.
556        """
557
558        if timeout is None:
559            timeout = 15
560        if not isinstance(timeout, (int, float)):
561            raise TypeError(
562                f"timeout is {type(timeout)} but 'int' or 'float' is required"
563            )
564        assert self.client is not None
565
566        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
567            params = dict()
568            params["dataset[]"] = dataset or "*"
569            params["dataset.exclude[]"] = dataset_exclude or "in"
570            params["dataset.separator"] = dataset_separator or "."
571            if not isinstance(tensor, (pd.DataFrame, pa.Table)):
572                raise TypeError(
573                    f"tensor is of type {type(tensor)} but 'pandas.DataFrame' or 'pyarrow.Table' is required"
574                )
575            if isinstance(tensor, pd.DataFrame):
576                return self._infer_with_pandas(tensor, params, timeout)
577            else:
578                return self._infer_with_arrow(tensor, params, timeout)
579        else:
580            if not isinstance(tensor, dict):
581                raise TypeError(f"tensor is {type(tensor)} but 'dict' is required")
582            headers = {"Content-Type": JSON_HEADER}
583            try:
584                res = self.session.post(
585                    self._url(),
586                    json=tensor,
587                    timeout=timeout,
588                    # TODO: Digging the auth object out of the gql_client is cheating
589                    auth=self._get_auth(),
590                    headers=headers,
591                )
592            except (
593                requests.exceptions.Timeout,
594                requests.exceptions.ReadTimeout,
595                requests.exceptions.RequestException,
596            ):
597                raise RuntimeError(
598                    f"Inference did not return within {timeout}s, adjust if necessary"
599                )
600            try:
601                data = res.json()
602            except (json.JSONDecodeError, ValueError) as err:
603                raise RuntimeError(f"Inference unable to complete.") from err
604            return [InferenceResult(self._gql_client, d) for d in data]
605
606    def infer_from_file(
607        self,
608        filename: Union[str, pathlib.Path],
609        timeout: Optional[Union[int, float]] = None,
610        dataset: Optional[Sequence[str]] = None,
611        exclude: Optional[Sequence[str]] = None,
612        dataset_separator: Optional[str] = None,
613    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
614        if not isinstance(filename, pathlib.Path):
615            filename = pathlib.Path(filename)
616        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
617            if filename.suffix.lower() == ".arrow":
618                with pa.ipc.open_file(filename) as source:
619                    tensor = source.read_all()
620            elif filename.suffix.lower() == ".json":
621                tensor = pd.read_json(filename)
622            else:
623                raise TypeError(
624                    f" File is of type {filename.suffix.lower()}, but only '.arrow' or '.json' are accepted"
625                )
626        else:
627            with filename.open("rb") as f:
628                tensor = json.load(f)
629        return self.infer(tensor, timeout, dataset, exclude, dataset_separator)
630
631    async def batch_infer_from_file(
632        self,
633        filename: Union[str, pathlib.Path],
634        data_key: str = "tensor",
635        batch_size: int = 1000,
636        connector_limit: int = 4,
637    ) -> List[InferenceResult]:
638        """Async method to run batched inference on a data file for a given deployment.
639
640        :param str filename: path to an existing file with tensor data in JSON format.
641        :param str data_key: key which the tensor data is under within the JSON. defaults to "tensor".
642        :param int batch_size: batch size to use when sending requests to the engine. defaults to 1000.
643        :param int connector_limit: limit for the amount of TCP connections. defaults to 4.
644        :return: List of InferenceResult's.
645        :rtype: List[InferenceResult]
646        """
647        if not isinstance(filename, pathlib.Path):
648            filename = pathlib.Path(filename)
649        with filename.open("rb") as f:
650            json_data = orjson.loads(f.read())
651
652        input_data = json_data[data_key]
653        chunked_data = self._generate_chunk_data(input_data, batch_size, data_key)
654
655        assert self.client is not None
656        url = self._url()
657        auth = self._get_auth()
658        connector = aiohttp.TCPConnector(limit=connector_limit)
659
660        headers = {}
661        if isinstance(auth, _PlatformAuth):
662            headers = auth.auth_header()
663
664        async with aiohttp.ClientSession(
665            connector=connector,
666            headers=headers,
667            json_serialize=lambda x: orjson.dumps(x).decode(),
668        ) as session:
669            requests = []
670            for i, chunk in enumerate(chunked_data):
671                requests.append(
672                    asyncio.ensure_future(self._batch_infer(session, url, chunk))
673                )
674
675            resps = await asyncio.gather(*requests)
676            return [InferenceResult(self._gql_client, resp) for resp in resps]
677
678    def _chunk_data(self, data: Any, batch_size: int) -> Iterator[List[Any]]:
679        for i in range(0, len(data), batch_size):
680            yield data[i : i + batch_size]
681
682    def _key_data(self, key: str, data: Any) -> Dict[str, Any]:
683        return {key: data}
684
685    def _generate_chunk_data(
686        self, data: Dict[str, Any], batch_size: int, key: str
687    ) -> Iterator[Dict[str, Any]]:
688        chunked_data = self._chunk_data(data, batch_size)
689        return (self._key_data(key, chunk) for chunk in chunked_data)
690
691    async def _batch_infer(
692        self, session: aiohttp.ClientSession, url: str, batch_data: Dict[str, Any]
693    ):
694        if not isinstance(batch_data, dict):
695            raise TypeError(f"tensor is {type(batch_data)} but 'dict' is required")
696        async with session.post(url, json=batch_data) as resp:
697            resp_data = await resp.json(content_type=None)
698            # NOTE: resp data comes back in a list, returning first elem for parity with requests resp
699            return resp_data[0]
700
701    def replace_model(self, model: "Model") -> "Deployment":
702        """Replaces the current model with a default-configured Model.
703
704        :param Model model: Model variant to replace current model with
705        """
706        return self.replace_configured_model(model.config())
707
708    def replace_configured_model(self, model_config: ModelConfig) -> "Deployment":
709        """Replaces the current model with a configured variant.
710
711        :param ModelConfig model_config: Configured model to replace current model with
712        """
713        data = self._gql_client.execute(
714            gql.gql(
715                """
716            mutation ReplaceModel($deployment_id: bigint!, $model_config_id: bigint!) {
717                insert_deployment_model_configs(objects: {deployment_id: $deployment_id, model_config_id: $model_config_id}) {
718                    returning {
719                        id
720                        deployment_id
721                        model_config_id
722                    }
723                }
724            }
725        """
726            ),
727            variable_values={
728                "deployment_id": self.id(),
729                "model_config_id": model_config.id(),
730            },
731        )
732        self._rehydrate()
733        return self
734
735    def internal_url(self) -> str:
736        """Returns the internal inference URL that is only reachable from inside of the Wallaroo cluster by SDK instances deployed in the cluster.
737
738        If both pipelines and models are configured on the Deployment, this
739        gives preference to pipelines. The returned URL is always for the first
740        configured pipeline or model.
741        """
742        return self._internal_url()
743
744    def _internal_url(self) -> str:
745        return (
746            f"http://engine-lb.{self.name()}-{self.id()}:29502/pipelines/{self.name()}"
747        )
748
749    def url(self) -> str:
750        """Returns the inference URL.
751
752        If both pipelines and models are configured on the Deployment, this
753        gives preference to pipelines. The returned URL is always for the first
754        configured pipeline or model.
755        """
756        return self._url()
757
758    def _url(self) -> str:
759        if self.client is None:
760            raise RuntimeError("api_endpoint must be provided to wallaroo.Client()")
761        if "api-lb" in self.client.api_endpoint:
762            return self._internal_url()
763
764        return f"{self.client.api_endpoint}/v1/api/pipelines/infer/{self.name()}-{self.id()}"
765
766    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
767        """Deployment.logs() has been removed. Please use pipeline.logs() instead."""
768        raise RuntimeError(
769            "Deployment.logs() has been removed. Please use pipeline.logs() instead."
770        )
class WaitForError(builtins.Exception):
49class WaitForError(Exception):
50    def __init__(self, message: str, status: Optional[Dict[str, Any]]):
51        super().__init__(message)
52        self.status = status

Common base class for all non-exit exceptions.

WaitForError(message: str, status: Optional[Dict[str, Any]])
50    def __init__(self, message: str, status: Optional[Dict[str, Any]]):
51        super().__init__(message)
52        self.status = status
Inherited Members
builtins.BaseException
with_traceback
args
class WaitForDeployError(builtins.RuntimeError):
55class WaitForDeployError(RuntimeError):
56    def __init__(self, message: str):
57        super().__init__(message)
58
59    def _render_traceback_(self):
60        display(
61            HTML(
62                "<strong>*** An error occurred while deploying your pipeline.</strong>"
63            )
64        )
65        return [str(self)]

Unspecified run-time error.

WaitForDeployError(message: str)
56    def __init__(self, message: str):
57        super().__init__(message)
Inherited Members
builtins.BaseException
with_traceback
args
class Deployment(wallaroo.object.Object):
 68class Deployment(Object):
 69    def __init__(self, client: Optional["Client"], data: Dict[str, Any]) -> None:
 70        self.client = client
 71        assert client is not None
 72        # TODO: revisit session initialization during connection pooling work
 73        self.session = self._initialize_session()
 74        super().__init__(gql_client=client._gql_client, data=data)
 75
 76    def _fill(self, data: Dict[str, Any]) -> None:
 77        """Fills an object given a response dictionary from the GraphQL API.
 78
 79        Only the primary key member must be present; other members will be
 80        filled in via rehydration if their corresponding member function is
 81        called.
 82        """
 83        from .pipeline_variant import PipelineVariant  # avoids circular imports
 84
 85        for required_attribute in ["id"]:
 86            if required_attribute not in data:
 87                raise RequiredAttributeMissing(
 88                    self.__class__.__name__, required_attribute
 89                )
 90        self._id = data["id"]
 91
 92        self._name = value_if_present(data, "deploy_id")
 93        self._deployed = value_if_present(data, "deployed")
 94        self._model_configs = (
 95            [
 96                ModelConfig(self.client, elem["model_config"])
 97                for elem in data["deployment_model_configs"]
 98            ]
 99            if "deployment_model_configs" in data
100            else DehydratedValue()
101        )
102        self._pipeline_variants = (
103            [
104                PipelineVariant(self.client, elem["pipeline_version"])
105                for elem in data["deployment_pipeline_versions"]
106            ]
107            if "deployment_pipeline_versions" in data
108            else DehydratedValue()
109        )
110
111        self._pipeline_id = value_if_present(data, "pipeline_id")
112
113    def _fetch_attributes(self) -> Dict[str, Any]:
114        """Fetches all member data from the GraphQL API."""
115        return self._gql_client.execute(
116            gql.gql(
117                """
118            query DeploymentById($deployment_id: bigint!) {
119                deployment_by_pk(id: $deployment_id) {
120                    id
121                    deploy_id
122                    deployed
123                    deployment_model_configs {
124                        model_config {
125                            id
126                        }
127                    }
128                    deployment_pipeline_versions(order_by: {pipeline_version: {id: desc}}) {
129                        pipeline_version {
130                            id
131                        }
132                    }
133                }
134            }
135            """
136            ),
137            variable_values={
138                "deployment_id": self._id,
139            },
140        )["deployment_by_pk"]
141
142    def _initialize_session(self) -> requests.Session:
143        # TODO: make session initialization configurable
144        #  to be informed by connection polling reqs.
145        #  includes sane defaults to match current retry time (~45s)
146        session = requests.Session()
147        retries = Retry(
148            total=10,
149            backoff_factor=0.1,
150            status_forcelist=[503],
151            allowed_methods=["GET", "POST"],
152        )
153        session.mount("http://", HTTPAdapter(max_retries=retries))
154        return session
155
156    def id(self) -> int:
157        return self._id
158
159    @rehydrate("_name")
160    def name(self) -> str:
161        return cast(str, self._name)
162
163    @rehydrate("_deployed")
164    def deployed(self) -> bool:
165        return cast(bool, self._deployed)
166
167    @rehydrate("_model_configs")
168    def model_configs(self) -> List[ModelConfig]:
169        return cast(List[ModelConfig], self._model_configs)
170
171    @rehydrate("_pipeline_variants")
172    def pipeline_variants(self) -> List["PipelineVariant"]:
173        from .pipeline_variant import PipelineVariant  # avoids circular imports
174
175        return cast(List[PipelineVariant], self._pipeline_variants)
176
177    def deploy(self) -> "Deployment":
178        """Deploys this deployment, if it is not already deployed.
179
180        If the deployment is already deployed, this is a no-op.
181        """
182        q = gql.gql(
183            """
184        mutation Deploy($id: bigint!) {
185            update_deployment_by_pk(pk_columns: {id: $id} _set: { deployed: true }) {
186                id
187                deploy_id
188                deployed
189            }
190        }
191        """
192        )
193        variables = {"id": self.id()}
194        assert self.client is not None
195        self.client._gql_client.execute(q, variable_values=variables)
196        self._rehydrate()
197        return self
198
199    def undeploy(self) -> "Deployment":
200        """Shuts down this deployment, if it is deployed.
201
202        If the deployment is already undeployed, this is a no-op.
203        """
204        # TODO: Error handling.
205        assert self.client is not None
206
207        data = pipelines_undeploy.sync_detailed(
208            client=self.client.mlops(),
209            json_body=pipelines_undeploy_json_body.PipelinesUndeployJsonBody(
210                UNSET, self.id()
211            ),
212        )
213
214        if data.status_code != 200:
215            err = data.content.decode("utf-8")
216            raise Exception(f"Failed to undeploy. {err}")
217
218        self._rehydrate()
219        return self.wait_for_undeployed()
220
221    def _get_auth(self):
222        # TODO: Digging the auth object out of the gql_client is cheating
223        return self.client._gql_client.transport.auth
224
225    def status(self) -> Dict[str, Any]:
226        """Returns a dict of deployment status useful for determining if a deployment has succeeded.
227
228        :return: Dict of deployment internal state information.
229        :rtype: Dict[str, Any]
230        """
231
232        assert self.client is not None
233
234        status_url = f"{self.client.api_endpoint}/v1/api/status/get_deployment"
235        params = {"name": f"{self.name()}-{self.id()}"}
236
237        kind = ""
238        resp = None
239        try:
240            resp = requests.post(
241                status_url,
242                timeout=5,
243                auth=self._get_auth(),
244                json=params,
245            )
246            kind = ""
247        except requests.ReadTimeout as ex:
248            raise CommunicationError(f"rest-api connection to {status_url}")
249        except Exception:
250            kind = "comm"
251
252        if resp is not None and resp.status_code == 200:
253            res = resp.json()
254            if res is not None and res["status"] == "Running":
255                # retry for a running status
256                return res
257
258        details = ""
259        if resp is not None:
260            if resp.status_code == 200:
261                return resp.json()
262
263            if resp.status_code == 404:
264                raise EntityNotFoundError(
265                    f"Deployment not found", {"name": self.name()}
266                )
267
268            details = f"\nStatus code: {resp.status_code}\nMessage: {resp.text}"
269
270        if kind == "comm":
271            raise CommunicationError(f"rest-api connection to {status_url}")
272
273        raise RuntimeError(f"Unable to query deployment status {status_url}{details}")
274
275    def check_limit_status(self):
276        q = gql.gql(
277            """
278            query QueryLimitStatus($id: bigint!) {
279                deployment(where: {id: {_eq: $id}}) {
280                    id
281                    deployed
282                    limit_status
283                }
284            }
285            """
286        )
287
288        variables = {"id": self.id()}
289        assert self.client is not None
290        res = self.client._gql_client.execute(q, variable_values=variables)[
291            "deployment"
292        ]
293        if len(res) > 0:
294            status = res[0]
295            if "limit_status" in status:
296                limit_status = status["limit_status"]
297                if limit_status is not None:
298                    raise LimitError(limit_status)
299
300    def _wait_for(
301        self,
302        poll_fn: Callable[[], Tuple[bool, str, str]],
303        task_name: str,
304        on_iter: Callable[[int], None] = lambda ix: None,
305        timeout: Optional[int] = None,
306    ) -> "Deployment":
307        """Wait for a generic task to finish before completing.
308
309        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
310
311
312        :return: The deployment, for chaining.
313        :rtype: Deployment
314        """
315        assert self.client is not None
316        warning = False
317        duration = timeout if timeout is not None else self.client.timeout
318        message = "(none)"
319        kind = "unset"
320
321        start = time.monotonic()
322        ix = 0
323        while ix == 0 or time.monotonic() - start < duration:
324            on_iter(ix)
325            ix += 1
326
327            res, message, kind = poll_fn()
328            if res:
329                if self.client._interactive:
330                    sys.stdout.write(" ok\n")
331                return self
332
333            if self.client._interactive:
334                if not warning:
335                    sys.stdout.write(
336                        f"Waiting for {task_name} - this will take up to {duration}s "
337                    )
338                    warning = True
339                time.sleep(1)
340                sys.stdout.write(".")
341            else:
342                time.sleep(1)
343
344        if kind == "comm":
345            raise CommunicationError(message)
346        else:
347            try:
348                status: Optional[Dict[str, Any]] = self.status()
349                message = f"{task_name.capitalize()} failed. See status for details."
350            except:
351                message = f"{task_name.capitalize()} did not finish within {duration}s."
352                status = None
353            raise WaitForError(message, status)
354
355    def wait_for_running(self, timeout: Optional[int] = None) -> "Deployment":
356        """Waits for the deployment status to enter the "Running" state.
357
358        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
359        in the "Client" object constructor. Will raise various exceptions on failures.
360
361        :return: The deployment, for chaining.
362        :rtype: Deployment
363        """
364
365        def check_limit(ix: int) -> None:
366            # If this checks immediately, it will happen too soon for the deployment manager to
367            # have cleared the limit_status column on the deployment and this will fail erroneously
368            if ix > 5:
369                self.check_limit_status()
370
371        def check_for_running() -> Tuple[bool, str, str]:
372            try:
373                res = self.status()
374                if res is not None and res["status"] == "Running":
375                    return True, "", ""
376                return False, "not running", "runtime"
377            except CommunicationError as ex:
378                # Connection may be coming up, try again
379                return False, str(ex), "comm"
380            except (EntityNotFoundError, RuntimeError) as ex:
381                # Not found may switch to found, after a while. Retry it.
382                return False, f"not found {ex}", "runtime"
383
384        try:
385            return self._wait_for(check_for_running, "deployment", check_limit, timeout)
386        except WaitForError as ex:
387            status = ex.status
388            message = f"{str(ex)}\nStatus: {str(status)}"
389            if status is not None and status.get("status") == "Error":
390                quantity, resource = None, None
391                engines = status.get("engines", [])
392                engine_lbs = status.get("engine_lbs", [])
393                required_cpu = next(
394                    filter(
395                        lambda item: item.get("status") == "Pending"
396                        and item.get("required_cpu"),
397                        engines + engine_lbs,
398                    ),
399                    cast(Dict[str, Any], {}),
400                ).get("required_cpu")
401                if required_cpu:
402                    resource = "CPU"
403                    quantity = (
404                        "one CPU"
405                        if required_cpu == "1"
406                        else f"{required_cpu} units of CPU"
407                    )
408                else:
409                    required_memory = next(
410                        filter(
411                            lambda item: item.get("status") == "Pending"
412                            and item.get("required_memory"),
413                            engines + engine_lbs,
414                        ),
415                        cast(Dict[str, Any], {}),
416                    ).get("required_memory")
417                    if required_memory:
418                        resource = "memory"
419                        quantity = f"{required_memory} of memory"
420
421                if quantity is not None and resource is not None:
422                    message = (
423                        "Cannot deploy pipeline due to insufficient resources. "
424                        f"Your pipeline needs {quantity} to run but there is not enough {resource} currently available. "
425                        "Please try again or un-deploy pipelines not in use to adjust the resources that are available for your Wallaroo instance. "
426                        "Contact your Wallaroo platform administrator for additional support."
427                    )
428
429            raise WaitForDeployError(message)
430
431    def wait_for_undeployed(self) -> "Deployment":
432        """Waits for the deployment to end.
433
434        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
435        in the "Client" object constructor. Will raise various exceptions on failures.
436
437        :return: The deployment, for chaining.
438        :rtype: Deployment
439        """
440
441        def check_for_undeployed() -> Tuple[bool, str, str]:
442            try:
443                self.status()
444                return False, "still running", "runtime"
445            except CommunicationError as ex:
446                # Connection may be coming up, try again
447                return False, str(ex), "comm"
448            except RuntimeError as ex:
449                # Not found may switch to found, after a while. Retry it.
450                return False, f"not found {ex}", "runtime"
451            except EntityNotFoundError:
452                return True, "", ""
453
454        try:
455            return self._wait_for(check_for_undeployed, "undeployment")
456        except WaitForError as ex:
457            message = f"{str(ex)}\nStatus: {str(ex.status)}"
458            raise RuntimeError(message)
459
460    @staticmethod
461    def _write_table_to_arrow_file(table: pa.Table, schema: pa.Schema):
462        sink = pa.BufferOutputStream()
463        with pa.ipc.new_file(sink, schema) as arrow_ipc:
464            arrow_ipc.write(table)
465            arrow_ipc.close()
466        return sink.getvalue()
467
468    def _infer_with_pandas(
469        self,
470        tensor: pd.DataFrame,
471        params: Mapping[str, Union[str, Sequence[str]]],
472        timeout: Optional[Union[int, float]],
473    ) -> pd.DataFrame:
474        input_records = tensor.to_json(orient="records")
475        headers = {
476            "Content-Type": PANDAS_RECORDS_HEADER,
477            "Accept": PANDAS_RECORDS_HEADER,
478        }
479        res = self._make_infer_request(
480            data=input_records,
481            headers=headers,
482            params=params,
483            timeout=timeout,
484        )
485        try:
486            data = res.json()
487        except (json.JSONDecodeError, ValueError) as err:
488            raise RuntimeError(f"Inference unable to complete.") from err
489        return pd.DataFrame.from_records(data)
490
491    def _infer_with_arrow(
492        self,
493        tensor: pa.Table,
494        params: Mapping[str, Union[str, Sequence[str]]],
495        timeout: Optional[Union[int, float]],
496    ) -> pa.Table:
497        input_arrow = self._write_table_to_arrow_file(tensor, tensor.schema)
498        headers = {"Content-Type": ARROW_HEADER, "Accept": ARROW_HEADER}
499        res = self._make_infer_request(
500            data=input_arrow,
501            headers=headers,
502            params=params,
503            timeout=timeout,
504        )
505        with pa.ipc.open_file(res.content) as reader:
506            data_table = reader.read_all()
507        return data_table
508
509    def _make_infer_request(
510        self,
511        data: Union[Dict[str, Any], pd.DataFrame],
512        headers: Dict[str, str],
513        params: Mapping[str, Union[str, Sequence[str]]],
514        timeout: Optional[Union[int, float]] = None,
515    ) -> requests.Response:
516
517        try:
518            res = self.session.post(
519                self._url(),
520                data=data,
521                timeout=timeout,
522                # TODO: Digging the auth object out of the gql_client is cheating
523                auth=self._get_auth(),
524                params=params,
525                headers=headers,
526            )
527        except (
528            requests.exceptions.Timeout,
529            requests.exceptions.ReadTimeout,
530            requests.exceptions.RequestException,
531        ):
532            raise RuntimeError(
533                f"Inference did not return within {timeout}s, adjust if necessary"
534            )
535        return res
536
537    def infer(
538        self,
539        tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table],
540        timeout: Optional[Union[int, float]] = None,
541        dataset: Optional[Union[Sequence[str], str]] = None,
542        dataset_exclude: Optional[Union[Sequence[str], str]] = None,
543        dataset_separator: Optional[str] = None,
544    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
545        """
546        Returns an inference result on this deployment, given a tensor.
547        :param: tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table] Inference data. Should be a dictionary.
548        Future improvement: will be a pandas dataframe or arrow table
549        :param: timeout: Optional[Union[int, float]] infer requests will time out after
550            the amount of seconds provided are exceeded. timeout defaults
551            to 15 secs.
552        :param: dataset: Optional[Union[Sequence[str], str]] By default this is set to return, ["time", "out"].
553            Other available options "check_failures", "metadata"
554        :param: dataset_exclude: Optional[Union[Sequence[str], str]] If set, allows user to exclude parts of dataset.
555        :param: dataset_separator: Optional[str] If set to ".", return dataset will be flattened.
556        :return: InferenceResult in dictionary, dataframe or arrow format.
557        """
558
559        if timeout is None:
560            timeout = 15
561        if not isinstance(timeout, (int, float)):
562            raise TypeError(
563                f"timeout is {type(timeout)} but 'int' or 'float' is required"
564            )
565        assert self.client is not None
566
567        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
568            params = dict()
569            params["dataset[]"] = dataset or "*"
570            params["dataset.exclude[]"] = dataset_exclude or "in"
571            params["dataset.separator"] = dataset_separator or "."
572            if not isinstance(tensor, (pd.DataFrame, pa.Table)):
573                raise TypeError(
574                    f"tensor is of type {type(tensor)} but 'pandas.DataFrame' or 'pyarrow.Table' is required"
575                )
576            if isinstance(tensor, pd.DataFrame):
577                return self._infer_with_pandas(tensor, params, timeout)
578            else:
579                return self._infer_with_arrow(tensor, params, timeout)
580        else:
581            if not isinstance(tensor, dict):
582                raise TypeError(f"tensor is {type(tensor)} but 'dict' is required")
583            headers = {"Content-Type": JSON_HEADER}
584            try:
585                res = self.session.post(
586                    self._url(),
587                    json=tensor,
588                    timeout=timeout,
589                    # TODO: Digging the auth object out of the gql_client is cheating
590                    auth=self._get_auth(),
591                    headers=headers,
592                )
593            except (
594                requests.exceptions.Timeout,
595                requests.exceptions.ReadTimeout,
596                requests.exceptions.RequestException,
597            ):
598                raise RuntimeError(
599                    f"Inference did not return within {timeout}s, adjust if necessary"
600                )
601            try:
602                data = res.json()
603            except (json.JSONDecodeError, ValueError) as err:
604                raise RuntimeError(f"Inference unable to complete.") from err
605            return [InferenceResult(self._gql_client, d) for d in data]
606
607    def infer_from_file(
608        self,
609        filename: Union[str, pathlib.Path],
610        timeout: Optional[Union[int, float]] = None,
611        dataset: Optional[Sequence[str]] = None,
612        exclude: Optional[Sequence[str]] = None,
613        dataset_separator: Optional[str] = None,
614    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
615        if not isinstance(filename, pathlib.Path):
616            filename = pathlib.Path(filename)
617        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
618            if filename.suffix.lower() == ".arrow":
619                with pa.ipc.open_file(filename) as source:
620                    tensor = source.read_all()
621            elif filename.suffix.lower() == ".json":
622                tensor = pd.read_json(filename)
623            else:
624                raise TypeError(
625                    f" File is of type {filename.suffix.lower()}, but only '.arrow' or '.json' are accepted"
626                )
627        else:
628            with filename.open("rb") as f:
629                tensor = json.load(f)
630        return self.infer(tensor, timeout, dataset, exclude, dataset_separator)
631
632    async def batch_infer_from_file(
633        self,
634        filename: Union[str, pathlib.Path],
635        data_key: str = "tensor",
636        batch_size: int = 1000,
637        connector_limit: int = 4,
638    ) -> List[InferenceResult]:
639        """Async method to run batched inference on a data file for a given deployment.
640
641        :param str filename: path to an existing file with tensor data in JSON format.
642        :param str data_key: key which the tensor data is under within the JSON. defaults to "tensor".
643        :param int batch_size: batch size to use when sending requests to the engine. defaults to 1000.
644        :param int connector_limit: limit for the amount of TCP connections. defaults to 4.
645        :return: List of InferenceResult's.
646        :rtype: List[InferenceResult]
647        """
648        if not isinstance(filename, pathlib.Path):
649            filename = pathlib.Path(filename)
650        with filename.open("rb") as f:
651            json_data = orjson.loads(f.read())
652
653        input_data = json_data[data_key]
654        chunked_data = self._generate_chunk_data(input_data, batch_size, data_key)
655
656        assert self.client is not None
657        url = self._url()
658        auth = self._get_auth()
659        connector = aiohttp.TCPConnector(limit=connector_limit)
660
661        headers = {}
662        if isinstance(auth, _PlatformAuth):
663            headers = auth.auth_header()
664
665        async with aiohttp.ClientSession(
666            connector=connector,
667            headers=headers,
668            json_serialize=lambda x: orjson.dumps(x).decode(),
669        ) as session:
670            requests = []
671            for i, chunk in enumerate(chunked_data):
672                requests.append(
673                    asyncio.ensure_future(self._batch_infer(session, url, chunk))
674                )
675
676            resps = await asyncio.gather(*requests)
677            return [InferenceResult(self._gql_client, resp) for resp in resps]
678
679    def _chunk_data(self, data: Any, batch_size: int) -> Iterator[List[Any]]:
680        for i in range(0, len(data), batch_size):
681            yield data[i : i + batch_size]
682
683    def _key_data(self, key: str, data: Any) -> Dict[str, Any]:
684        return {key: data}
685
686    def _generate_chunk_data(
687        self, data: Dict[str, Any], batch_size: int, key: str
688    ) -> Iterator[Dict[str, Any]]:
689        chunked_data = self._chunk_data(data, batch_size)
690        return (self._key_data(key, chunk) for chunk in chunked_data)
691
692    async def _batch_infer(
693        self, session: aiohttp.ClientSession, url: str, batch_data: Dict[str, Any]
694    ):
695        if not isinstance(batch_data, dict):
696            raise TypeError(f"tensor is {type(batch_data)} but 'dict' is required")
697        async with session.post(url, json=batch_data) as resp:
698            resp_data = await resp.json(content_type=None)
699            # NOTE: resp data comes back in a list, returning first elem for parity with requests resp
700            return resp_data[0]
701
702    def replace_model(self, model: "Model") -> "Deployment":
703        """Replaces the current model with a default-configured Model.
704
705        :param Model model: Model variant to replace current model with
706        """
707        return self.replace_configured_model(model.config())
708
709    def replace_configured_model(self, model_config: ModelConfig) -> "Deployment":
710        """Replaces the current model with a configured variant.
711
712        :param ModelConfig model_config: Configured model to replace current model with
713        """
714        data = self._gql_client.execute(
715            gql.gql(
716                """
717            mutation ReplaceModel($deployment_id: bigint!, $model_config_id: bigint!) {
718                insert_deployment_model_configs(objects: {deployment_id: $deployment_id, model_config_id: $model_config_id}) {
719                    returning {
720                        id
721                        deployment_id
722                        model_config_id
723                    }
724                }
725            }
726        """
727            ),
728            variable_values={
729                "deployment_id": self.id(),
730                "model_config_id": model_config.id(),
731            },
732        )
733        self._rehydrate()
734        return self
735
736    def internal_url(self) -> str:
737        """Returns the internal inference URL that is only reachable from inside of the Wallaroo cluster by SDK instances deployed in the cluster.
738
739        If both pipelines and models are configured on the Deployment, this
740        gives preference to pipelines. The returned URL is always for the first
741        configured pipeline or model.
742        """
743        return self._internal_url()
744
745    def _internal_url(self) -> str:
746        return (
747            f"http://engine-lb.{self.name()}-{self.id()}:29502/pipelines/{self.name()}"
748        )
749
750    def url(self) -> str:
751        """Returns the inference URL.
752
753        If both pipelines and models are configured on the Deployment, this
754        gives preference to pipelines. The returned URL is always for the first
755        configured pipeline or model.
756        """
757        return self._url()
758
759    def _url(self) -> str:
760        if self.client is None:
761            raise RuntimeError("api_endpoint must be provided to wallaroo.Client()")
762        if "api-lb" in self.client.api_endpoint:
763            return self._internal_url()
764
765        return f"{self.client.api_endpoint}/v1/api/pipelines/infer/{self.name()}-{self.id()}"
766
767    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
768        """Deployment.logs() has been removed. Please use pipeline.logs() instead."""
769        raise RuntimeError(
770            "Deployment.logs() has been removed. Please use pipeline.logs() instead."
771        )

Base class for all backend GraphQL API objects.

This class serves as a framework for API objects to be constructed based on a partially-complete JSON response, and to fill in their remaining members dynamically if needed.

Deployment(client: Optional[wallaroo.client.Client], data: Dict[str, Any])
69    def __init__(self, client: Optional["Client"], data: Dict[str, Any]) -> None:
70        self.client = client
71        assert client is not None
72        # TODO: revisit session initialization during connection pooling work
73        self.session = self._initialize_session()
74        super().__init__(gql_client=client._gql_client, data=data)

Base constructor.

Each object requires:

  • a GraphQL client - in order to fill its missing members dynamically
  • an initial data blob - typically from unserialized JSON, contains at
  • least the data for required members (typically the object's primary key) and optionally other data members.
def id(self) -> int:
156    def id(self) -> int:
157        return self._id
def name(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def deployed(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def model_configs(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def pipeline_variants(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def deploy(self) -> wallaroo.deployment.Deployment:
177    def deploy(self) -> "Deployment":
178        """Deploys this deployment, if it is not already deployed.
179
180        If the deployment is already deployed, this is a no-op.
181        """
182        q = gql.gql(
183            """
184        mutation Deploy($id: bigint!) {
185            update_deployment_by_pk(pk_columns: {id: $id} _set: { deployed: true }) {
186                id
187                deploy_id
188                deployed
189            }
190        }
191        """
192        )
193        variables = {"id": self.id()}
194        assert self.client is not None
195        self.client._gql_client.execute(q, variable_values=variables)
196        self._rehydrate()
197        return self

Deploys this deployment, if it is not already deployed.

If the deployment is already deployed, this is a no-op.

def undeploy(self) -> wallaroo.deployment.Deployment:
199    def undeploy(self) -> "Deployment":
200        """Shuts down this deployment, if it is deployed.
201
202        If the deployment is already undeployed, this is a no-op.
203        """
204        # TODO: Error handling.
205        assert self.client is not None
206
207        data = pipelines_undeploy.sync_detailed(
208            client=self.client.mlops(),
209            json_body=pipelines_undeploy_json_body.PipelinesUndeployJsonBody(
210                UNSET, self.id()
211            ),
212        )
213
214        if data.status_code != 200:
215            err = data.content.decode("utf-8")
216            raise Exception(f"Failed to undeploy. {err}")
217
218        self._rehydrate()
219        return self.wait_for_undeployed()

Shuts down this deployment, if it is deployed.

If the deployment is already undeployed, this is a no-op.

def status(self) -> Dict[str, Any]:
225    def status(self) -> Dict[str, Any]:
226        """Returns a dict of deployment status useful for determining if a deployment has succeeded.
227
228        :return: Dict of deployment internal state information.
229        :rtype: Dict[str, Any]
230        """
231
232        assert self.client is not None
233
234        status_url = f"{self.client.api_endpoint}/v1/api/status/get_deployment"
235        params = {"name": f"{self.name()}-{self.id()}"}
236
237        kind = ""
238        resp = None
239        try:
240            resp = requests.post(
241                status_url,
242                timeout=5,
243                auth=self._get_auth(),
244                json=params,
245            )
246            kind = ""
247        except requests.ReadTimeout as ex:
248            raise CommunicationError(f"rest-api connection to {status_url}")
249        except Exception:
250            kind = "comm"
251
252        if resp is not None and resp.status_code == 200:
253            res = resp.json()
254            if res is not None and res["status"] == "Running":
255                # retry for a running status
256                return res
257
258        details = ""
259        if resp is not None:
260            if resp.status_code == 200:
261                return resp.json()
262
263            if resp.status_code == 404:
264                raise EntityNotFoundError(
265                    f"Deployment not found", {"name": self.name()}
266                )
267
268            details = f"\nStatus code: {resp.status_code}\nMessage: {resp.text}"
269
270        if kind == "comm":
271            raise CommunicationError(f"rest-api connection to {status_url}")
272
273        raise RuntimeError(f"Unable to query deployment status {status_url}{details}")

Returns a dict of deployment status useful for determining if a deployment has succeeded.

Returns

Dict of deployment internal state information.

def check_limit_status(self):
275    def check_limit_status(self):
276        q = gql.gql(
277            """
278            query QueryLimitStatus($id: bigint!) {
279                deployment(where: {id: {_eq: $id}}) {
280                    id
281                    deployed
282                    limit_status
283                }
284            }
285            """
286        )
287
288        variables = {"id": self.id()}
289        assert self.client is not None
290        res = self.client._gql_client.execute(q, variable_values=variables)[
291            "deployment"
292        ]
293        if len(res) > 0:
294            status = res[0]
295            if "limit_status" in status:
296                limit_status = status["limit_status"]
297                if limit_status is not None:
298                    raise LimitError(limit_status)
def wait_for_running(self, timeout: Optional[int] = None) -> wallaroo.deployment.Deployment:
355    def wait_for_running(self, timeout: Optional[int] = None) -> "Deployment":
356        """Waits for the deployment status to enter the "Running" state.
357
358        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
359        in the "Client" object constructor. Will raise various exceptions on failures.
360
361        :return: The deployment, for chaining.
362        :rtype: Deployment
363        """
364
365        def check_limit(ix: int) -> None:
366            # If this checks immediately, it will happen too soon for the deployment manager to
367            # have cleared the limit_status column on the deployment and this will fail erroneously
368            if ix > 5:
369                self.check_limit_status()
370
371        def check_for_running() -> Tuple[bool, str, str]:
372            try:
373                res = self.status()
374                if res is not None and res["status"] == "Running":
375                    return True, "", ""
376                return False, "not running", "runtime"
377            except CommunicationError as ex:
378                # Connection may be coming up, try again
379                return False, str(ex), "comm"
380            except (EntityNotFoundError, RuntimeError) as ex:
381                # Not found may switch to found, after a while. Retry it.
382                return False, f"not found {ex}", "runtime"
383
384        try:
385            return self._wait_for(check_for_running, "deployment", check_limit, timeout)
386        except WaitForError as ex:
387            status = ex.status
388            message = f"{str(ex)}\nStatus: {str(status)}"
389            if status is not None and status.get("status") == "Error":
390                quantity, resource = None, None
391                engines = status.get("engines", [])
392                engine_lbs = status.get("engine_lbs", [])
393                required_cpu = next(
394                    filter(
395                        lambda item: item.get("status") == "Pending"
396                        and item.get("required_cpu"),
397                        engines + engine_lbs,
398                    ),
399                    cast(Dict[str, Any], {}),
400                ).get("required_cpu")
401                if required_cpu:
402                    resource = "CPU"
403                    quantity = (
404                        "one CPU"
405                        if required_cpu == "1"
406                        else f"{required_cpu} units of CPU"
407                    )
408                else:
409                    required_memory = next(
410                        filter(
411                            lambda item: item.get("status") == "Pending"
412                            and item.get("required_memory"),
413                            engines + engine_lbs,
414                        ),
415                        cast(Dict[str, Any], {}),
416                    ).get("required_memory")
417                    if required_memory:
418                        resource = "memory"
419                        quantity = f"{required_memory} of memory"
420
421                if quantity is not None and resource is not None:
422                    message = (
423                        "Cannot deploy pipeline due to insufficient resources. "
424                        f"Your pipeline needs {quantity} to run but there is not enough {resource} currently available. "
425                        "Please try again or un-deploy pipelines not in use to adjust the resources that are available for your Wallaroo instance. "
426                        "Contact your Wallaroo platform administrator for additional support."
427                    )
428
429            raise WaitForDeployError(message)

Waits for the deployment status to enter the "Running" state.

Will wait up "timeout_request" seconds for the deployment to enter that state. This is set in the "Client" object constructor. Will raise various exceptions on failures.

Returns

The deployment, for chaining.

def wait_for_undeployed(self) -> wallaroo.deployment.Deployment:
431    def wait_for_undeployed(self) -> "Deployment":
432        """Waits for the deployment to end.
433
434        Will wait up "timeout_request" seconds for the deployment to enter that state. This is set
435        in the "Client" object constructor. Will raise various exceptions on failures.
436
437        :return: The deployment, for chaining.
438        :rtype: Deployment
439        """
440
441        def check_for_undeployed() -> Tuple[bool, str, str]:
442            try:
443                self.status()
444                return False, "still running", "runtime"
445            except CommunicationError as ex:
446                # Connection may be coming up, try again
447                return False, str(ex), "comm"
448            except RuntimeError as ex:
449                # Not found may switch to found, after a while. Retry it.
450                return False, f"not found {ex}", "runtime"
451            except EntityNotFoundError:
452                return True, "", ""
453
454        try:
455            return self._wait_for(check_for_undeployed, "undeployment")
456        except WaitForError as ex:
457            message = f"{str(ex)}\nStatus: {str(ex.status)}"
458            raise RuntimeError(message)

Waits for the deployment to end.

Will wait up "timeout_request" seconds for the deployment to enter that state. This is set in the "Client" object constructor. Will raise various exceptions on failures.

Returns

The deployment, for chaining.

def infer( self, tensor: Union[Dict[str, Any], pandas.core.frame.DataFrame, pyarrow.lib.Table], timeout: Union[int, float, NoneType] = None, dataset: Union[Sequence[str], str, NoneType] = None, dataset_exclude: Union[Sequence[str], str, NoneType] = None, dataset_separator: Optional[str] = None) -> Union[List[wallaroo.inference_result.InferenceResult], pandas.core.frame.DataFrame, pyarrow.lib.Table]:
537    def infer(
538        self,
539        tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table],
540        timeout: Optional[Union[int, float]] = None,
541        dataset: Optional[Union[Sequence[str], str]] = None,
542        dataset_exclude: Optional[Union[Sequence[str], str]] = None,
543        dataset_separator: Optional[str] = None,
544    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
545        """
546        Returns an inference result on this deployment, given a tensor.
547        :param: tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table] Inference data. Should be a dictionary.
548        Future improvement: will be a pandas dataframe or arrow table
549        :param: timeout: Optional[Union[int, float]] infer requests will time out after
550            the amount of seconds provided are exceeded. timeout defaults
551            to 15 secs.
552        :param: dataset: Optional[Union[Sequence[str], str]] By default this is set to return, ["time", "out"].
553            Other available options "check_failures", "metadata"
554        :param: dataset_exclude: Optional[Union[Sequence[str], str]] If set, allows user to exclude parts of dataset.
555        :param: dataset_separator: Optional[str] If set to ".", return dataset will be flattened.
556        :return: InferenceResult in dictionary, dataframe or arrow format.
557        """
558
559        if timeout is None:
560            timeout = 15
561        if not isinstance(timeout, (int, float)):
562            raise TypeError(
563                f"timeout is {type(timeout)} but 'int' or 'float' is required"
564            )
565        assert self.client is not None
566
567        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
568            params = dict()
569            params["dataset[]"] = dataset or "*"
570            params["dataset.exclude[]"] = dataset_exclude or "in"
571            params["dataset.separator"] = dataset_separator or "."
572            if not isinstance(tensor, (pd.DataFrame, pa.Table)):
573                raise TypeError(
574                    f"tensor is of type {type(tensor)} but 'pandas.DataFrame' or 'pyarrow.Table' is required"
575                )
576            if isinstance(tensor, pd.DataFrame):
577                return self._infer_with_pandas(tensor, params, timeout)
578            else:
579                return self._infer_with_arrow(tensor, params, timeout)
580        else:
581            if not isinstance(tensor, dict):
582                raise TypeError(f"tensor is {type(tensor)} but 'dict' is required")
583            headers = {"Content-Type": JSON_HEADER}
584            try:
585                res = self.session.post(
586                    self._url(),
587                    json=tensor,
588                    timeout=timeout,
589                    # TODO: Digging the auth object out of the gql_client is cheating
590                    auth=self._get_auth(),
591                    headers=headers,
592                )
593            except (
594                requests.exceptions.Timeout,
595                requests.exceptions.ReadTimeout,
596                requests.exceptions.RequestException,
597            ):
598                raise RuntimeError(
599                    f"Inference did not return within {timeout}s, adjust if necessary"
600                )
601            try:
602                data = res.json()
603            except (json.JSONDecodeError, ValueError) as err:
604                raise RuntimeError(f"Inference unable to complete.") from err
605            return [InferenceResult(self._gql_client, d) for d in data]

Returns an inference result on this deployment, given a tensor.

Parameters
  • tensor: Union[Dict[str, Any], pd.DataFrame, pa.Table] Inference data. Should be a dictionary. Future improvement: will be a pandas dataframe or arrow table
  • timeout: Optional[Union[int, float]] infer requests will time out after the amount of seconds provided are exceeded. timeout defaults to 15 secs.
  • dataset: Optional[Union[Sequence[str], str]] By default this is set to return, ["time", "out"]. Other available options "check_failures", "metadata"
  • dataset_exclude: Optional[Union[Sequence[str], str]] If set, allows user to exclude parts of dataset.
  • dataset_separator: Optional[str] If set to ".", return dataset will be flattened.
Returns

InferenceResult in dictionary, dataframe or arrow format.

def infer_from_file( self, filename: Union[str, pathlib.Path], timeout: Union[int, float, NoneType] = None, dataset: Optional[Sequence[str]] = None, exclude: Optional[Sequence[str]] = None, dataset_separator: Optional[str] = None) -> Union[List[wallaroo.inference_result.InferenceResult], pandas.core.frame.DataFrame, pyarrow.lib.Table]:
607    def infer_from_file(
608        self,
609        filename: Union[str, pathlib.Path],
610        timeout: Optional[Union[int, float]] = None,
611        dataset: Optional[Sequence[str]] = None,
612        exclude: Optional[Sequence[str]] = None,
613        dataset_separator: Optional[str] = None,
614    ) -> Union[List[InferenceResult], pd.DataFrame, pa.Table]:
615        if not isinstance(filename, pathlib.Path):
616            filename = pathlib.Path(filename)
617        if os.getenv(ARROW_ENABLED, "false").lower() == "true":
618            if filename.suffix.lower() == ".arrow":
619                with pa.ipc.open_file(filename) as source:
620                    tensor = source.read_all()
621            elif filename.suffix.lower() == ".json":
622                tensor = pd.read_json(filename)
623            else:
624                raise TypeError(
625                    f" File is of type {filename.suffix.lower()}, but only '.arrow' or '.json' are accepted"
626                )
627        else:
628            with filename.open("rb") as f:
629                tensor = json.load(f)
630        return self.infer(tensor, timeout, dataset, exclude, dataset_separator)
async def batch_infer_from_file( self, filename: Union[str, pathlib.Path], data_key: str = 'tensor', batch_size: int = 1000, connector_limit: int = 4) -> List[wallaroo.inference_result.InferenceResult]:
632    async def batch_infer_from_file(
633        self,
634        filename: Union[str, pathlib.Path],
635        data_key: str = "tensor",
636        batch_size: int = 1000,
637        connector_limit: int = 4,
638    ) -> List[InferenceResult]:
639        """Async method to run batched inference on a data file for a given deployment.
640
641        :param str filename: path to an existing file with tensor data in JSON format.
642        :param str data_key: key which the tensor data is under within the JSON. defaults to "tensor".
643        :param int batch_size: batch size to use when sending requests to the engine. defaults to 1000.
644        :param int connector_limit: limit for the amount of TCP connections. defaults to 4.
645        :return: List of InferenceResult's.
646        :rtype: List[InferenceResult]
647        """
648        if not isinstance(filename, pathlib.Path):
649            filename = pathlib.Path(filename)
650        with filename.open("rb") as f:
651            json_data = orjson.loads(f.read())
652
653        input_data = json_data[data_key]
654        chunked_data = self._generate_chunk_data(input_data, batch_size, data_key)
655
656        assert self.client is not None
657        url = self._url()
658        auth = self._get_auth()
659        connector = aiohttp.TCPConnector(limit=connector_limit)
660
661        headers = {}
662        if isinstance(auth, _PlatformAuth):
663            headers = auth.auth_header()
664
665        async with aiohttp.ClientSession(
666            connector=connector,
667            headers=headers,
668            json_serialize=lambda x: orjson.dumps(x).decode(),
669        ) as session:
670            requests = []
671            for i, chunk in enumerate(chunked_data):
672                requests.append(
673                    asyncio.ensure_future(self._batch_infer(session, url, chunk))
674                )
675
676            resps = await asyncio.gather(*requests)
677            return [InferenceResult(self._gql_client, resp) for resp in resps]

Async method to run batched inference on a data file for a given deployment.

Parameters
  • str filename: path to an existing file with tensor data in JSON format.
  • str data_key: key which the tensor data is under within the JSON. defaults to "tensor".
  • int batch_size: batch size to use when sending requests to the engine. defaults to 1000.
  • int connector_limit: limit for the amount of TCP connections. defaults to 4.
Returns

List of InferenceResult's.

def replace_model(self, model: wallaroo.model.Model) -> wallaroo.deployment.Deployment:
702    def replace_model(self, model: "Model") -> "Deployment":
703        """Replaces the current model with a default-configured Model.
704
705        :param Model model: Model variant to replace current model with
706        """
707        return self.replace_configured_model(model.config())

Replaces the current model with a default-configured Model.

Parameters
  • Model model: Model variant to replace current model with
def replace_configured_model( self, model_config: wallaroo.model_config.ModelConfig) -> wallaroo.deployment.Deployment:
709    def replace_configured_model(self, model_config: ModelConfig) -> "Deployment":
710        """Replaces the current model with a configured variant.
711
712        :param ModelConfig model_config: Configured model to replace current model with
713        """
714        data = self._gql_client.execute(
715            gql.gql(
716                """
717            mutation ReplaceModel($deployment_id: bigint!, $model_config_id: bigint!) {
718                insert_deployment_model_configs(objects: {deployment_id: $deployment_id, model_config_id: $model_config_id}) {
719                    returning {
720                        id
721                        deployment_id
722                        model_config_id
723                    }
724                }
725            }
726        """
727            ),
728            variable_values={
729                "deployment_id": self.id(),
730                "model_config_id": model_config.id(),
731            },
732        )
733        self._rehydrate()
734        return self

Replaces the current model with a configured variant.

Parameters
  • ModelConfig model_config: Configured model to replace current model with
def internal_url(self) -> str:
736    def internal_url(self) -> str:
737        """Returns the internal inference URL that is only reachable from inside of the Wallaroo cluster by SDK instances deployed in the cluster.
738
739        If both pipelines and models are configured on the Deployment, this
740        gives preference to pipelines. The returned URL is always for the first
741        configured pipeline or model.
742        """
743        return self._internal_url()

Returns the internal inference URL that is only reachable from inside of the Wallaroo cluster by SDK instances deployed in the cluster.

If both pipelines and models are configured on the Deployment, this gives preference to pipelines. The returned URL is always for the first configured pipeline or model.

def url(self) -> str:
750    def url(self) -> str:
751        """Returns the inference URL.
752
753        If both pipelines and models are configured on the Deployment, this
754        gives preference to pipelines. The returned URL is always for the first
755        configured pipeline or model.
756        """
757        return self._url()

Returns the inference URL.

If both pipelines and models are configured on the Deployment, this gives preference to pipelines. The returned URL is always for the first configured pipeline or model.

def logs( self, limit: int = 100, valid: Optional[bool] = None) -> wallaroo.logs.LogEntries:
767    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
768        """Deployment.logs() has been removed. Please use pipeline.logs() instead."""
769        raise RuntimeError(
770            "Deployment.logs() has been removed. Please use pipeline.logs() instead."
771        )

Deployment.logs() has been removed. Please use pipeline.logs() instead.