wallaroo.model

  1import base64
  2import datetime
  3from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
  4
  5import gql  # type: ignore
  6import pyarrow as pa  # type: ignore
  7import requests
  8from dateutil import parser as dateparse
  9
 10from .checks import Variables, require_dns_compliance
 11from .deployment_config import DeploymentConfig, DeploymentConfigBuilder
 12from .logs import LogEntries
 13from .object import *
 14from .version import _user_agent
 15from .visibility import _Visibility
 16
 17if TYPE_CHECKING:
 18    # Imports that happen below in methods to fix circular import dependency
 19    # issues need to also be specified here to satisfy mypy type checking.
 20    from .client import Client
 21    from .model_config import ModelConfig
 22    from .pipeline import Pipeline
 23    from .tag import Tag
 24
 25# A little misleading, this is actually a ModelVersion or ModelVariant, associated with the `model` table, not the `models` table.
 26class Model(Object):
 27    """Wraps a backend Model object."""
 28
 29    def __init__(
 30        self, client: Optional["Client"], data: Dict[str, Any], standalone=False
 31    ) -> None:
 32        self.client = client
 33        self._config: Optional["ModelConfig"] = None
 34        super().__init__(
 35            gql_client=client._gql_client if client is not None else None,
 36            data=data,
 37            standalone=standalone,
 38        )
 39
 40    def __repr__(self) -> str:
 41        return str(
 42            {
 43                "name": self.name(),
 44                "version": self.version(),
 45                "file_name": self.file_name(),
 46                "image_path": self.image_path(),
 47                "last_update_time": self.last_update_time(),
 48            }
 49        )
 50
 51    @staticmethod
 52    def as_standalone(name: str, version: str, file_name: str) -> "Model":
 53        """Creates a Model intended for use in generating standalone configurations"""
 54        constructor_dict = {
 55            "model_id": name,
 56            "model_version": version,
 57            "file_name": file_name,
 58            "id": -1,
 59        }
 60        return Model(None, constructor_dict, standalone=True)
 61
 62    def _fill(self, data: Dict[str, Any]) -> None:
 63        """Fills an object given a response dictionary from the GraphQL API.
 64
 65        Only the primary key member must be present; other members will be
 66        filled in via rehydration if their corresponding member function is
 67        called.
 68        """
 69        from .tag import Tag
 70
 71        for required_attribute in ["id"]:
 72            if required_attribute not in data:
 73                raise RequiredAttributeMissing(
 74                    self.__class__.__name__, required_attribute
 75                )
 76        # Required
 77        self._id = data["id"]
 78
 79        # Optional
 80        self._name = value_if_present(data, "model_id")
 81        self._version = value_if_present(data, "model_version")
 82        self._models_pk_id = value_if_present(data, "models_pk_id")
 83        self._sha = value_if_present(data, "sha")
 84        self._file_name = value_if_present(data, "file_name")
 85        self._image_path = value_if_present(data, "image_path")
 86        self._last_update_time = (
 87            dateparse.isoparse(data["updated_at"])
 88            if "updated_at" in data
 89            else DehydratedValue()
 90        )
 91        self._visibility = (
 92            _Visibility.from_str(data["visibility"])
 93            if "visibility" in data
 94            else DehydratedValue()
 95        )
 96        self._tags = (
 97            [Tag(self.client, tag["tag"]) for tag in data["model_tags"]]
 98            if "model_tags" in data
 99            else DehydratedValue()
100        )
101
102    def _fetch_attributes(self) -> Dict[str, Any]:
103        """Fetches all member data from the GraphQL API."""
104        return self._gql_client.execute(
105            gql.gql(
106                f"""
107            query ModelById {{
108                model_by_pk(id: {self._id}) {{
109                    id
110                    model_id
111                    model_version
112                    models_pk_id    
113                    sha
114                    file_name
115                    image_path
116                    updated_at
117                    visibility
118                    model_tags {{
119                      tag {{
120                        id
121                        tag
122                      }}
123                    }}
124                }}
125            }}
126            """
127            )
128        )["model_by_pk"]
129
130    def id(self) -> int:
131        return self._id
132
133    def uid(self) -> str:
134        return f"{self.name()}-{self.id()}"
135
136    @rehydrate("_name")
137    def name(self) -> str:
138        return cast(str, self._name)
139
140    @rehydrate("_version")
141    def version(self) -> str:
142        return cast(str, self._version)
143
144    # TODO: Find other models by finding the parent?
145    # @rehydrate("_versions")
146    # def versions(self) -> "Models":
147    #     from .models import Models
148
149    #     return Models(client=self.client, data={"id": self._models_pk_id})
150
151    @rehydrate("_models_pk_id")
152    def models_pk_id(self) -> str:
153        return cast(str, self._models_pk_id)
154
155    @rehydrate("_sha")
156    def sha(self) -> str:
157        return cast(str, self._sha)
158
159    @rehydrate("_file_name")
160    def file_name(self) -> str:
161        return cast(str, self._file_name)
162
163    @rehydrate("_image_path")
164    def image_path(self) -> str:
165        return cast(str, self._image_path)
166
167    @rehydrate("_last_update_time")
168    def last_update_time(self) -> datetime.datetime:
169        return cast(datetime.datetime, self._last_update_time)
170
171    @property
172    def inputs(self):
173        return Variables(self.name(), "input")
174
175    @property
176    def outputs(self):
177        return Variables(self.name(), "output")
178
179    @rehydrate("_tags")
180    def tags(self) -> List["Tag"]:
181        from .tag import Tag  # avoids import cycles
182
183        return cast(List[Tag], self._tags)
184
185    @rehydrate("_config")
186    def rehydrate_config(self) -> "ModelConfig":
187        from .model_config import ModelConfig
188
189        if self._config is not None:
190            return self._config
191        if self.client is None:
192            raise Exception("Cannot retrieve current model config, client is None.")
193        assert self.client is not None
194        base = self.client.api_endpoint + f"/v1/api/models/get_config_by_id"
195        headers = {"User-Agent": _user_agent}
196
197        raw = requests.post(
198            base,
199            auth=self.client.auth,
200            headers=headers,
201            json={"model_id": self.id()},
202        )
203        if raw.status_code > 299:
204            raise Exception(
205                "Failed to retrieve config from api, and failed to automatically configure model."
206            )
207        possible_model_config = raw.json()
208        if possible_model_config["model_config"]:
209            return ModelConfig(
210                client=self.client, data=possible_model_config["model_config"]
211            )
212        else:
213            raise Exception(
214                "Failed to determine model configuration, could not auto configure based on name, and no existing "
215                "model configuration was present. "
216            )
217
218    def _update_visibility(self, visibility: _Visibility):
219        assert self.client is not None
220        return self._fill(
221            self.client._gql_client.execute(
222                gql.gql(
223                    """
224                mutation UpdateModelVisibility(
225                    $model_pk: bigint!,
226                    $visibility: String
227                ) {
228                  update_model(where: {id: {_eq: $model_pk}}, _set: {visibility: $visibility}) {
229                        returning {
230                          id
231                          model_id
232                          model_version
233                          file_name
234                          visibility
235                          updated_at
236                        }
237                    }
238                }
239                """
240                ),
241                variable_values={
242                    "model_pk": self._id,
243                    "visibility": visibility,
244                },
245            )["update_model"]["returning"][0]
246        )
247
248    def config(self) -> "ModelConfig":
249        if self._config is None:
250            try:
251                self.configure()
252            except Exception:
253                self._config = self.rehydrate_config()
254        assert self._config is not None
255        return cast("ModelConfig", self._config)
256
257    def configure(
258        self,
259        runtime: Optional[str] = None,
260        tensor_fields: List[str] = None,
261        filter_threshold: float = None,
262        input_schema: Optional[pa.Schema] = None,
263        output_schema: Optional[pa.Schema] = None,
264    ) -> "Model":
265        from .model_config import ModelConfig  # Avoids circular imports
266
267        if runtime is None:
268            filename_to_runtime = {".onnx": "onnx", ".py": "python"}
269
270            runtimes = [
271                v
272                for (k, v) in filename_to_runtime.items()
273                if self.file_name().endswith(k)
274            ]
275            if not runtimes:
276                raise DeploymentError(
277                    f"runtime cannot be inferred from filename: {self.file_name()}"
278                )
279            if len(runtimes) > 1:
280                raise DeploymentError(
281                    f"Multiple runtimes possible for filename {self.file_name()}: {runtimes}"
282                )
283            runtime = runtimes[0]
284
285        if tensor_fields:
286            if not isinstance(tensor_fields, List) or not all(
287                isinstance(s, str) for s in tensor_fields
288            ):
289                raise DeploymentError(
290                    f"tensor_fields must be a list of strings, received: {tensor_fields}"
291                )
292
293        if input_schema:
294            input_schema = base64.b64encode(bytes(input_schema.serialize())).decode(
295                "utf8"
296            )
297        elif runtime == "mlflow":
298            raise DeploymentError("input_schema is required for mlflow models")
299
300        if output_schema:
301            output_schema = base64.b64encode(bytes(output_schema.serialize())).decode(
302                "utf8"
303            )
304        elif runtime == "mlflow":
305            raise DeploymentError("output_schema is required for mlflow models")
306
307        q = gql.gql(
308            """
309            mutation ConfigureModel($model_id: bigint!, $runtime: String, $tensor_fields: jsonb, $filter_threshold: float8, $input_schema: String, $output_schema: String) {
310                insert_model_config(objects: {filter_threshold: $filter_threshold, model_id: $model_id, runtime: $runtime, tensor_fields: $tensor_fields, input_schema: $input_schema, output_schema: $output_schema}) {
311                returning {
312                    id
313                    model_id
314                    runtime
315                    tensor_fields
316                    filter_threshold
317                    input_schema
318                    output_schema
319                    model {
320                        id
321                    }
322                }
323            }
324            }
325            """
326        )
327        variables = {
328            "model_id": self.id(),
329            "runtime": runtime,
330            "tensor_fields": tensor_fields,
331            "filter_threshold": filter_threshold,
332            "input_schema": input_schema,
333            "output_schema": output_schema,
334        }
335        assert self.client is not None
336        data = self.client._gql_client.execute(q, variable_values=variables)
337
338        self._config = ModelConfig(
339            client=self.client,
340            data=data["insert_model_config"]["returning"][0],
341        )
342        self._config._model = self
343        return self
344
345    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
346        topic = f"model-{self.name()}-inference"
347        if valid is False:
348            topic += "-failures"
349        assert self.client is not None
350        return self.client.get_logs(topic, limit)[0]
351
352    def deploy(
353        self,
354        pipeline_name: str,
355        deployment_config: Optional[DeploymentConfig] = None,
356    ) -> "Pipeline":
357        """Convenience function to quickly deploy a Model. It will configure the model,
358        create a pipeline with a single model step, deploy it, and return the pipeline.
359
360        Typically, the configure() method is used to configure a model prior to
361        deploying it. However, if a default configuration is sufficient, this
362        function can be used to quickly deploy with said default configuration.
363
364        The filename this Model was generated from needs to have a recognizable
365        file extension so that the runtime can be inferred. Currently, this is:
366
367        * `.onnx` -> ONNX runtime
368
369        :param str deployment_name: Name of the deployment to create. Must be
370            unique across all deployments. Deployment names must be ASCII alpha-numeric
371            characters plus dash (-) only.
372
373        """
374
375        assert self.client is not None
376
377        require_dns_compliance(pipeline_name)
378
379        workspace_id = (
380            None if self.client is None else self.client.get_current_workspace().id()
381        )
382
383        if deployment_config is None:
384            deployment_config = DeploymentConfigBuilder(
385                workspace_id=workspace_id
386            ).build()
387        else:
388            deployment_config.guarantee_workspace_id(workspace_id=workspace_id)
389
390        pipeline = self.client.build_pipeline(pipeline_name)
391
392        pipeline.add_model_step(self)
393
394        pipeline.deploy(pipeline_name, deployment_config=deployment_config)
395
396        return pipeline
397
398        # TODO - wait_for_running
399
400
401class ModelVersions(List[Model]):
402    """Wraps a list of Models for display in a display-aware environment like Jupyter."""
403
404    def _repr_html_(self) -> str:
405        def row(model: Model):
406            return f"""
407            <tr>
408                <td>{model.name()}</td>
409                <td>{model.version()}</td>
410                <td>{model.file_name()}</td>
411                <td>{model.image_path()}</td>
412                <td>{model.last_update_time()}</td>
413            </tr>
414          """
415
416        fields = ["name", "version", "file_name", "image_path", "last_update_time"]
417        if not self:
418            return "(no model versions)"
419        else:
420            return (
421                "<table>"
422                + "<tr><th>"
423                + "</th><th>".join(fields)
424                + "</th></tr>"
425                + ("".join([row(m) for m in self]))
426                + "</table>"
427            )
class Model(wallaroo.object.Object):
 27class Model(Object):
 28    """Wraps a backend Model object."""
 29
 30    def __init__(
 31        self, client: Optional["Client"], data: Dict[str, Any], standalone=False
 32    ) -> None:
 33        self.client = client
 34        self._config: Optional["ModelConfig"] = None
 35        super().__init__(
 36            gql_client=client._gql_client if client is not None else None,
 37            data=data,
 38            standalone=standalone,
 39        )
 40
 41    def __repr__(self) -> str:
 42        return str(
 43            {
 44                "name": self.name(),
 45                "version": self.version(),
 46                "file_name": self.file_name(),
 47                "image_path": self.image_path(),
 48                "last_update_time": self.last_update_time(),
 49            }
 50        )
 51
 52    @staticmethod
 53    def as_standalone(name: str, version: str, file_name: str) -> "Model":
 54        """Creates a Model intended for use in generating standalone configurations"""
 55        constructor_dict = {
 56            "model_id": name,
 57            "model_version": version,
 58            "file_name": file_name,
 59            "id": -1,
 60        }
 61        return Model(None, constructor_dict, standalone=True)
 62
 63    def _fill(self, data: Dict[str, Any]) -> None:
 64        """Fills an object given a response dictionary from the GraphQL API.
 65
 66        Only the primary key member must be present; other members will be
 67        filled in via rehydration if their corresponding member function is
 68        called.
 69        """
 70        from .tag import Tag
 71
 72        for required_attribute in ["id"]:
 73            if required_attribute not in data:
 74                raise RequiredAttributeMissing(
 75                    self.__class__.__name__, required_attribute
 76                )
 77        # Required
 78        self._id = data["id"]
 79
 80        # Optional
 81        self._name = value_if_present(data, "model_id")
 82        self._version = value_if_present(data, "model_version")
 83        self._models_pk_id = value_if_present(data, "models_pk_id")
 84        self._sha = value_if_present(data, "sha")
 85        self._file_name = value_if_present(data, "file_name")
 86        self._image_path = value_if_present(data, "image_path")
 87        self._last_update_time = (
 88            dateparse.isoparse(data["updated_at"])
 89            if "updated_at" in data
 90            else DehydratedValue()
 91        )
 92        self._visibility = (
 93            _Visibility.from_str(data["visibility"])
 94            if "visibility" in data
 95            else DehydratedValue()
 96        )
 97        self._tags = (
 98            [Tag(self.client, tag["tag"]) for tag in data["model_tags"]]
 99            if "model_tags" in data
100            else DehydratedValue()
101        )
102
103    def _fetch_attributes(self) -> Dict[str, Any]:
104        """Fetches all member data from the GraphQL API."""
105        return self._gql_client.execute(
106            gql.gql(
107                f"""
108            query ModelById {{
109                model_by_pk(id: {self._id}) {{
110                    id
111                    model_id
112                    model_version
113                    models_pk_id    
114                    sha
115                    file_name
116                    image_path
117                    updated_at
118                    visibility
119                    model_tags {{
120                      tag {{
121                        id
122                        tag
123                      }}
124                    }}
125                }}
126            }}
127            """
128            )
129        )["model_by_pk"]
130
131    def id(self) -> int:
132        return self._id
133
134    def uid(self) -> str:
135        return f"{self.name()}-{self.id()}"
136
137    @rehydrate("_name")
138    def name(self) -> str:
139        return cast(str, self._name)
140
141    @rehydrate("_version")
142    def version(self) -> str:
143        return cast(str, self._version)
144
145    # TODO: Find other models by finding the parent?
146    # @rehydrate("_versions")
147    # def versions(self) -> "Models":
148    #     from .models import Models
149
150    #     return Models(client=self.client, data={"id": self._models_pk_id})
151
152    @rehydrate("_models_pk_id")
153    def models_pk_id(self) -> str:
154        return cast(str, self._models_pk_id)
155
156    @rehydrate("_sha")
157    def sha(self) -> str:
158        return cast(str, self._sha)
159
160    @rehydrate("_file_name")
161    def file_name(self) -> str:
162        return cast(str, self._file_name)
163
164    @rehydrate("_image_path")
165    def image_path(self) -> str:
166        return cast(str, self._image_path)
167
168    @rehydrate("_last_update_time")
169    def last_update_time(self) -> datetime.datetime:
170        return cast(datetime.datetime, self._last_update_time)
171
172    @property
173    def inputs(self):
174        return Variables(self.name(), "input")
175
176    @property
177    def outputs(self):
178        return Variables(self.name(), "output")
179
180    @rehydrate("_tags")
181    def tags(self) -> List["Tag"]:
182        from .tag import Tag  # avoids import cycles
183
184        return cast(List[Tag], self._tags)
185
186    @rehydrate("_config")
187    def rehydrate_config(self) -> "ModelConfig":
188        from .model_config import ModelConfig
189
190        if self._config is not None:
191            return self._config
192        if self.client is None:
193            raise Exception("Cannot retrieve current model config, client is None.")
194        assert self.client is not None
195        base = self.client.api_endpoint + f"/v1/api/models/get_config_by_id"
196        headers = {"User-Agent": _user_agent}
197
198        raw = requests.post(
199            base,
200            auth=self.client.auth,
201            headers=headers,
202            json={"model_id": self.id()},
203        )
204        if raw.status_code > 299:
205            raise Exception(
206                "Failed to retrieve config from api, and failed to automatically configure model."
207            )
208        possible_model_config = raw.json()
209        if possible_model_config["model_config"]:
210            return ModelConfig(
211                client=self.client, data=possible_model_config["model_config"]
212            )
213        else:
214            raise Exception(
215                "Failed to determine model configuration, could not auto configure based on name, and no existing "
216                "model configuration was present. "
217            )
218
219    def _update_visibility(self, visibility: _Visibility):
220        assert self.client is not None
221        return self._fill(
222            self.client._gql_client.execute(
223                gql.gql(
224                    """
225                mutation UpdateModelVisibility(
226                    $model_pk: bigint!,
227                    $visibility: String
228                ) {
229                  update_model(where: {id: {_eq: $model_pk}}, _set: {visibility: $visibility}) {
230                        returning {
231                          id
232                          model_id
233                          model_version
234                          file_name
235                          visibility
236                          updated_at
237                        }
238                    }
239                }
240                """
241                ),
242                variable_values={
243                    "model_pk": self._id,
244                    "visibility": visibility,
245                },
246            )["update_model"]["returning"][0]
247        )
248
249    def config(self) -> "ModelConfig":
250        if self._config is None:
251            try:
252                self.configure()
253            except Exception:
254                self._config = self.rehydrate_config()
255        assert self._config is not None
256        return cast("ModelConfig", self._config)
257
258    def configure(
259        self,
260        runtime: Optional[str] = None,
261        tensor_fields: List[str] = None,
262        filter_threshold: float = None,
263        input_schema: Optional[pa.Schema] = None,
264        output_schema: Optional[pa.Schema] = None,
265    ) -> "Model":
266        from .model_config import ModelConfig  # Avoids circular imports
267
268        if runtime is None:
269            filename_to_runtime = {".onnx": "onnx", ".py": "python"}
270
271            runtimes = [
272                v
273                for (k, v) in filename_to_runtime.items()
274                if self.file_name().endswith(k)
275            ]
276            if not runtimes:
277                raise DeploymentError(
278                    f"runtime cannot be inferred from filename: {self.file_name()}"
279                )
280            if len(runtimes) > 1:
281                raise DeploymentError(
282                    f"Multiple runtimes possible for filename {self.file_name()}: {runtimes}"
283                )
284            runtime = runtimes[0]
285
286        if tensor_fields:
287            if not isinstance(tensor_fields, List) or not all(
288                isinstance(s, str) for s in tensor_fields
289            ):
290                raise DeploymentError(
291                    f"tensor_fields must be a list of strings, received: {tensor_fields}"
292                )
293
294        if input_schema:
295            input_schema = base64.b64encode(bytes(input_schema.serialize())).decode(
296                "utf8"
297            )
298        elif runtime == "mlflow":
299            raise DeploymentError("input_schema is required for mlflow models")
300
301        if output_schema:
302            output_schema = base64.b64encode(bytes(output_schema.serialize())).decode(
303                "utf8"
304            )
305        elif runtime == "mlflow":
306            raise DeploymentError("output_schema is required for mlflow models")
307
308        q = gql.gql(
309            """
310            mutation ConfigureModel($model_id: bigint!, $runtime: String, $tensor_fields: jsonb, $filter_threshold: float8, $input_schema: String, $output_schema: String) {
311                insert_model_config(objects: {filter_threshold: $filter_threshold, model_id: $model_id, runtime: $runtime, tensor_fields: $tensor_fields, input_schema: $input_schema, output_schema: $output_schema}) {
312                returning {
313                    id
314                    model_id
315                    runtime
316                    tensor_fields
317                    filter_threshold
318                    input_schema
319                    output_schema
320                    model {
321                        id
322                    }
323                }
324            }
325            }
326            """
327        )
328        variables = {
329            "model_id": self.id(),
330            "runtime": runtime,
331            "tensor_fields": tensor_fields,
332            "filter_threshold": filter_threshold,
333            "input_schema": input_schema,
334            "output_schema": output_schema,
335        }
336        assert self.client is not None
337        data = self.client._gql_client.execute(q, variable_values=variables)
338
339        self._config = ModelConfig(
340            client=self.client,
341            data=data["insert_model_config"]["returning"][0],
342        )
343        self._config._model = self
344        return self
345
346    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
347        topic = f"model-{self.name()}-inference"
348        if valid is False:
349            topic += "-failures"
350        assert self.client is not None
351        return self.client.get_logs(topic, limit)[0]
352
353    def deploy(
354        self,
355        pipeline_name: str,
356        deployment_config: Optional[DeploymentConfig] = None,
357    ) -> "Pipeline":
358        """Convenience function to quickly deploy a Model. It will configure the model,
359        create a pipeline with a single model step, deploy it, and return the pipeline.
360
361        Typically, the configure() method is used to configure a model prior to
362        deploying it. However, if a default configuration is sufficient, this
363        function can be used to quickly deploy with said default configuration.
364
365        The filename this Model was generated from needs to have a recognizable
366        file extension so that the runtime can be inferred. Currently, this is:
367
368        * `.onnx` -> ONNX runtime
369
370        :param str deployment_name: Name of the deployment to create. Must be
371            unique across all deployments. Deployment names must be ASCII alpha-numeric
372            characters plus dash (-) only.
373
374        """
375
376        assert self.client is not None
377
378        require_dns_compliance(pipeline_name)
379
380        workspace_id = (
381            None if self.client is None else self.client.get_current_workspace().id()
382        )
383
384        if deployment_config is None:
385            deployment_config = DeploymentConfigBuilder(
386                workspace_id=workspace_id
387            ).build()
388        else:
389            deployment_config.guarantee_workspace_id(workspace_id=workspace_id)
390
391        pipeline = self.client.build_pipeline(pipeline_name)
392
393        pipeline.add_model_step(self)
394
395        pipeline.deploy(pipeline_name, deployment_config=deployment_config)
396
397        return pipeline
398
399        # TODO - wait_for_running

Wraps a backend Model object.

Model( client: Optional[wallaroo.client.Client], data: Dict[str, Any], standalone=False)
30    def __init__(
31        self, client: Optional["Client"], data: Dict[str, Any], standalone=False
32    ) -> None:
33        self.client = client
34        self._config: Optional["ModelConfig"] = None
35        super().__init__(
36            gql_client=client._gql_client if client is not None else None,
37            data=data,
38            standalone=standalone,
39        )

Base constructor.

Each object requires:

  • a GraphQL client - in order to fill its missing members dynamically
  • an initial data blob - typically from unserialized JSON, contains at
  • least the data for required members (typically the object's primary key) and optionally other data members.
@staticmethod
def as_standalone(name: str, version: str, file_name: str) -> wallaroo.model.Model:
52    @staticmethod
53    def as_standalone(name: str, version: str, file_name: str) -> "Model":
54        """Creates a Model intended for use in generating standalone configurations"""
55        constructor_dict = {
56            "model_id": name,
57            "model_version": version,
58            "file_name": file_name,
59            "id": -1,
60        }
61        return Model(None, constructor_dict, standalone=True)

Creates a Model intended for use in generating standalone configurations

def id(self) -> int:
131    def id(self) -> int:
132        return self._id
def uid(self) -> str:
134    def uid(self) -> str:
135        return f"{self.name()}-{self.id()}"
def name(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def version(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def models_pk_id(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def sha(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def file_name(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def image_path(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def last_update_time(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
inputs
outputs
def tags(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def rehydrate_config(*args, **kwargs):
41        def wrapper(*args, **kwargs):
42            obj = args[0]
43            if not getattr(obj, "_standalone", None):
44                present = getattr(obj, attr) != DehydratedValue()
45                # Uncomment to debug while testing
46                # print(
47                #    "rehydrate: {} -> {}".format(
48                #        attr, "present" if present else "not present"
49                #    )
50                # )
51                if not present:
52                    obj._rehydrate()
53            result = fn(*args, **kwargs)
54            return result
def config(self) -> wallaroo.model_config.ModelConfig:
249    def config(self) -> "ModelConfig":
250        if self._config is None:
251            try:
252                self.configure()
253            except Exception:
254                self._config = self.rehydrate_config()
255        assert self._config is not None
256        return cast("ModelConfig", self._config)
def configure( self, runtime: Optional[str] = None, tensor_fields: List[str] = None, filter_threshold: float = None, input_schema: Optional[pyarrow.lib.Schema] = None, output_schema: Optional[pyarrow.lib.Schema] = None) -> wallaroo.model.Model:
258    def configure(
259        self,
260        runtime: Optional[str] = None,
261        tensor_fields: List[str] = None,
262        filter_threshold: float = None,
263        input_schema: Optional[pa.Schema] = None,
264        output_schema: Optional[pa.Schema] = None,
265    ) -> "Model":
266        from .model_config import ModelConfig  # Avoids circular imports
267
268        if runtime is None:
269            filename_to_runtime = {".onnx": "onnx", ".py": "python"}
270
271            runtimes = [
272                v
273                for (k, v) in filename_to_runtime.items()
274                if self.file_name().endswith(k)
275            ]
276            if not runtimes:
277                raise DeploymentError(
278                    f"runtime cannot be inferred from filename: {self.file_name()}"
279                )
280            if len(runtimes) > 1:
281                raise DeploymentError(
282                    f"Multiple runtimes possible for filename {self.file_name()}: {runtimes}"
283                )
284            runtime = runtimes[0]
285
286        if tensor_fields:
287            if not isinstance(tensor_fields, List) or not all(
288                isinstance(s, str) for s in tensor_fields
289            ):
290                raise DeploymentError(
291                    f"tensor_fields must be a list of strings, received: {tensor_fields}"
292                )
293
294        if input_schema:
295            input_schema = base64.b64encode(bytes(input_schema.serialize())).decode(
296                "utf8"
297            )
298        elif runtime == "mlflow":
299            raise DeploymentError("input_schema is required for mlflow models")
300
301        if output_schema:
302            output_schema = base64.b64encode(bytes(output_schema.serialize())).decode(
303                "utf8"
304            )
305        elif runtime == "mlflow":
306            raise DeploymentError("output_schema is required for mlflow models")
307
308        q = gql.gql(
309            """
310            mutation ConfigureModel($model_id: bigint!, $runtime: String, $tensor_fields: jsonb, $filter_threshold: float8, $input_schema: String, $output_schema: String) {
311                insert_model_config(objects: {filter_threshold: $filter_threshold, model_id: $model_id, runtime: $runtime, tensor_fields: $tensor_fields, input_schema: $input_schema, output_schema: $output_schema}) {
312                returning {
313                    id
314                    model_id
315                    runtime
316                    tensor_fields
317                    filter_threshold
318                    input_schema
319                    output_schema
320                    model {
321                        id
322                    }
323                }
324            }
325            }
326            """
327        )
328        variables = {
329            "model_id": self.id(),
330            "runtime": runtime,
331            "tensor_fields": tensor_fields,
332            "filter_threshold": filter_threshold,
333            "input_schema": input_schema,
334            "output_schema": output_schema,
335        }
336        assert self.client is not None
337        data = self.client._gql_client.execute(q, variable_values=variables)
338
339        self._config = ModelConfig(
340            client=self.client,
341            data=data["insert_model_config"]["returning"][0],
342        )
343        self._config._model = self
344        return self
def logs( self, limit: int = 100, valid: Optional[bool] = None) -> wallaroo.logs.LogEntries:
346    def logs(self, limit: int = 100, valid: Optional[bool] = None) -> LogEntries:
347        topic = f"model-{self.name()}-inference"
348        if valid is False:
349            topic += "-failures"
350        assert self.client is not None
351        return self.client.get_logs(topic, limit)[0]
def deploy( self, pipeline_name: str, deployment_config: Optional[wallaroo.deployment_config.DeploymentConfig] = None) -> wallaroo.pipeline.Pipeline:
353    def deploy(
354        self,
355        pipeline_name: str,
356        deployment_config: Optional[DeploymentConfig] = None,
357    ) -> "Pipeline":
358        """Convenience function to quickly deploy a Model. It will configure the model,
359        create a pipeline with a single model step, deploy it, and return the pipeline.
360
361        Typically, the configure() method is used to configure a model prior to
362        deploying it. However, if a default configuration is sufficient, this
363        function can be used to quickly deploy with said default configuration.
364
365        The filename this Model was generated from needs to have a recognizable
366        file extension so that the runtime can be inferred. Currently, this is:
367
368        * `.onnx` -> ONNX runtime
369
370        :param str deployment_name: Name of the deployment to create. Must be
371            unique across all deployments. Deployment names must be ASCII alpha-numeric
372            characters plus dash (-) only.
373
374        """
375
376        assert self.client is not None
377
378        require_dns_compliance(pipeline_name)
379
380        workspace_id = (
381            None if self.client is None else self.client.get_current_workspace().id()
382        )
383
384        if deployment_config is None:
385            deployment_config = DeploymentConfigBuilder(
386                workspace_id=workspace_id
387            ).build()
388        else:
389            deployment_config.guarantee_workspace_id(workspace_id=workspace_id)
390
391        pipeline = self.client.build_pipeline(pipeline_name)
392
393        pipeline.add_model_step(self)
394
395        pipeline.deploy(pipeline_name, deployment_config=deployment_config)
396
397        return pipeline
398
399        # TODO - wait_for_running

Convenience function to quickly deploy a Model. It will configure the model, create a pipeline with a single model step, deploy it, and return the pipeline.

Typically, the configure() method is used to configure a model prior to deploying it. However, if a default configuration is sufficient, this function can be used to quickly deploy with said default configuration.

The filename this Model was generated from needs to have a recognizable file extension so that the runtime can be inferred. Currently, this is:

  • .onnx -> ONNX runtime
Parameters
  • str deployment_name: Name of the deployment to create. Must be unique across all deployments. Deployment names must be ASCII alpha-numeric characters plus dash (-) only.
class ModelVersions(typing.List[wallaroo.model.Model]):
402class ModelVersions(List[Model]):
403    """Wraps a list of Models for display in a display-aware environment like Jupyter."""
404
405    def _repr_html_(self) -> str:
406        def row(model: Model):
407            return f"""
408            <tr>
409                <td>{model.name()}</td>
410                <td>{model.version()}</td>
411                <td>{model.file_name()}</td>
412                <td>{model.image_path()}</td>
413                <td>{model.last_update_time()}</td>
414            </tr>
415          """
416
417        fields = ["name", "version", "file_name", "image_path", "last_update_time"]
418        if not self:
419            return "(no model versions)"
420        else:
421            return (
422                "<table>"
423                + "<tr><th>"
424                + "</th><th>".join(fields)
425                + "</th></tr>"
426                + ("".join([row(m) for m in self]))
427                + "</table>"
428            )

Wraps a list of Models for display in a display-aware environment like Jupyter.

Inherited Members
builtins.list
list
clear
copy
append
insert
extend
pop
remove
index
count
reverse
sort