wallaroo.client

   1import base64
   2import json
   3import math
   4import os
   5import pathlib
   6import posixpath
   7from datetime import datetime, timezone
   8from functools import partial
   9from typing import Any, Dict, List, NewType, Optional, Tuple, Union, cast
  10from urllib.parse import quote_plus
  11
  12import gql  # type: ignore
  13import pandas as pd
  14import requests
  15from gql.transport.requests import RequestsHTTPTransport
  16
  17from wallaroo.models import Models, ModelsList
  18
  19from . import auth
  20from .assay import Assay, AssayAnalysis, AssayAnalysisList, Assays
  21from .assay_config import AssayBuilder, AssayConfig
  22from .checks import require_dns_compliance
  23from .deployment import Deployment
  24from .inference_decode import inference_logs_to_dataframe
  25from .logs import LogEntries, LogEntry
  26from .model import Model, ModelVersions
  27from .model_config import ModelConfig
  28from .ModelConversion import (
  29    ConvertKerasArguments,
  30    ConvertSKLearnArguments,
  31    ConvertXGBoostArgs,
  32    ModelConversionArguments,
  33    ModelConversionFailure,
  34    ModelConversionGenericException,
  35    ModelConversionSource,
  36    ModelConversionSourceFileNotPresent,
  37    ModelConversionUnsupportedType,
  38)
  39from .object import EntityNotFoundError, ModelUploadError
  40from .pipeline import Pipeline, Pipelines
  41from .pipeline_config import PipelineConfig
  42from .pipeline_variant import PipelineVariant, PipelineVariants
  43from .tag import Tag, Tags
  44from .user import User
  45from .version import _user_agent
  46from .visibility import _Visibility
  47from .wallaroo_ml_ops_api_client.api.assay import (
  48    assays_create,
  49    assays_get_assay_results,
  50    assays_list,
  51)
  52from .wallaroo_ml_ops_api_client.api.model import models_list
  53from .wallaroo_ml_ops_api_client.api.pipeline import pipelines_create
  54from .wallaroo_ml_ops_api_client.api.workspace import workspaces_list
  55from .wallaroo_ml_ops_api_client.client import AuthenticatedClient
  56from .wallaroo_ml_ops_api_client.models import (
  57    assays_get_assay_results_json_body,
  58    models_list_json_body,
  59    pipelines_create_json_body,
  60    pipelines_create_json_body_definition,
  61    workspaces_list_json_body,
  62)
  63from .wallaroo_ml_ops_api_client.models.assays_create_json_body import (
  64    AssaysCreateJsonBody,
  65)
  66from .wallaroo_ml_ops_api_client.models.assays_create_response_200 import (
  67    AssaysCreateResponse200,
  68)
  69from .wallaroo_ml_ops_api_client.models.assays_get_assay_results_response_200_item import (
  70    AssaysGetAssayResultsResponse200Item,
  71)
  72from .wallaroo_ml_ops_api_client.models.assays_list_json_body import AssaysListJsonBody
  73from .wallaroo_ml_ops_api_client.models.models_list_response_200 import (
  74    ModelsListResponse200,
  75)
  76from .wallaroo_ml_ops_api_client.models.pipelines_create_response_200 import (
  77    PipelinesCreateResponse200,
  78)
  79from .wallaroo_ml_ops_api_client.models.workspaces_list_response_200 import (
  80    WorkspacesListResponse200,
  81)
  82from .wallaroo_ml_ops_api_client.types import UNSET
  83from .workspace import Workspace, Workspaces
  84
  85Datetime = NewType("Datetime", datetime)
  86
  87WALLAROO_SDK_AUTH_TYPE = "WALLAROO_SDK_AUTH_TYPE"
  88WALLAROO_SDK_AUTH_ENDPOINT = "WALLAROO_SDK_AUTH_ENDPOINT"
  89WALLAROO_URL = "WALLAROO_URL"
  90WALLAROO_AUTH_URL = "WALLAROO_AUTH_URL"
  91
  92
  93class Client(object):
  94    """Client handle to a Wallaroo platform instance.
  95
  96    Objects of this class serve as the entrypoint to Wallaroo platform
  97    functionality.
  98    """
  99
 100    @staticmethod
 101    def get_urls(
 102        auth_type: Optional[str], api_endpoint: str, auth_endpoint: str
 103    ) -> Tuple[Optional[str], str, str]:
 104        """Method to calculate the auth values specified as defaults,
 105        as params or in ENV vars.
 106        Made static to be testable without reaching out to SSO, etc."""
 107
 108        if auth_type is None:
 109            auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None)
 110
 111        # ideally we'd set auth_endpoint to None default value but that would
 112        # make the auth_endpoint type to be Optiona[str] which messes up
 113        # a lot| of type hinting and I wanted to make minimal changes without a
 114        # lot of 'warnings'.
 115        if len(auth_endpoint.strip()) == 0:
 116            auth_endpoint = (
 117                os.environ.get(WALLAROO_AUTH_URL)
 118                or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT)
 119                or "http://api-lb:8080"
 120            )
 121
 122        api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint)
 123
 124        return auth_type, api_endpoint, auth_endpoint
 125
 126    def __init__(
 127        self,
 128        api_endpoint: str = "http://api-lb:8080",
 129        auth_endpoint: str = "",
 130        request_timeout: int = 45,
 131        auth_type: Optional[str] = None,
 132        gql_client: Optional[gql.Client] = None,
 133        pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432",
 134        interactive: Optional[bool] = None,
 135        time_format: str = "%Y-%d-%b %H:%M:%S",
 136    ):
 137        """Create a Client handle.
 138
 139        :param str api_endpoint: Host/port of the platform API endpoint
 140        :param str auth_endpoint: Host/port of the platform Keycloak instance
 141        :param int timeout: Max timeout of web requests, in seconds
 142        :param str auth_type: Authentication type to use. Can be one of: "none",
 143            "sso", "user_password".
 144        :param str pg_connection_string: Postgres connection string
 145        :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise.
 146        :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context.
 147        """
 148
 149        auth_type, api_endpoint, auth_endpoint = Client.get_urls(
 150            auth_type, api_endpoint, auth_endpoint
 151        )
 152
 153        self.auth = auth.create(auth_endpoint, auth_type)
 154
 155        if gql_client:
 156            self._gql_client = gql_client
 157        else:
 158            gql_transport = RequestsHTTPTransport(
 159                url=posixpath.join(api_endpoint, "v1/graphql"),
 160                auth=self.auth,
 161                timeout=request_timeout,
 162            )
 163            self._gql_client = gql.Client(
 164                transport=gql_transport, fetch_schema_from_transport=True
 165            )
 166
 167        self.api_endpoint = api_endpoint
 168
 169        self.auth_endpoint = auth_endpoint
 170
 171        self.timeout = request_timeout
 172
 173        self._setup_mlops_client()
 174
 175        self.pg_connection_string = pg_connection_string
 176
 177        self._current_workspace: Optional[Workspace] = None
 178
 179        # TODO: debate the names of these things
 180        self._default_ws_name: Optional[str] = None
 181
 182        user_email = self.auth.user_email()
 183        if user_email is not None:
 184            self._default_ws_name = user_email + "_ws"
 185
 186        if interactive is not None:
 187            self._interactive = interactive
 188        elif (
 189            "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ
 190        ):
 191            self._interactive = True
 192        else:
 193            self._interactive = False
 194
 195        self._time_format = time_format
 196
 197    def _get_rest_api(self, path: str, params: dict):
 198        headers = {
 199            "authorization": self.auth._bearer_token_str(),
 200            "user-agent": _user_agent,
 201        }
 202
 203        url = f"{self.api_endpoint}/{path}"
 204
 205        return requests.get(url, headers=headers, params=params)
 206
 207    def _post_rest_api(self, path: str, body: dict):
 208        headers = {
 209            "authorization": self.auth._bearer_token_str(),
 210            "user-agent": _user_agent,
 211        }
 212
 213        url = f"{self.api_endpoint}/{path}"
 214        return requests.post(url, headers=headers, json=body)
 215
 216    def list_tags(self) -> Tags:
 217        """List all tags on the platform.
 218
 219        :return: A list of all tags on the platform.
 220        :rtype: List[Tag]
 221        """
 222        res = self._gql_client.execute(
 223            gql.gql(
 224                """
 225            query ListTags {
 226              tag(order_by: {id: desc}) {
 227                id
 228                tag
 229                model_tags {
 230                  model {
 231                    id
 232                    model_id
 233                    models_pk_id
 234                    model_version
 235                    
 236                  }
 237                }
 238                pipeline_tags {
 239                  pipeline {
 240                    id
 241                    pipeline_id
 242                    pipeline_versions {
 243                        id
 244                        version
 245                    }
 246                  }
 247                }
 248              }
 249            }
 250
 251
 252            """
 253            )
 254        )
 255        return Tags([Tag(client=self, data={"id": p["id"]}) for p in res["tag"]])
 256
 257    def list_models(self) -> ModelsList:
 258        """List all models on the platform.
 259
 260        :return: A list of all models on the platform.
 261        :rtype: List[Model]
 262        """
 263        id = self.get_current_workspace().id()
 264        res = models_list.sync(
 265            client=self.mlops(),
 266            json_body=models_list_json_body.ModelsListJsonBody(id),
 267        )
 268
 269        if res is None:
 270            raise Exception("Failed to list models")
 271
 272        if not isinstance(res, ModelsListResponse200):
 273            raise Exception(res.msg)
 274
 275        return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models])
 276
 277    def list_deployments(self) -> List[Deployment]:
 278        """List all deployments (active or not) on the platform.
 279
 280        :return: A list of all deployments on the platform.
 281        :rtype: List[Deployment]
 282        """
 283        res = self._gql_client.execute(
 284            gql.gql(
 285                """
 286            query ListDeployments {
 287              deployment {
 288                id
 289                deploy_id
 290                deployed
 291                deployment_model_configs {
 292                  model_config {
 293                    id
 294                  }
 295                }
 296              }
 297            }
 298            """
 299            )
 300        )
 301        return [Deployment(client=self, data=d) for d in res["deployment"]]
 302
 303    """
 304        # Removed until we figure out what pipeline ownership means
 305        #
 306        # def search_my_pipelines(
 307        #     self,
 308        #     search_term: Optional[str] = None,
 309        #     deployed: Optional[bool] = None,
 310        #     created_start: Optional["Datetime"] = None,
 311        #     created_end: Optional["Datetime"] = None,
 312        #     updated_start: Optional["Datetime"] = None,
 313        #     updated_end: Optional["Datetime"] = None,
 314        # ) -> List[Pipeline]:
 315        #     user_id = self.auth.user_id()
 316        #     return Pipelines(
 317        #         self._search_pipelines(
 318        #             search_term,
 319        #             deployed,
 320        #             user_id,
 321        #             created_start,
 322        #             created_end,
 323        #             updated_start,
 324        #             updated_end,
 325        #         )
 326        #     )
 327    """
 328
 329    def search_pipelines(
 330        self,
 331        search_term: Optional[str] = None,
 332        deployed: Optional[bool] = None,
 333        created_start: Optional["Datetime"] = None,
 334        created_end: Optional["Datetime"] = None,
 335        updated_start: Optional["Datetime"] = None,
 336        updated_end: Optional["Datetime"] = None,
 337    ) -> PipelineVariants:
 338        """Search for pipelines. All parameters are optional, in which case the result is the same as
 339        `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example:
 340
 341             myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo")
 342
 343        :param str search_term: Will be matched against tags and model names. Example: "footag123".
 344        :param bool deployed: Pipeline was deployed or not
 345        :param str created_start: Pipeline was created at or after this time
 346        :param str created_end: Pipeline was created at or before this time
 347        :param str updated_start: Pipeline was updated at or before this time
 348        :param str updated_end: Pipeline was updated at or before this time
 349
 350        :return: A list of all pipelines on the platform.
 351        :rtype: List[Pipeline]
 352        """
 353        return PipelineVariants(
 354            self._search_pipelines(
 355                search_term,
 356                deployed,
 357                None,
 358                created_start,
 359                created_end,
 360                updated_start,
 361                updated_end,
 362            )
 363        )
 364
 365    def _search_pipelines(
 366        self,
 367        search_term: Optional[str] = None,
 368        deployed: Optional[bool] = None,
 369        user_id: Optional[str] = None,
 370        created_start: Optional["Datetime"] = None,
 371        created_end: Optional["Datetime"] = None,
 372        updated_start: Optional["Datetime"] = None,
 373        updated_end: Optional["Datetime"] = None,
 374    ) -> List[PipelineVariant]:
 375        (query, params) = self._generate_search_pipeline_query(
 376            search_term=search_term,
 377            deployed=deployed,
 378            user_id=user_id,
 379            created_start=created_start,
 380            created_end=created_end,
 381            updated_start=updated_start,
 382            updated_end=updated_end,
 383        )
 384        q = gql.gql(query)
 385        data = self._gql_client.execute(q, variable_values=params)
 386        pipelines = []
 387        if data["search_pipelines"]:
 388            for p in data["search_pipelines"]:
 389                pipelines.append(PipelineVariant(self, p))
 390        return pipelines
 391
 392    def _generate_search_pipeline_query(
 393        self,
 394        search_term: Optional[str] = None,
 395        deployed: Optional[bool] = None,
 396        user_id: Optional[str] = None,
 397        created_start: Optional["Datetime"] = None,
 398        created_end: Optional["Datetime"] = None,
 399        updated_start: Optional["Datetime"] = None,
 400        updated_end: Optional["Datetime"] = None,
 401    ):
 402        filters = []
 403        query_params = []
 404        params: Dict[str, Any] = {}
 405        search = ""
 406        if search_term:
 407            search = search_term
 408        params["search_term"] = search
 409        query_params.append("$search_term: String!")
 410
 411        if user_id:
 412            filters.append("owner_id: {_eq: $user_id}")
 413            params["user_id"] = user_id
 414            query_params.append("$user_id: String!")
 415
 416        if deployed is not None:
 417            filters.append("pipeline: {deployment: {deployed: {_eq: $deployed}}}")
 418            params["deployed"] = deployed
 419            query_params.append("$deployed: Boolean")
 420
 421        self._generate_time_range_graphql(
 422            "created_at",
 423            start=created_start,
 424            end=created_end,
 425            filters=filters,
 426            query_params=query_params,
 427            params=params,
 428        )
 429        self._generate_time_range_graphql(
 430            "updated_at",
 431            start=updated_start,
 432            end=updated_end,
 433            filters=filters,
 434            query_params=query_params,
 435            params=params,
 436        )
 437
 438        where_clause_str = self._generate_where_clause_str(filters)
 439        query_param_str = self._generate_query_param_str(query_params)
 440        query = f"""
 441            query GetPipelines({query_param_str}) {{
 442                search_pipelines(args: {{search: $search_term}}, distinct_on: id{where_clause_str}, order_by: {{id: desc}}) {{
 443                    id
 444                    created_at
 445                    pipeline_pk_id
 446                    updated_at
 447                    version
 448                    pipeline {{
 449                        id
 450                        pipeline_id
 451                        pipeline_tags {{
 452                            id
 453                            tag {{
 454                                id
 455                                tag
 456                            }}
 457                        }}
 458                    }}
 459                }}
 460            }}
 461        """
 462        return (query, params)
 463
 464    def _generate_where_clause_str(self, filters: List[str]) -> str:
 465        where_clause_str = ""
 466        filters_len = len(filters)
 467        if filters_len > 0:
 468            if filters_len > 1:
 469                where_clause_str = f""", where: {{_and: {{ {", ".join(filters)} }}}}"""
 470            else:
 471                where_clause_str = f", where: {{{filters[0]}}}"
 472        return where_clause_str
 473
 474    def _generate_query_param_str(self, query_params: List[str]):
 475        return ", ".join(query_params)
 476
 477    def _generate_time_range_graphql(
 478        self,
 479        field: str,
 480        start: Optional["Datetime"],
 481        end: Optional["Datetime"],
 482        filters: List[str],
 483        query_params: List[str],
 484        params: Dict[str, Any],
 485    ):
 486        if start and not end:
 487            filters.append(f"{field}: {{_gte: $start_{field}}}")
 488            params[f"start_{field}"] = start
 489            query_params.append(f"$start_{field}: timestamptz!")
 490        elif end and not start:
 491            filters.append(f"{field}: {{_lte: $end_{field}}}")
 492            params[f"end_{field}"] = end
 493            query_params.append(f"$end_{field}: timestamptz!")
 494        elif start and end:
 495            filters.append(f"{field}: {{_gte: $start_{field}, _lte: $end_{field}}}")
 496            params[f"start_{field}"] = start
 497            params[f"end_{field}"] = start
 498            query_params.append(f"$start_{field}: timestamptz!")
 499            query_params.append(f"$end_{field}: timestamptz!")
 500
 501    def search_my_models(
 502        self,
 503        search_term: Optional[str] = None,
 504        uploaded_time_start: Optional["Datetime"] = None,
 505        uploaded_time_end: Optional["Datetime"] = None,
 506    ) -> ModelVersions:
 507        """Search models owned by you
 508        params:
 509             search_term: Searches the following metadata: names, shas, versions, file names, and tags
 510             uploaded_time_start: Inclusive time of upload
 511             uploaded_time_end: Inclusive time of upload
 512        """
 513        user_id = self.auth.user_id()
 514        return ModelVersions(
 515            self._search_models(
 516                search_term=search_term,
 517                user_id=user_id,
 518                start=uploaded_time_start,
 519                end=uploaded_time_end,
 520            )
 521        )
 522
 523    def search_models(
 524        self,
 525        search_term: Optional[str] = None,
 526        uploaded_time_start: Optional["Datetime"] = None,
 527        uploaded_time_end: Optional["Datetime"] = None,
 528    ) -> ModelVersions:
 529        """Search all models you have access to.
 530        params:
 531             search_term: Searches the following metadata: names, shas, versions, file names, and tags
 532             uploaded_time_start: Inclusive time of upload
 533             uploaded_time_end: Inclusive time of upload
 534        """
 535        return ModelVersions(
 536            self._search_models(
 537                search_term=search_term,
 538                start=uploaded_time_start,
 539                end=uploaded_time_end,
 540            )
 541        )
 542
 543    def _search_models(
 544        self, search_term=None, user_id=None, start=None, end=None
 545    ) -> List[Model]:
 546        (query, params) = self._generate_model_query(
 547            search_term=search_term,
 548            user_id=user_id,
 549            start=start,
 550            end=end,
 551        )
 552
 553        q = gql.gql(query)
 554
 555        data = self._gql_client.execute(q, variable_values=params)
 556        models = []
 557        if data["search_models"]:
 558            for m in data["search_models"]:
 559                models.append(Model(self, m))
 560        return models
 561
 562    def _generate_model_query(
 563        self,
 564        search_term=None,
 565        user_id=None,
 566        start=None,
 567        end=None,
 568    ):
 569        filters = []
 570        query_params = []
 571        params = {}
 572        search = ""
 573        if search_term:
 574            search = search_term
 575        params["search_term"] = search
 576        query_params.append("$search_term: String!")
 577        if user_id:
 578            filters.append("owner_id: {_eq: $user_id}")
 579            params["user_id"] = user_id
 580            query_params.append("$user_id: String!")
 581
 582        self._generate_time_range_graphql(
 583            "created_at",
 584            start=start,
 585            end=end,
 586            filters=filters,
 587            params=params,
 588            query_params=query_params,
 589        )
 590
 591        where_clause_str = self._generate_where_clause_str(filters)
 592        query_param_str = self._generate_query_param_str(query_params)
 593        query = f"""
 594            query GetModels({query_param_str}) {{
 595              search_models(args: {{search: $search_term}}{where_clause_str}, order_by: {{created_at: desc}}) {{
 596                id
 597              }}
 598            }}
 599        """
 600        return (query, params)
 601
 602    def get_user_by_email(self, email: str) -> Optional[User]:
 603        """Find a user by email"""
 604        assert email is not None
 605        escaped_email = quote_plus(email)
 606        url = (
 607            f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}"
 608        )
 609        headers = {
 610            "Content-Type": "application/json",
 611            "Authorization": self.auth._bearer_token_str(),
 612            "User-Agent": _user_agent,
 613        }
 614        resp = requests.get(url, headers=headers, data={})
 615        jresp = resp.json()
 616        return None if jresp == [] else User(client=self, data=jresp[0])
 617
 618    def deactivate_user(self, email: str) -> None:
 619        """Deactivates an existing user of the platform
 620
 621        Deactivated users cannot log into the platform.
 622        Deactivated users do not count towards the number of allotted user seats from the license.
 623
 624        The Models and Pipelines owned by the deactivated user are not removed from the platform.
 625
 626        :param str email: The email address of the user to deactivate.
 627
 628        :return: None
 629        :rtype: None
 630        """
 631
 632        if self.auth.user_email() == email:
 633            raise Exception("A user may not deactive themselves.")
 634
 635        user = self.get_user_by_email(email)
 636
 637        if user is None:
 638            raise EntityNotFoundError("User", {"email": email})
 639
 640        if user.username() == "admin":
 641            raise Exception("Admin user may not be deactivated.")
 642
 643        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
 644        headers = {
 645            "Content-Type": "application/json",
 646            "Authorization": self.auth._bearer_token_str(),
 647            "User-Agent": _user_agent,
 648        }
 649
 650        # Get the current full user representation to return in the mutation due to keycloak bug
 651        get_user_response = requests.get(url, headers=headers, data={})
 652
 653        cur_user_rep = get_user_response.json()
 654        cur_user_rep["enabled"] = False
 655
 656        resp = requests.put(url, headers=headers, json=cur_user_rep)
 657
 658        if resp.status_code != 204:
 659            raise EntityNotFoundError("User", {"email": email})
 660        return None
 661
 662    def activate_user(self, email: str) -> None:
 663        """Activates an existing user of the platform that had been previously deactivated.
 664
 665        Activated users can log into the platform.
 666
 667        :param str email: The email address of the user to activate.
 668
 669        :return: None
 670        :rtype: None
 671        """
 672        user = self.get_user_by_email(email)
 673
 674        if user is None:
 675            raise EntityNotFoundError("User", {"email": email})
 676
 677        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
 678
 679        headers = {
 680            "Content-Type": "application/json",
 681            "Authorization": self.auth._bearer_token_str(),
 682            "User-Agent": _user_agent,
 683        }
 684
 685        # Get the current full user representation to return in the mutation due to keycloak bug
 686        get_user_response = requests.get(url, headers=headers, data={})
 687
 688        cur_user_rep = get_user_response.json()
 689        cur_user_rep["enabled"] = True
 690
 691        resp = requests.put(url, headers=headers, json=cur_user_rep)
 692
 693        if resp.status_code != 204:
 694            raise EntityNotFoundError("User", {"email": email})
 695        return None
 696
 697    def _get_user_by_id(self, id: str) -> Optional[User]:
 698        assert id is not None
 699        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{id}"
 700        headers = {
 701            "Content-Type": "application/json",
 702            "Authorization": self.auth._bearer_token_str(),
 703            "User-Agent": _user_agent,
 704        }
 705        resp = requests.get(url, headers=headers, data={})
 706        jresp = resp.json()
 707        return None if jresp == [] else User(client=self, data=jresp)
 708
 709    def list_users(self) -> List[User]:
 710        """List of all Users on the platform
 711
 712        :return: A list of all Users on the platform.
 713        :rtype: List[User]
 714        """
 715        resp = User.list_users(auth=self.auth)
 716        return [User(client=self, data=u) for u in resp]
 717
 718    def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model:
 719        """Upload a model defined by a file as a new model variant.
 720
 721        :param str model_name: The name of the model of which this is a variant.
 722            Names must be ASCII alpha-numeric characters or dash (-) only.
 723        :param Union[str, pathlib.Path] path: Path of the model file to upload.
 724        :return: The created Model.
 725        :rtype: Model
 726        """
 727
 728        _Visibility.PRIVATE
 729        if isinstance(path, str):
 730            path = pathlib.Path(path)
 731        with path.open("rb") as f:
 732            return self._upload_model_stream(name, {"filename": path.name}, f)
 733
 734    def _upload_model_stream(self, name: str, data: Dict[str, Any], file: Any):
 735        require_dns_compliance(name)
 736        endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload_stream")
 737        data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()}
 738        headers = {"User-Agent": _user_agent}
 739
 740        res = requests.post(
 741            endpoint, auth=self.auth, params=data, data=file, headers=headers
 742        )
 743        if res.status_code != 200:
 744            raise ModelUploadError(res.text)
 745
 746        res_dict = json.loads(res.text)
 747        return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0])
 748
 749    def register_model_image(self, name: str, image: str) -> Model:
 750        """Registers an MLFlow model as a new model.
 751
 752        :param str model_name: The name of the model of which this is a variant.
 753            Names must be ASCII alpha-numeric characters or dash (-) only.
 754        :param str image: Image name of the MLFlow model to register.
 755        :return: The created Model.
 756        :rtype: Model
 757        """
 758        data = {
 759            "image_path": image,
 760        }
 761        return self._upload_model(name, data)
 762
 763    def _upload_model(
 764        self, name: str, data: Dict[str, Any], files: Dict[str, Tuple[str, bytes]] = {}
 765    ):
 766        require_dns_compliance(name)
 767        endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload")
 768        data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()}
 769        if len(files) == 0:
 770            files = {"dummy": ("none", b"")}
 771
 772        headers = {"User-Agent": _user_agent}
 773
 774        res = requests.post(
 775            endpoint, files=files, auth=self.auth, data=data, headers=headers
 776        )
 777        if res.status_code != 200:
 778            raise ModelUploadError(res.text)
 779
 780        res_dict = json.loads(res.text)
 781        return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0])
 782
 783    def model_by_name(self, model_class: str, model_name: str) -> Model:
 784        """Fetch a Model by name.
 785
 786        :param str model_class: Name of the model class.
 787        :param str model_name: Name of the variant within the specified model class.
 788        :return: The Model with the corresponding model and variant name.
 789        :rtype: Model
 790        """
 791        res = self._gql_client.execute(
 792            gql.gql(
 793                """
 794                query ModelByName($model_id: String!, $model_version: String!) {
 795                  model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) {
 796                    id
 797                    model_id
 798                    model_version
 799                  }
 800                }
 801                """
 802            ),
 803            variable_values={
 804                "model_id": model_class,
 805                "model_version": model_name,
 806            },
 807        )
 808        if not res["model"]:
 809            raise EntityNotFoundError(
 810                "Model", {"model_class": model_class, "model_name": model_name}
 811            )
 812        return Model(client=self, data={"id": res["model"][0]["id"]})
 813
 814    def deployment_by_name(self, deployment_name: str) -> Deployment:
 815        """Fetch a Deployment by name.
 816
 817        :param str deployment_name: Name of the deployment.
 818        :return: The Deployment with the corresponding name.
 819        :rtype: Deployment
 820        """
 821        res = self._gql_client.execute(
 822            gql.gql(
 823                """
 824                query DeploymentByName($deployment_name: String!) {
 825                  deployment(where: {deploy_id: {_eq: $deployment_name}}) {
 826                    id
 827                  }
 828                }
 829                """
 830            ),
 831            variable_values={
 832                "deployment_name": deployment_name,
 833            },
 834        )
 835        if not res["deployment"]:
 836            raise EntityNotFoundError(
 837                "Deployment", {"deployment_name": deployment_name}
 838            )
 839        return Deployment(client=self, data={"id": res["deployment"][0]["id"]})
 840
 841    def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]:
 842        """Fetch Pipelines by name.
 843
 844        :param str pipeline_name: Name of the pipeline.
 845        :return: The Pipeline with the corresponding name.
 846        :rtype: Pipeline
 847        """
 848        res = self._gql_client.execute(
 849            gql.gql(
 850                """
 851                query PipelineByName($pipeline_name: String!) {
 852                  pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) {
 853                    id
 854                  }
 855                }
 856                """
 857            ),
 858            variable_values={
 859                "pipeline_name": pipeline_name,
 860            },
 861        )
 862        assert "pipeline" in res
 863        length = len(res["pipeline"])
 864        if length < 1:
 865            raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name})
 866        return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]]
 867
 868    def list_pipelines(self) -> List[Pipeline]:
 869        """List all pipelines on the platform.
 870
 871        :return: A list of all pipelines on the platform.
 872        :rtype: List[Pipeline]
 873        """
 874        res = self._gql_client.execute(
 875            gql.gql(
 876                """
 877            query ListPipelines {
 878              pipeline(order_by: {id: desc}) {
 879                id
 880                pipeline_tags {
 881                  tag {
 882                    id
 883                    tag
 884                  }
 885                }
 886              }
 887            }
 888            """
 889            )
 890        )
 891        return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]])
 892
 893    def build_pipeline(self, pipeline_name: str) -> "Pipeline":
 894        """Starts building a pipeline with the given `pipeline_name`,
 895        returning a :py:PipelineConfigBuilder:
 896
 897        When completed, the pipeline can be uploaded with `.upload()`
 898
 899        :param pipeline_name string: Name of the pipeline, must be composed of ASCII
 900          alpha-numeric characters plus dash (-).
 901        """
 902
 903        require_dns_compliance(pipeline_name)
 904
 905        _Visibility.PRIVATE
 906
 907        # TODO: Needs to handle visibility?
 908        data = pipelines_create.sync(
 909            client=self.mlops(),
 910            json_body=pipelines_create_json_body.PipelinesCreateJsonBody(
 911                pipeline_name,
 912                self.get_current_workspace().id(),
 913                pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict(
 914                    {}
 915                ),
 916            ),
 917        )
 918
 919        if data is None:
 920            raise Exception("Failed to create pipeline")
 921
 922        if not isinstance(data, PipelinesCreateResponse200):
 923            raise Exception(data.msg)
 924
 925        return Pipeline(client=self, data={"id": data.pipeline_pk_id})
 926
 927    def _upload_pipeline_variant(
 928        self,
 929        name: str,
 930        config: PipelineConfig,
 931    ) -> Pipeline:
 932        """Creates a new PipelineVariant with the specified configuration.
 933
 934        :param str name: Name of the Pipeline. Must be unique across all Pipelines.
 935        :param config PipelineConfig: Pipeline configuration.
 936        """
 937        definition = config.to_json()
 938        _Visibility.PRIVATE
 939
 940        data = pipelines_create.sync(
 941            client=self.mlops(),
 942            json_body=pipelines_create_json_body.PipelinesCreateJsonBody(
 943                name,
 944                self.get_current_workspace().id(),
 945                pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict(
 946                    definition
 947                ),
 948            ),
 949        )
 950
 951        if data is None:
 952            # TODO: Generalize
 953            raise Exception("Failed to create pipeline")
 954
 955        if not isinstance(data, PipelinesCreateResponse200):
 956            raise Exception(data.msg)
 957
 958        for alert_config in config.alert_configurations:
 959            self._gql_client.execute(
 960                gql.gql(
 961                    """
 962                mutation CreateAlertConfiguration(
 963                    $pipeline_version_id: bigint,
 964                    $name: String,
 965                    $expression: String,
 966                    $notifications: jsonb
 967                ) {
 968                    insert_alert_configuration(objects: {
 969                        name: $name,
 970                        expression: $expression,
 971                        notifications: $notifications,
 972                        pipeline_version_pk_id: $pipeline_version_id
 973                    }) {
 974                        returning { id }
 975                    }
 976                }
 977                """
 978                ),
 979                variable_values={
 980                    **alert_config.to_json(),
 981                    "pipeline_version_id": data.pipeline_pk_id,
 982                },
 983            )
 984
 985        pipeline_data = data.to_dict()
 986        pipeline_data["id"] = data.pipeline_pk_id
 987
 988        return Pipeline(
 989            client=self,
 990            data=pipeline_data,
 991        )
 992
 993    def create_value_split_experiment(
 994        self,
 995        name: str,
 996        meta_key: str,
 997        default_model: ModelConfig,
 998        challenger_models: List[Tuple[Any, ModelConfig]],
 999    ) -> Pipeline:
1000        """Creates a new PipelineVariant of a "value-split experiment" type.
1001        :param str name: Name of the Pipeline
1002        :param meta_key str: Inference input key on which to redirect inputs to
1003            experiment models.
1004        :param default_model ModelConfig: Model to send inferences by default.
1005        :param challenger_models List[Tuple[Any, ModelConfig]]: A list of
1006            meta_key values -> Models to send inferences. If the inference data
1007            referred to by meta_key is equal to one of the keys in this tuple,
1008            that inference is redirected to the corresponding model instead of
1009            the default model.
1010        """
1011        args = [meta_key, default_model.model().name()]
1012        for v, m in challenger_models:
1013            args.append(v)
1014            args.append(m.model().name())
1015        step = {
1016            "id": "metavalue_split",
1017            "operation": "map",
1018            "args": args,
1019        }
1020        definition = {"id": name, "steps": [step]}
1021        # TODO: This seems like a one-to-one replace, find appropriate test.
1022        data = self._gql_client.execute(
1023            gql.gql(
1024                """
1025            mutation CreatePipeline(
1026                $pipeline_id: String,
1027                $version: String,
1028                $definition: jsonb,
1029                $workspace_id: bigint
1030            ) {
1031                insert_pipeline(
1032                    objects: {
1033                    pipeline_versions: {
1034                        data: { definition: $definition }
1035                    }
1036                    pipeline_id: $pipeline_id
1037                    }
1038                ) {
1039                    returning {
1040                        id
1041                    }
1042                }
1043            }
1044            """
1045            ),
1046            variable_values={
1047                "pipeline_id": name,
1048                "definition": definition,
1049                "workspace_id": self.get_current_workspace().id(),
1050            },
1051        )
1052        return Pipeline(
1053            client=self,
1054            data=data["insert_pipeline"]["returning"][0],
1055        )
1056
1057    def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]:
1058        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1059        headers = {"User-Agent": _user_agent}
1060        partitions = requests.get(base, auth=self.auth, headers=headers).json()[
1061            "partitions"
1062        ]
1063
1064        iterator = {
1065            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1066            for k, span in partitions.items()
1067        }
1068
1069        response = requests.post(
1070            base + "/records",
1071            params={"limit": limit},
1072            json=iterator,
1073            auth=self.auth,
1074            headers=headers,
1075        ).json()
1076
1077        return (
1078            LogEntries([LogEntry(json.loads(l)) for l in response["records"]]),
1079            response.get("status", "None"),
1080        )
1081
1082    def security_logs(self, limit: int) -> List[dict]:
1083        """This function is not available in this release"""
1084        raise RuntimeError("security_log() is not available in this release.")
1085
1086    def get_raw_logs(
1087        self,
1088        topic: str,
1089        start: Optional[datetime] = None,
1090        end: Optional[datetime] = None,
1091        limit: int = 100_000,
1092        parse: bool = False,
1093        verbose: bool = False,
1094    ) -> List[Dict[str, Any]]:
1095        """Gets logs from Plateau for a particular time window without attempting
1096        to convert them to Inference LogEntries. Logs can be returned as strings
1097        or the json parsed into lists and dicts.
1098        :param topic str: The name of the topic to query
1099        :param start Optional[datetime]: The start of the time window
1100        :param end Optional[datetime]: The end of the time window
1101        :param limit int: The number of records to retrieve. Note retrieving many
1102            records may be a performance bottleneck.
1103        :param parse bool: Wether to attempt to parse the string as a json object.
1104        :param verbose bool: Prints out info to help diagnose issues.
1105        """
1106
1107        assert limit <= 1_000_000
1108
1109        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1110        headers = {"User-Agent": _user_agent}
1111        resp = requests.get(base, auth=self.auth, headers=headers)
1112        if resp.status_code != 200:
1113            raise EntityNotFoundError(
1114                f"Could not get partitions {resp.text}", {"url": base}
1115            )
1116        data = resp.json()
1117        partitions = data["partitions"]
1118
1119        if verbose:
1120            print(f"Got partitions {partitions}")
1121
1122        params: Dict[str, Any] = {"limit": limit}
1123        if start is not None:
1124            start_str = start.astimezone(tz=timezone.utc).isoformat()
1125            params["time.start"] = start_str
1126        if end is not None:
1127            end_str = end.astimezone(tz=timezone.utc).isoformat()
1128            params["time.end"] = end_str
1129
1130        next: Union[Any, None] = {
1131            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1132            for k, span in partitions.items()
1133        }
1134
1135        if verbose:
1136            print("Using params: ", params)
1137
1138        records = []
1139        while next is not None:
1140            response = requests.post(
1141                base + "/records",
1142                params=params,
1143                json=next,
1144                auth=self.auth,
1145                headers=headers,
1146            )
1147            if response.status_code != 200:
1148                raise EntityNotFoundError(
1149                    f"Could not get records {resp.text}",
1150                    {"url": base, "params": str(params), "iterator": str(next)},
1151                )
1152
1153            if verbose:
1154                print("response: ", response)
1155
1156            result = response.json()
1157            result_records = result["records"]
1158            if len(result_records) > 0:
1159                records.extend(result_records)
1160                next = result["next"]
1161            else:
1162                next = None
1163        if parse:
1164            return [json.loads(r) for r in records]
1165        return records
1166
1167    def get_raw_pipeline_inference_logs(
1168        self,
1169        topic: str,
1170        start: datetime,
1171        end: datetime,
1172        model_name: Optional[str] = None,
1173        limit: int = 100_000,
1174        verbose: bool = False,
1175    ) -> List[Dict[str, Any]]:
1176        """Gets logs from Plateau for a particular time window and filters them for
1177        the model specified.
1178        :param pipeline_name str: The name/pipeline_id of the pipeline to query
1179        :param topic str: The name of the topic to query
1180        :param start Optional[datetime]: The start of the time window
1181        :param end Optional[datetime]: The end of the time window
1182        :param model_id: The name of the specific model to filter if any
1183        :param limit int: The number of records to retrieve. Note retrieving many
1184            records may be a performance bottleneck.
1185        :param verbose bool: Prints out info to help diagnose issues.
1186        """
1187        logs = self.get_raw_logs(
1188            topic,
1189            start=start,
1190            end=end,
1191            limit=limit,
1192            parse=True,
1193            verbose=verbose,
1194        )
1195
1196        if verbose:
1197            print(f"Got {len(logs)} initial logs")
1198
1199        if model_name:
1200            logs = [l for l in logs if l["model_name"] == model_name]
1201
1202        # inference results are a unix timestamp in millis - filter by that
1203        start_ts = int(start.timestamp() * 1000)
1204        end_ts = int(end.timestamp() * 1000)
1205        logs = [l for l in logs if start_ts <= l["time"] < end_ts]
1206
1207        return logs
1208
1209    def get_pipeline_inference_dataframe(
1210        self,
1211        topic: str,
1212        start: datetime,
1213        end: datetime,
1214        model_name: Optional[str] = None,
1215        limit: int = 100_000,
1216        verbose=False,
1217    ) -> pd.DataFrame:
1218        logs = self.get_raw_pipeline_inference_logs(
1219            topic, start, end, model_name, limit, verbose
1220        )
1221        return inference_logs_to_dataframe(logs)
1222
1223    def get_assay_results(
1224        self,
1225        assay_id: int,
1226        start: datetime,
1227        end: datetime,
1228    ) -> AssayAnalysisList:
1229        """Gets the assay results for a particular time window, parses them, and returns an
1230        AssayAnalysisList of AssayAnalysis.
1231        :param assay_id int: The id of the assay we are looking for.
1232        :param start datetime: The start of the time window
1233        :param end datetime: The end of the time window
1234        """
1235        res = assays_get_assay_results.sync(
1236            client=self.mlops(),
1237            json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody(
1238                assay_id, start, end
1239            ),
1240        )
1241
1242        if res is None:
1243            raise Exception("Failed to list models")
1244
1245        if not isinstance(res, List):
1246            raise Exception(res.msg)
1247
1248        if len(res) != 0 and not isinstance(
1249            res[0], AssaysGetAssayResultsResponse200Item
1250        ):
1251            raise Exception("invalid response")
1252
1253        return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res])
1254
1255    def build_assay(
1256        self,
1257        assay_name: str,
1258        pipeline: Pipeline,
1259        model_name: str,
1260        baseline_start: datetime,
1261        baseline_end: datetime,
1262    ) -> AssayBuilder:
1263        """Creates an AssayBuilder that can be used to configure and create
1264        Assays.
1265        :param assay_name str: Human friendly name for the assay
1266        :param pipeline Pipeline: The pipeline this assay will work on
1267        :param model_name str: The model that this assay will monitor
1268        :param baseline_start datetime: The start time for the inferences to
1269            use as the baseline
1270        :param baseline_end datetime: The end time of the baseline window.
1271        the baseline. Windows start immediately after the baseline window and
1272        are run at regular intervals continously until the assay is deactivated
1273        or deleted.
1274        """
1275        assay_builder = AssayBuilder(
1276            self,
1277            assay_name,
1278            pipeline.id(),
1279            pipeline.name(),
1280            model_name,
1281            baseline_start,
1282            baseline_end,
1283        )
1284
1285        return assay_builder
1286
1287    def upload_assay(self, config: AssayConfig) -> int:
1288        """Creates an assay in the database.
1289        :param config AssayConfig: The configuration for the assay to create.
1290        :return assay_id: The identifier for the assay that was created.
1291        :rtype int
1292        """
1293        data = assays_create.sync(
1294            client=self.mlops(),
1295            json_body=AssaysCreateJsonBody.from_dict(
1296                {
1297                    **json.loads(config.to_json()),
1298                    "created_at": datetime.now(timezone.utc).isoformat(),
1299                }
1300            ),
1301        )
1302
1303        if data is None:
1304            raise Exception("Failed to create assay")
1305
1306        if not isinstance(data, AssaysCreateResponse200):
1307            raise Exception(data.msg)
1308
1309        return data.assay_id
1310
1311    def list_assays(self) -> List[Assay]:
1312        """List all assays on the platform.
1313
1314        :return: A list of all assays on the platform.
1315        :rtype: List[Assay]
1316        """
1317        res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET))
1318
1319        if res is None:
1320            raise Exception("Failed to get assays")
1321
1322        if not isinstance(res, List):
1323            raise Exception(res.msg)
1324
1325        return Assays([Assay(client=self, data=v.to_dict()) for v in res])
1326
1327    def create_tag(self, tag_text: str) -> Tag:
1328        """Create a new tag with the given text."""
1329        assert tag_text is not None
1330        return Tag._create_tag(client=self, tag_text=tag_text)
1331
1332    def create_workspace(self, workspace_name: str) -> Workspace:
1333        """Create a new workspace with the current user as its first owner.
1334
1335        :param str workspace_name: Name of the workspace, must be composed of ASCII
1336           alpha-numeric characters plus dash (-)"""
1337        assert workspace_name is not None
1338        require_dns_compliance(workspace_name)
1339        return Workspace._create_workspace(client=self, name=workspace_name)
1340
1341    def list_workspaces(self) -> List[Workspace]:
1342        """List all workspaces on the platform which this user has permission see.
1343
1344        :return: A list of all workspaces on the platform.
1345        :rtype: List[Workspace]
1346        """
1347        res = workspaces_list.sync(
1348            client=self.mlops(),
1349            json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET),
1350        )
1351
1352        if res is None:
1353            raise Exception("Failed to get workspaces")
1354
1355        if not isinstance(res, WorkspacesListResponse200):
1356            raise Exception(res.msg)
1357
1358        return Workspaces(
1359            [Workspace(client=self, data=d.to_dict()) for d in res.workspaces]
1360        )
1361
1362    def set_current_workspace(self, workspace: Workspace) -> Workspace:
1363        """Any calls involving pipelines or models will use the given workspace from then on."""
1364        assert workspace is not None
1365        if not isinstance(workspace, Workspace):
1366            raise TypeError("Workspace type was expected")
1367
1368        self._current_workspace = workspace
1369        return cast("Workspace", self._current_workspace)
1370
1371    def get_current_workspace(self) -> Workspace:
1372        """Return the current workspace.  See also `set_current_workspace`."""
1373        if self._current_workspace is None:
1374            # Is there a default? Use that or make one.
1375            default_ws = Workspace._get_user_default_workspace(self)
1376            if default_ws is not None:
1377                self._current_workspace = default_ws
1378            else:
1379                self._current_workspace = Workspace._create_user_default_workspace(self)
1380
1381        return cast("Workspace", self._current_workspace)
1382
1383    def invite_user(self, email, password=None):
1384        return User.invite_user(
1385            email, password, self.auth, self.api_endpoint, self.auth_endpoint
1386        )
1387
1388    def get_topic_name(self, pipeline_pk_id: int) -> str:
1389        return self._post_rest_api(
1390            "v1/api/plateau/get_topic_name",
1391            {
1392                "pipeline_pk_id": pipeline_pk_id,
1393            },
1394        ).json()["topic_name"]
1395
1396    def shim_token(self, token_data: auth.TokenData):
1397        fetcher = auth._RawTokenFetcher(token_data)
1398        self.auth = auth._PlatformAuth(fetcher)
1399
1400    def convert_model(
1401        self,
1402        path: Union[str, pathlib.Path],
1403        source_type: ModelConversionSource,
1404        conversion_arguments: ModelConversionArguments,
1405    ) -> Model:
1406        """
1407        Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments.
1408        Convert the model to onnx, and add to available models for a pipeline.
1409
1410        :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model.
1411        :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost.
1412        :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type.
1413        :return: An instance of the Model being converted to Onnx.
1414        :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance.
1415        :raises ModelConversionFailure: Failure in converting the model type.
1416        :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported.
1417        :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist.
1418        """
1419        if isinstance(path, str):
1420            path = pathlib.Path(path)
1421        if not os.path.exists(path):
1422            raise ModelConversionSourceFileNotPresent(
1423                f"The provided source file: {path} can not be found."
1424            )
1425        file_handle = open(path, "rb")
1426        files = [("files", file_handle)]
1427        base_url = self.api_endpoint
1428        workspace_id = self.get_current_workspace().id()
1429        common_headers = {
1430            "user_id": self.auth.user_id(),
1431            "user_email": self.auth.user_email(),
1432            "User-Agent": _user_agent,
1433        }
1434        curry_post = partial(
1435            requests.post,
1436            auth=self.auth,
1437            headers=common_headers,
1438            files=files,
1439        )
1440        model_id: Union[int, None] = None
1441
1442        def _handle_response(http_response) -> int:
1443            http_response.raise_for_status()
1444            response_record = http_response.json()
1445            if response_record is not None and "model_id" in response_record:
1446                return int(response_record["model_id"])
1447            else:
1448                raise ModelConversionFailure("Failed to convert keras model")
1449
1450        try:
1451            if source_type == ModelConversionSource.KERAS:
1452                assert (
1453                    type(conversion_arguments).__name__
1454                    == ConvertKerasArguments.__name__
1455                )
1456                ## This is not optimal but api-lb(envoy) -> python-api
1457                ## Gave a 50(3|4|2|0) on several other options
1458                ## dimensions being of type List[Any]
1459                ## dimensions being a base64 encoded json array
1460                ## dimensions being a comma seperated string that was proccessed server side.
1461                ## This casts the ConvertKerasArguments dict structure to a json string.
1462                ## Base64 encoding via a utf-8 binary conversion, passed as the url
1463                ## parameter config. Which is handlded server side.
1464                data = {
1465                    **conversion_arguments.to_dict(),
1466                    "workspace_id": workspace_id,
1467                }
1468                params = {
1469                    "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode(
1470                        "utf-8"
1471                    )
1472                }
1473                response = curry_post(url=f"{base_url}/v1/convert/keras", params=params)
1474                model_id = _handle_response(response)
1475            elif source_type == ModelConversionSource.SKLEARN:
1476                assert (
1477                    type(conversion_arguments).__name__
1478                    == ConvertSKLearnArguments.__name__
1479                )
1480                response = curry_post(
1481                    url=f"{base_url}/v1/convert/sklearn",
1482                    params={
1483                        **conversion_arguments.to_dict(),
1484                        "workspace_id": workspace_id,
1485                    },
1486                )
1487                model_id = _handle_response(response)
1488
1489            elif source_type == ModelConversionSource.XGBOOST:
1490                assert (
1491                    type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__
1492                )
1493                response = curry_post(
1494                    url=f"{base_url}/v1/convert/xgboost",
1495                    params={
1496                        **conversion_arguments.to_dict(),
1497                        "workspace_id": workspace_id,
1498                    },
1499                )
1500                model_id = _handle_response(response)
1501            else:
1502                raise ModelConversionUnsupportedType(
1503                    f"Unsupported model source type of {source_type} passed."
1504                )
1505            if model_id is None:
1506                raise ModelConversionFailure("Failed to retrieve final model id")
1507            return Model(self, {"id": model_id})
1508        except Exception:
1509            raise ModelConversionGenericException(
1510                "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai"
1511            )
1512        finally:
1513            file_handle.close()
1514
1515    def _post_rest_api_json(self, uri: str, body: dict):
1516        result = self._post_rest_api(uri, body)
1517        if result.status_code == 200:
1518            return result.json()
1519        else:
1520            raise Exception(f"{result.status_code}: {result.text}")
1521
1522    def _setup_mlops_client(self) -> "AuthenticatedClient":
1523        self._mlops = AuthenticatedClient(
1524            base_url=self.api_endpoint, token=self.auth._access_token().token
1525        )
1526        return self._mlops
1527
1528    def mlops(self) -> "AuthenticatedClient":
1529        return self._setup_mlops_client()
class Client:
  94class Client(object):
  95    """Client handle to a Wallaroo platform instance.
  96
  97    Objects of this class serve as the entrypoint to Wallaroo platform
  98    functionality.
  99    """
 100
 101    @staticmethod
 102    def get_urls(
 103        auth_type: Optional[str], api_endpoint: str, auth_endpoint: str
 104    ) -> Tuple[Optional[str], str, str]:
 105        """Method to calculate the auth values specified as defaults,
 106        as params or in ENV vars.
 107        Made static to be testable without reaching out to SSO, etc."""
 108
 109        if auth_type is None:
 110            auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None)
 111
 112        # ideally we'd set auth_endpoint to None default value but that would
 113        # make the auth_endpoint type to be Optiona[str] which messes up
 114        # a lot| of type hinting and I wanted to make minimal changes without a
 115        # lot of 'warnings'.
 116        if len(auth_endpoint.strip()) == 0:
 117            auth_endpoint = (
 118                os.environ.get(WALLAROO_AUTH_URL)
 119                or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT)
 120                or "http://api-lb:8080"
 121            )
 122
 123        api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint)
 124
 125        return auth_type, api_endpoint, auth_endpoint
 126
 127    def __init__(
 128        self,
 129        api_endpoint: str = "http://api-lb:8080",
 130        auth_endpoint: str = "",
 131        request_timeout: int = 45,
 132        auth_type: Optional[str] = None,
 133        gql_client: Optional[gql.Client] = None,
 134        pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432",
 135        interactive: Optional[bool] = None,
 136        time_format: str = "%Y-%d-%b %H:%M:%S",
 137    ):
 138        """Create a Client handle.
 139
 140        :param str api_endpoint: Host/port of the platform API endpoint
 141        :param str auth_endpoint: Host/port of the platform Keycloak instance
 142        :param int timeout: Max timeout of web requests, in seconds
 143        :param str auth_type: Authentication type to use. Can be one of: "none",
 144            "sso", "user_password".
 145        :param str pg_connection_string: Postgres connection string
 146        :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise.
 147        :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context.
 148        """
 149
 150        auth_type, api_endpoint, auth_endpoint = Client.get_urls(
 151            auth_type, api_endpoint, auth_endpoint
 152        )
 153
 154        self.auth = auth.create(auth_endpoint, auth_type)
 155
 156        if gql_client:
 157            self._gql_client = gql_client
 158        else:
 159            gql_transport = RequestsHTTPTransport(
 160                url=posixpath.join(api_endpoint, "v1/graphql"),
 161                auth=self.auth,
 162                timeout=request_timeout,
 163            )
 164            self._gql_client = gql.Client(
 165                transport=gql_transport, fetch_schema_from_transport=True
 166            )
 167
 168        self.api_endpoint = api_endpoint
 169
 170        self.auth_endpoint = auth_endpoint
 171
 172        self.timeout = request_timeout
 173
 174        self._setup_mlops_client()
 175
 176        self.pg_connection_string = pg_connection_string
 177
 178        self._current_workspace: Optional[Workspace] = None
 179
 180        # TODO: debate the names of these things
 181        self._default_ws_name: Optional[str] = None
 182
 183        user_email = self.auth.user_email()
 184        if user_email is not None:
 185            self._default_ws_name = user_email + "_ws"
 186
 187        if interactive is not None:
 188            self._interactive = interactive
 189        elif (
 190            "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ
 191        ):
 192            self._interactive = True
 193        else:
 194            self._interactive = False
 195
 196        self._time_format = time_format
 197
 198    def _get_rest_api(self, path: str, params: dict):
 199        headers = {
 200            "authorization": self.auth._bearer_token_str(),
 201            "user-agent": _user_agent,
 202        }
 203
 204        url = f"{self.api_endpoint}/{path}"
 205
 206        return requests.get(url, headers=headers, params=params)
 207
 208    def _post_rest_api(self, path: str, body: dict):
 209        headers = {
 210            "authorization": self.auth._bearer_token_str(),
 211            "user-agent": _user_agent,
 212        }
 213
 214        url = f"{self.api_endpoint}/{path}"
 215        return requests.post(url, headers=headers, json=body)
 216
 217    def list_tags(self) -> Tags:
 218        """List all tags on the platform.
 219
 220        :return: A list of all tags on the platform.
 221        :rtype: List[Tag]
 222        """
 223        res = self._gql_client.execute(
 224            gql.gql(
 225                """
 226            query ListTags {
 227              tag(order_by: {id: desc}) {
 228                id
 229                tag
 230                model_tags {
 231                  model {
 232                    id
 233                    model_id
 234                    models_pk_id
 235                    model_version
 236                    
 237                  }
 238                }
 239                pipeline_tags {
 240                  pipeline {
 241                    id
 242                    pipeline_id
 243                    pipeline_versions {
 244                        id
 245                        version
 246                    }
 247                  }
 248                }
 249              }
 250            }
 251
 252
 253            """
 254            )
 255        )
 256        return Tags([Tag(client=self, data={"id": p["id"]}) for p in res["tag"]])
 257
 258    def list_models(self) -> ModelsList:
 259        """List all models on the platform.
 260
 261        :return: A list of all models on the platform.
 262        :rtype: List[Model]
 263        """
 264        id = self.get_current_workspace().id()
 265        res = models_list.sync(
 266            client=self.mlops(),
 267            json_body=models_list_json_body.ModelsListJsonBody(id),
 268        )
 269
 270        if res is None:
 271            raise Exception("Failed to list models")
 272
 273        if not isinstance(res, ModelsListResponse200):
 274            raise Exception(res.msg)
 275
 276        return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models])
 277
 278    def list_deployments(self) -> List[Deployment]:
 279        """List all deployments (active or not) on the platform.
 280
 281        :return: A list of all deployments on the platform.
 282        :rtype: List[Deployment]
 283        """
 284        res = self._gql_client.execute(
 285            gql.gql(
 286                """
 287            query ListDeployments {
 288              deployment {
 289                id
 290                deploy_id
 291                deployed
 292                deployment_model_configs {
 293                  model_config {
 294                    id
 295                  }
 296                }
 297              }
 298            }
 299            """
 300            )
 301        )
 302        return [Deployment(client=self, data=d) for d in res["deployment"]]
 303
 304    """
 305        # Removed until we figure out what pipeline ownership means
 306        #
 307        # def search_my_pipelines(
 308        #     self,
 309        #     search_term: Optional[str] = None,
 310        #     deployed: Optional[bool] = None,
 311        #     created_start: Optional["Datetime"] = None,
 312        #     created_end: Optional["Datetime"] = None,
 313        #     updated_start: Optional["Datetime"] = None,
 314        #     updated_end: Optional["Datetime"] = None,
 315        # ) -> List[Pipeline]:
 316        #     user_id = self.auth.user_id()
 317        #     return Pipelines(
 318        #         self._search_pipelines(
 319        #             search_term,
 320        #             deployed,
 321        #             user_id,
 322        #             created_start,
 323        #             created_end,
 324        #             updated_start,
 325        #             updated_end,
 326        #         )
 327        #     )
 328    """
 329
 330    def search_pipelines(
 331        self,
 332        search_term: Optional[str] = None,
 333        deployed: Optional[bool] = None,
 334        created_start: Optional["Datetime"] = None,
 335        created_end: Optional["Datetime"] = None,
 336        updated_start: Optional["Datetime"] = None,
 337        updated_end: Optional["Datetime"] = None,
 338    ) -> PipelineVariants:
 339        """Search for pipelines. All parameters are optional, in which case the result is the same as
 340        `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example:
 341
 342             myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo")
 343
 344        :param str search_term: Will be matched against tags and model names. Example: "footag123".
 345        :param bool deployed: Pipeline was deployed or not
 346        :param str created_start: Pipeline was created at or after this time
 347        :param str created_end: Pipeline was created at or before this time
 348        :param str updated_start: Pipeline was updated at or before this time
 349        :param str updated_end: Pipeline was updated at or before this time
 350
 351        :return: A list of all pipelines on the platform.
 352        :rtype: List[Pipeline]
 353        """
 354        return PipelineVariants(
 355            self._search_pipelines(
 356                search_term,
 357                deployed,
 358                None,
 359                created_start,
 360                created_end,
 361                updated_start,
 362                updated_end,
 363            )
 364        )
 365
 366    def _search_pipelines(
 367        self,
 368        search_term: Optional[str] = None,
 369        deployed: Optional[bool] = None,
 370        user_id: Optional[str] = None,
 371        created_start: Optional["Datetime"] = None,
 372        created_end: Optional["Datetime"] = None,
 373        updated_start: Optional["Datetime"] = None,
 374        updated_end: Optional["Datetime"] = None,
 375    ) -> List[PipelineVariant]:
 376        (query, params) = self._generate_search_pipeline_query(
 377            search_term=search_term,
 378            deployed=deployed,
 379            user_id=user_id,
 380            created_start=created_start,
 381            created_end=created_end,
 382            updated_start=updated_start,
 383            updated_end=updated_end,
 384        )
 385        q = gql.gql(query)
 386        data = self._gql_client.execute(q, variable_values=params)
 387        pipelines = []
 388        if data["search_pipelines"]:
 389            for p in data["search_pipelines"]:
 390                pipelines.append(PipelineVariant(self, p))
 391        return pipelines
 392
 393    def _generate_search_pipeline_query(
 394        self,
 395        search_term: Optional[str] = None,
 396        deployed: Optional[bool] = None,
 397        user_id: Optional[str] = None,
 398        created_start: Optional["Datetime"] = None,
 399        created_end: Optional["Datetime"] = None,
 400        updated_start: Optional["Datetime"] = None,
 401        updated_end: Optional["Datetime"] = None,
 402    ):
 403        filters = []
 404        query_params = []
 405        params: Dict[str, Any] = {}
 406        search = ""
 407        if search_term:
 408            search = search_term
 409        params["search_term"] = search
 410        query_params.append("$search_term: String!")
 411
 412        if user_id:
 413            filters.append("owner_id: {_eq: $user_id}")
 414            params["user_id"] = user_id
 415            query_params.append("$user_id: String!")
 416
 417        if deployed is not None:
 418            filters.append("pipeline: {deployment: {deployed: {_eq: $deployed}}}")
 419            params["deployed"] = deployed
 420            query_params.append("$deployed: Boolean")
 421
 422        self._generate_time_range_graphql(
 423            "created_at",
 424            start=created_start,
 425            end=created_end,
 426            filters=filters,
 427            query_params=query_params,
 428            params=params,
 429        )
 430        self._generate_time_range_graphql(
 431            "updated_at",
 432            start=updated_start,
 433            end=updated_end,
 434            filters=filters,
 435            query_params=query_params,
 436            params=params,
 437        )
 438
 439        where_clause_str = self._generate_where_clause_str(filters)
 440        query_param_str = self._generate_query_param_str(query_params)
 441        query = f"""
 442            query GetPipelines({query_param_str}) {{
 443                search_pipelines(args: {{search: $search_term}}, distinct_on: id{where_clause_str}, order_by: {{id: desc}}) {{
 444                    id
 445                    created_at
 446                    pipeline_pk_id
 447                    updated_at
 448                    version
 449                    pipeline {{
 450                        id
 451                        pipeline_id
 452                        pipeline_tags {{
 453                            id
 454                            tag {{
 455                                id
 456                                tag
 457                            }}
 458                        }}
 459                    }}
 460                }}
 461            }}
 462        """
 463        return (query, params)
 464
 465    def _generate_where_clause_str(self, filters: List[str]) -> str:
 466        where_clause_str = ""
 467        filters_len = len(filters)
 468        if filters_len > 0:
 469            if filters_len > 1:
 470                where_clause_str = f""", where: {{_and: {{ {", ".join(filters)} }}}}"""
 471            else:
 472                where_clause_str = f", where: {{{filters[0]}}}"
 473        return where_clause_str
 474
 475    def _generate_query_param_str(self, query_params: List[str]):
 476        return ", ".join(query_params)
 477
 478    def _generate_time_range_graphql(
 479        self,
 480        field: str,
 481        start: Optional["Datetime"],
 482        end: Optional["Datetime"],
 483        filters: List[str],
 484        query_params: List[str],
 485        params: Dict[str, Any],
 486    ):
 487        if start and not end:
 488            filters.append(f"{field}: {{_gte: $start_{field}}}")
 489            params[f"start_{field}"] = start
 490            query_params.append(f"$start_{field}: timestamptz!")
 491        elif end and not start:
 492            filters.append(f"{field}: {{_lte: $end_{field}}}")
 493            params[f"end_{field}"] = end
 494            query_params.append(f"$end_{field}: timestamptz!")
 495        elif start and end:
 496            filters.append(f"{field}: {{_gte: $start_{field}, _lte: $end_{field}}}")
 497            params[f"start_{field}"] = start
 498            params[f"end_{field}"] = start
 499            query_params.append(f"$start_{field}: timestamptz!")
 500            query_params.append(f"$end_{field}: timestamptz!")
 501
 502    def search_my_models(
 503        self,
 504        search_term: Optional[str] = None,
 505        uploaded_time_start: Optional["Datetime"] = None,
 506        uploaded_time_end: Optional["Datetime"] = None,
 507    ) -> ModelVersions:
 508        """Search models owned by you
 509        params:
 510             search_term: Searches the following metadata: names, shas, versions, file names, and tags
 511             uploaded_time_start: Inclusive time of upload
 512             uploaded_time_end: Inclusive time of upload
 513        """
 514        user_id = self.auth.user_id()
 515        return ModelVersions(
 516            self._search_models(
 517                search_term=search_term,
 518                user_id=user_id,
 519                start=uploaded_time_start,
 520                end=uploaded_time_end,
 521            )
 522        )
 523
 524    def search_models(
 525        self,
 526        search_term: Optional[str] = None,
 527        uploaded_time_start: Optional["Datetime"] = None,
 528        uploaded_time_end: Optional["Datetime"] = None,
 529    ) -> ModelVersions:
 530        """Search all models you have access to.
 531        params:
 532             search_term: Searches the following metadata: names, shas, versions, file names, and tags
 533             uploaded_time_start: Inclusive time of upload
 534             uploaded_time_end: Inclusive time of upload
 535        """
 536        return ModelVersions(
 537            self._search_models(
 538                search_term=search_term,
 539                start=uploaded_time_start,
 540                end=uploaded_time_end,
 541            )
 542        )
 543
 544    def _search_models(
 545        self, search_term=None, user_id=None, start=None, end=None
 546    ) -> List[Model]:
 547        (query, params) = self._generate_model_query(
 548            search_term=search_term,
 549            user_id=user_id,
 550            start=start,
 551            end=end,
 552        )
 553
 554        q = gql.gql(query)
 555
 556        data = self._gql_client.execute(q, variable_values=params)
 557        models = []
 558        if data["search_models"]:
 559            for m in data["search_models"]:
 560                models.append(Model(self, m))
 561        return models
 562
 563    def _generate_model_query(
 564        self,
 565        search_term=None,
 566        user_id=None,
 567        start=None,
 568        end=None,
 569    ):
 570        filters = []
 571        query_params = []
 572        params = {}
 573        search = ""
 574        if search_term:
 575            search = search_term
 576        params["search_term"] = search
 577        query_params.append("$search_term: String!")
 578        if user_id:
 579            filters.append("owner_id: {_eq: $user_id}")
 580            params["user_id"] = user_id
 581            query_params.append("$user_id: String!")
 582
 583        self._generate_time_range_graphql(
 584            "created_at",
 585            start=start,
 586            end=end,
 587            filters=filters,
 588            params=params,
 589            query_params=query_params,
 590        )
 591
 592        where_clause_str = self._generate_where_clause_str(filters)
 593        query_param_str = self._generate_query_param_str(query_params)
 594        query = f"""
 595            query GetModels({query_param_str}) {{
 596              search_models(args: {{search: $search_term}}{where_clause_str}, order_by: {{created_at: desc}}) {{
 597                id
 598              }}
 599            }}
 600        """
 601        return (query, params)
 602
 603    def get_user_by_email(self, email: str) -> Optional[User]:
 604        """Find a user by email"""
 605        assert email is not None
 606        escaped_email = quote_plus(email)
 607        url = (
 608            f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}"
 609        )
 610        headers = {
 611            "Content-Type": "application/json",
 612            "Authorization": self.auth._bearer_token_str(),
 613            "User-Agent": _user_agent,
 614        }
 615        resp = requests.get(url, headers=headers, data={})
 616        jresp = resp.json()
 617        return None if jresp == [] else User(client=self, data=jresp[0])
 618
 619    def deactivate_user(self, email: str) -> None:
 620        """Deactivates an existing user of the platform
 621
 622        Deactivated users cannot log into the platform.
 623        Deactivated users do not count towards the number of allotted user seats from the license.
 624
 625        The Models and Pipelines owned by the deactivated user are not removed from the platform.
 626
 627        :param str email: The email address of the user to deactivate.
 628
 629        :return: None
 630        :rtype: None
 631        """
 632
 633        if self.auth.user_email() == email:
 634            raise Exception("A user may not deactive themselves.")
 635
 636        user = self.get_user_by_email(email)
 637
 638        if user is None:
 639            raise EntityNotFoundError("User", {"email": email})
 640
 641        if user.username() == "admin":
 642            raise Exception("Admin user may not be deactivated.")
 643
 644        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
 645        headers = {
 646            "Content-Type": "application/json",
 647            "Authorization": self.auth._bearer_token_str(),
 648            "User-Agent": _user_agent,
 649        }
 650
 651        # Get the current full user representation to return in the mutation due to keycloak bug
 652        get_user_response = requests.get(url, headers=headers, data={})
 653
 654        cur_user_rep = get_user_response.json()
 655        cur_user_rep["enabled"] = False
 656
 657        resp = requests.put(url, headers=headers, json=cur_user_rep)
 658
 659        if resp.status_code != 204:
 660            raise EntityNotFoundError("User", {"email": email})
 661        return None
 662
 663    def activate_user(self, email: str) -> None:
 664        """Activates an existing user of the platform that had been previously deactivated.
 665
 666        Activated users can log into the platform.
 667
 668        :param str email: The email address of the user to activate.
 669
 670        :return: None
 671        :rtype: None
 672        """
 673        user = self.get_user_by_email(email)
 674
 675        if user is None:
 676            raise EntityNotFoundError("User", {"email": email})
 677
 678        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
 679
 680        headers = {
 681            "Content-Type": "application/json",
 682            "Authorization": self.auth._bearer_token_str(),
 683            "User-Agent": _user_agent,
 684        }
 685
 686        # Get the current full user representation to return in the mutation due to keycloak bug
 687        get_user_response = requests.get(url, headers=headers, data={})
 688
 689        cur_user_rep = get_user_response.json()
 690        cur_user_rep["enabled"] = True
 691
 692        resp = requests.put(url, headers=headers, json=cur_user_rep)
 693
 694        if resp.status_code != 204:
 695            raise EntityNotFoundError("User", {"email": email})
 696        return None
 697
 698    def _get_user_by_id(self, id: str) -> Optional[User]:
 699        assert id is not None
 700        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{id}"
 701        headers = {
 702            "Content-Type": "application/json",
 703            "Authorization": self.auth._bearer_token_str(),
 704            "User-Agent": _user_agent,
 705        }
 706        resp = requests.get(url, headers=headers, data={})
 707        jresp = resp.json()
 708        return None if jresp == [] else User(client=self, data=jresp)
 709
 710    def list_users(self) -> List[User]:
 711        """List of all Users on the platform
 712
 713        :return: A list of all Users on the platform.
 714        :rtype: List[User]
 715        """
 716        resp = User.list_users(auth=self.auth)
 717        return [User(client=self, data=u) for u in resp]
 718
 719    def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model:
 720        """Upload a model defined by a file as a new model variant.
 721
 722        :param str model_name: The name of the model of which this is a variant.
 723            Names must be ASCII alpha-numeric characters or dash (-) only.
 724        :param Union[str, pathlib.Path] path: Path of the model file to upload.
 725        :return: The created Model.
 726        :rtype: Model
 727        """
 728
 729        _Visibility.PRIVATE
 730        if isinstance(path, str):
 731            path = pathlib.Path(path)
 732        with path.open("rb") as f:
 733            return self._upload_model_stream(name, {"filename": path.name}, f)
 734
 735    def _upload_model_stream(self, name: str, data: Dict[str, Any], file: Any):
 736        require_dns_compliance(name)
 737        endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload_stream")
 738        data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()}
 739        headers = {"User-Agent": _user_agent}
 740
 741        res = requests.post(
 742            endpoint, auth=self.auth, params=data, data=file, headers=headers
 743        )
 744        if res.status_code != 200:
 745            raise ModelUploadError(res.text)
 746
 747        res_dict = json.loads(res.text)
 748        return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0])
 749
 750    def register_model_image(self, name: str, image: str) -> Model:
 751        """Registers an MLFlow model as a new model.
 752
 753        :param str model_name: The name of the model of which this is a variant.
 754            Names must be ASCII alpha-numeric characters or dash (-) only.
 755        :param str image: Image name of the MLFlow model to register.
 756        :return: The created Model.
 757        :rtype: Model
 758        """
 759        data = {
 760            "image_path": image,
 761        }
 762        return self._upload_model(name, data)
 763
 764    def _upload_model(
 765        self, name: str, data: Dict[str, Any], files: Dict[str, Tuple[str, bytes]] = {}
 766    ):
 767        require_dns_compliance(name)
 768        endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload")
 769        data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()}
 770        if len(files) == 0:
 771            files = {"dummy": ("none", b"")}
 772
 773        headers = {"User-Agent": _user_agent}
 774
 775        res = requests.post(
 776            endpoint, files=files, auth=self.auth, data=data, headers=headers
 777        )
 778        if res.status_code != 200:
 779            raise ModelUploadError(res.text)
 780
 781        res_dict = json.loads(res.text)
 782        return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0])
 783
 784    def model_by_name(self, model_class: str, model_name: str) -> Model:
 785        """Fetch a Model by name.
 786
 787        :param str model_class: Name of the model class.
 788        :param str model_name: Name of the variant within the specified model class.
 789        :return: The Model with the corresponding model and variant name.
 790        :rtype: Model
 791        """
 792        res = self._gql_client.execute(
 793            gql.gql(
 794                """
 795                query ModelByName($model_id: String!, $model_version: String!) {
 796                  model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) {
 797                    id
 798                    model_id
 799                    model_version
 800                  }
 801                }
 802                """
 803            ),
 804            variable_values={
 805                "model_id": model_class,
 806                "model_version": model_name,
 807            },
 808        )
 809        if not res["model"]:
 810            raise EntityNotFoundError(
 811                "Model", {"model_class": model_class, "model_name": model_name}
 812            )
 813        return Model(client=self, data={"id": res["model"][0]["id"]})
 814
 815    def deployment_by_name(self, deployment_name: str) -> Deployment:
 816        """Fetch a Deployment by name.
 817
 818        :param str deployment_name: Name of the deployment.
 819        :return: The Deployment with the corresponding name.
 820        :rtype: Deployment
 821        """
 822        res = self._gql_client.execute(
 823            gql.gql(
 824                """
 825                query DeploymentByName($deployment_name: String!) {
 826                  deployment(where: {deploy_id: {_eq: $deployment_name}}) {
 827                    id
 828                  }
 829                }
 830                """
 831            ),
 832            variable_values={
 833                "deployment_name": deployment_name,
 834            },
 835        )
 836        if not res["deployment"]:
 837            raise EntityNotFoundError(
 838                "Deployment", {"deployment_name": deployment_name}
 839            )
 840        return Deployment(client=self, data={"id": res["deployment"][0]["id"]})
 841
 842    def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]:
 843        """Fetch Pipelines by name.
 844
 845        :param str pipeline_name: Name of the pipeline.
 846        :return: The Pipeline with the corresponding name.
 847        :rtype: Pipeline
 848        """
 849        res = self._gql_client.execute(
 850            gql.gql(
 851                """
 852                query PipelineByName($pipeline_name: String!) {
 853                  pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) {
 854                    id
 855                  }
 856                }
 857                """
 858            ),
 859            variable_values={
 860                "pipeline_name": pipeline_name,
 861            },
 862        )
 863        assert "pipeline" in res
 864        length = len(res["pipeline"])
 865        if length < 1:
 866            raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name})
 867        return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]]
 868
 869    def list_pipelines(self) -> List[Pipeline]:
 870        """List all pipelines on the platform.
 871
 872        :return: A list of all pipelines on the platform.
 873        :rtype: List[Pipeline]
 874        """
 875        res = self._gql_client.execute(
 876            gql.gql(
 877                """
 878            query ListPipelines {
 879              pipeline(order_by: {id: desc}) {
 880                id
 881                pipeline_tags {
 882                  tag {
 883                    id
 884                    tag
 885                  }
 886                }
 887              }
 888            }
 889            """
 890            )
 891        )
 892        return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]])
 893
 894    def build_pipeline(self, pipeline_name: str) -> "Pipeline":
 895        """Starts building a pipeline with the given `pipeline_name`,
 896        returning a :py:PipelineConfigBuilder:
 897
 898        When completed, the pipeline can be uploaded with `.upload()`
 899
 900        :param pipeline_name string: Name of the pipeline, must be composed of ASCII
 901          alpha-numeric characters plus dash (-).
 902        """
 903
 904        require_dns_compliance(pipeline_name)
 905
 906        _Visibility.PRIVATE
 907
 908        # TODO: Needs to handle visibility?
 909        data = pipelines_create.sync(
 910            client=self.mlops(),
 911            json_body=pipelines_create_json_body.PipelinesCreateJsonBody(
 912                pipeline_name,
 913                self.get_current_workspace().id(),
 914                pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict(
 915                    {}
 916                ),
 917            ),
 918        )
 919
 920        if data is None:
 921            raise Exception("Failed to create pipeline")
 922
 923        if not isinstance(data, PipelinesCreateResponse200):
 924            raise Exception(data.msg)
 925
 926        return Pipeline(client=self, data={"id": data.pipeline_pk_id})
 927
 928    def _upload_pipeline_variant(
 929        self,
 930        name: str,
 931        config: PipelineConfig,
 932    ) -> Pipeline:
 933        """Creates a new PipelineVariant with the specified configuration.
 934
 935        :param str name: Name of the Pipeline. Must be unique across all Pipelines.
 936        :param config PipelineConfig: Pipeline configuration.
 937        """
 938        definition = config.to_json()
 939        _Visibility.PRIVATE
 940
 941        data = pipelines_create.sync(
 942            client=self.mlops(),
 943            json_body=pipelines_create_json_body.PipelinesCreateJsonBody(
 944                name,
 945                self.get_current_workspace().id(),
 946                pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict(
 947                    definition
 948                ),
 949            ),
 950        )
 951
 952        if data is None:
 953            # TODO: Generalize
 954            raise Exception("Failed to create pipeline")
 955
 956        if not isinstance(data, PipelinesCreateResponse200):
 957            raise Exception(data.msg)
 958
 959        for alert_config in config.alert_configurations:
 960            self._gql_client.execute(
 961                gql.gql(
 962                    """
 963                mutation CreateAlertConfiguration(
 964                    $pipeline_version_id: bigint,
 965                    $name: String,
 966                    $expression: String,
 967                    $notifications: jsonb
 968                ) {
 969                    insert_alert_configuration(objects: {
 970                        name: $name,
 971                        expression: $expression,
 972                        notifications: $notifications,
 973                        pipeline_version_pk_id: $pipeline_version_id
 974                    }) {
 975                        returning { id }
 976                    }
 977                }
 978                """
 979                ),
 980                variable_values={
 981                    **alert_config.to_json(),
 982                    "pipeline_version_id": data.pipeline_pk_id,
 983                },
 984            )
 985
 986        pipeline_data = data.to_dict()
 987        pipeline_data["id"] = data.pipeline_pk_id
 988
 989        return Pipeline(
 990            client=self,
 991            data=pipeline_data,
 992        )
 993
 994    def create_value_split_experiment(
 995        self,
 996        name: str,
 997        meta_key: str,
 998        default_model: ModelConfig,
 999        challenger_models: List[Tuple[Any, ModelConfig]],
1000    ) -> Pipeline:
1001        """Creates a new PipelineVariant of a "value-split experiment" type.
1002        :param str name: Name of the Pipeline
1003        :param meta_key str: Inference input key on which to redirect inputs to
1004            experiment models.
1005        :param default_model ModelConfig: Model to send inferences by default.
1006        :param challenger_models List[Tuple[Any, ModelConfig]]: A list of
1007            meta_key values -> Models to send inferences. If the inference data
1008            referred to by meta_key is equal to one of the keys in this tuple,
1009            that inference is redirected to the corresponding model instead of
1010            the default model.
1011        """
1012        args = [meta_key, default_model.model().name()]
1013        for v, m in challenger_models:
1014            args.append(v)
1015            args.append(m.model().name())
1016        step = {
1017            "id": "metavalue_split",
1018            "operation": "map",
1019            "args": args,
1020        }
1021        definition = {"id": name, "steps": [step]}
1022        # TODO: This seems like a one-to-one replace, find appropriate test.
1023        data = self._gql_client.execute(
1024            gql.gql(
1025                """
1026            mutation CreatePipeline(
1027                $pipeline_id: String,
1028                $version: String,
1029                $definition: jsonb,
1030                $workspace_id: bigint
1031            ) {
1032                insert_pipeline(
1033                    objects: {
1034                    pipeline_versions: {
1035                        data: { definition: $definition }
1036                    }
1037                    pipeline_id: $pipeline_id
1038                    }
1039                ) {
1040                    returning {
1041                        id
1042                    }
1043                }
1044            }
1045            """
1046            ),
1047            variable_values={
1048                "pipeline_id": name,
1049                "definition": definition,
1050                "workspace_id": self.get_current_workspace().id(),
1051            },
1052        )
1053        return Pipeline(
1054            client=self,
1055            data=data["insert_pipeline"]["returning"][0],
1056        )
1057
1058    def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]:
1059        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1060        headers = {"User-Agent": _user_agent}
1061        partitions = requests.get(base, auth=self.auth, headers=headers).json()[
1062            "partitions"
1063        ]
1064
1065        iterator = {
1066            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1067            for k, span in partitions.items()
1068        }
1069
1070        response = requests.post(
1071            base + "/records",
1072            params={"limit": limit},
1073            json=iterator,
1074            auth=self.auth,
1075            headers=headers,
1076        ).json()
1077
1078        return (
1079            LogEntries([LogEntry(json.loads(l)) for l in response["records"]]),
1080            response.get("status", "None"),
1081        )
1082
1083    def security_logs(self, limit: int) -> List[dict]:
1084        """This function is not available in this release"""
1085        raise RuntimeError("security_log() is not available in this release.")
1086
1087    def get_raw_logs(
1088        self,
1089        topic: str,
1090        start: Optional[datetime] = None,
1091        end: Optional[datetime] = None,
1092        limit: int = 100_000,
1093        parse: bool = False,
1094        verbose: bool = False,
1095    ) -> List[Dict[str, Any]]:
1096        """Gets logs from Plateau for a particular time window without attempting
1097        to convert them to Inference LogEntries. Logs can be returned as strings
1098        or the json parsed into lists and dicts.
1099        :param topic str: The name of the topic to query
1100        :param start Optional[datetime]: The start of the time window
1101        :param end Optional[datetime]: The end of the time window
1102        :param limit int: The number of records to retrieve. Note retrieving many
1103            records may be a performance bottleneck.
1104        :param parse bool: Wether to attempt to parse the string as a json object.
1105        :param verbose bool: Prints out info to help diagnose issues.
1106        """
1107
1108        assert limit <= 1_000_000
1109
1110        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1111        headers = {"User-Agent": _user_agent}
1112        resp = requests.get(base, auth=self.auth, headers=headers)
1113        if resp.status_code != 200:
1114            raise EntityNotFoundError(
1115                f"Could not get partitions {resp.text}", {"url": base}
1116            )
1117        data = resp.json()
1118        partitions = data["partitions"]
1119
1120        if verbose:
1121            print(f"Got partitions {partitions}")
1122
1123        params: Dict[str, Any] = {"limit": limit}
1124        if start is not None:
1125            start_str = start.astimezone(tz=timezone.utc).isoformat()
1126            params["time.start"] = start_str
1127        if end is not None:
1128            end_str = end.astimezone(tz=timezone.utc).isoformat()
1129            params["time.end"] = end_str
1130
1131        next: Union[Any, None] = {
1132            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1133            for k, span in partitions.items()
1134        }
1135
1136        if verbose:
1137            print("Using params: ", params)
1138
1139        records = []
1140        while next is not None:
1141            response = requests.post(
1142                base + "/records",
1143                params=params,
1144                json=next,
1145                auth=self.auth,
1146                headers=headers,
1147            )
1148            if response.status_code != 200:
1149                raise EntityNotFoundError(
1150                    f"Could not get records {resp.text}",
1151                    {"url": base, "params": str(params), "iterator": str(next)},
1152                )
1153
1154            if verbose:
1155                print("response: ", response)
1156
1157            result = response.json()
1158            result_records = result["records"]
1159            if len(result_records) > 0:
1160                records.extend(result_records)
1161                next = result["next"]
1162            else:
1163                next = None
1164        if parse:
1165            return [json.loads(r) for r in records]
1166        return records
1167
1168    def get_raw_pipeline_inference_logs(
1169        self,
1170        topic: str,
1171        start: datetime,
1172        end: datetime,
1173        model_name: Optional[str] = None,
1174        limit: int = 100_000,
1175        verbose: bool = False,
1176    ) -> List[Dict[str, Any]]:
1177        """Gets logs from Plateau for a particular time window and filters them for
1178        the model specified.
1179        :param pipeline_name str: The name/pipeline_id of the pipeline to query
1180        :param topic str: The name of the topic to query
1181        :param start Optional[datetime]: The start of the time window
1182        :param end Optional[datetime]: The end of the time window
1183        :param model_id: The name of the specific model to filter if any
1184        :param limit int: The number of records to retrieve. Note retrieving many
1185            records may be a performance bottleneck.
1186        :param verbose bool: Prints out info to help diagnose issues.
1187        """
1188        logs = self.get_raw_logs(
1189            topic,
1190            start=start,
1191            end=end,
1192            limit=limit,
1193            parse=True,
1194            verbose=verbose,
1195        )
1196
1197        if verbose:
1198            print(f"Got {len(logs)} initial logs")
1199
1200        if model_name:
1201            logs = [l for l in logs if l["model_name"] == model_name]
1202
1203        # inference results are a unix timestamp in millis - filter by that
1204        start_ts = int(start.timestamp() * 1000)
1205        end_ts = int(end.timestamp() * 1000)
1206        logs = [l for l in logs if start_ts <= l["time"] < end_ts]
1207
1208        return logs
1209
1210    def get_pipeline_inference_dataframe(
1211        self,
1212        topic: str,
1213        start: datetime,
1214        end: datetime,
1215        model_name: Optional[str] = None,
1216        limit: int = 100_000,
1217        verbose=False,
1218    ) -> pd.DataFrame:
1219        logs = self.get_raw_pipeline_inference_logs(
1220            topic, start, end, model_name, limit, verbose
1221        )
1222        return inference_logs_to_dataframe(logs)
1223
1224    def get_assay_results(
1225        self,
1226        assay_id: int,
1227        start: datetime,
1228        end: datetime,
1229    ) -> AssayAnalysisList:
1230        """Gets the assay results for a particular time window, parses them, and returns an
1231        AssayAnalysisList of AssayAnalysis.
1232        :param assay_id int: The id of the assay we are looking for.
1233        :param start datetime: The start of the time window
1234        :param end datetime: The end of the time window
1235        """
1236        res = assays_get_assay_results.sync(
1237            client=self.mlops(),
1238            json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody(
1239                assay_id, start, end
1240            ),
1241        )
1242
1243        if res is None:
1244            raise Exception("Failed to list models")
1245
1246        if not isinstance(res, List):
1247            raise Exception(res.msg)
1248
1249        if len(res) != 0 and not isinstance(
1250            res[0], AssaysGetAssayResultsResponse200Item
1251        ):
1252            raise Exception("invalid response")
1253
1254        return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res])
1255
1256    def build_assay(
1257        self,
1258        assay_name: str,
1259        pipeline: Pipeline,
1260        model_name: str,
1261        baseline_start: datetime,
1262        baseline_end: datetime,
1263    ) -> AssayBuilder:
1264        """Creates an AssayBuilder that can be used to configure and create
1265        Assays.
1266        :param assay_name str: Human friendly name for the assay
1267        :param pipeline Pipeline: The pipeline this assay will work on
1268        :param model_name str: The model that this assay will monitor
1269        :param baseline_start datetime: The start time for the inferences to
1270            use as the baseline
1271        :param baseline_end datetime: The end time of the baseline window.
1272        the baseline. Windows start immediately after the baseline window and
1273        are run at regular intervals continously until the assay is deactivated
1274        or deleted.
1275        """
1276        assay_builder = AssayBuilder(
1277            self,
1278            assay_name,
1279            pipeline.id(),
1280            pipeline.name(),
1281            model_name,
1282            baseline_start,
1283            baseline_end,
1284        )
1285
1286        return assay_builder
1287
1288    def upload_assay(self, config: AssayConfig) -> int:
1289        """Creates an assay in the database.
1290        :param config AssayConfig: The configuration for the assay to create.
1291        :return assay_id: The identifier for the assay that was created.
1292        :rtype int
1293        """
1294        data = assays_create.sync(
1295            client=self.mlops(),
1296            json_body=AssaysCreateJsonBody.from_dict(
1297                {
1298                    **json.loads(config.to_json()),
1299                    "created_at": datetime.now(timezone.utc).isoformat(),
1300                }
1301            ),
1302        )
1303
1304        if data is None:
1305            raise Exception("Failed to create assay")
1306
1307        if not isinstance(data, AssaysCreateResponse200):
1308            raise Exception(data.msg)
1309
1310        return data.assay_id
1311
1312    def list_assays(self) -> List[Assay]:
1313        """List all assays on the platform.
1314
1315        :return: A list of all assays on the platform.
1316        :rtype: List[Assay]
1317        """
1318        res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET))
1319
1320        if res is None:
1321            raise Exception("Failed to get assays")
1322
1323        if not isinstance(res, List):
1324            raise Exception(res.msg)
1325
1326        return Assays([Assay(client=self, data=v.to_dict()) for v in res])
1327
1328    def create_tag(self, tag_text: str) -> Tag:
1329        """Create a new tag with the given text."""
1330        assert tag_text is not None
1331        return Tag._create_tag(client=self, tag_text=tag_text)
1332
1333    def create_workspace(self, workspace_name: str) -> Workspace:
1334        """Create a new workspace with the current user as its first owner.
1335
1336        :param str workspace_name: Name of the workspace, must be composed of ASCII
1337           alpha-numeric characters plus dash (-)"""
1338        assert workspace_name is not None
1339        require_dns_compliance(workspace_name)
1340        return Workspace._create_workspace(client=self, name=workspace_name)
1341
1342    def list_workspaces(self) -> List[Workspace]:
1343        """List all workspaces on the platform which this user has permission see.
1344
1345        :return: A list of all workspaces on the platform.
1346        :rtype: List[Workspace]
1347        """
1348        res = workspaces_list.sync(
1349            client=self.mlops(),
1350            json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET),
1351        )
1352
1353        if res is None:
1354            raise Exception("Failed to get workspaces")
1355
1356        if not isinstance(res, WorkspacesListResponse200):
1357            raise Exception(res.msg)
1358
1359        return Workspaces(
1360            [Workspace(client=self, data=d.to_dict()) for d in res.workspaces]
1361        )
1362
1363    def set_current_workspace(self, workspace: Workspace) -> Workspace:
1364        """Any calls involving pipelines or models will use the given workspace from then on."""
1365        assert workspace is not None
1366        if not isinstance(workspace, Workspace):
1367            raise TypeError("Workspace type was expected")
1368
1369        self._current_workspace = workspace
1370        return cast("Workspace", self._current_workspace)
1371
1372    def get_current_workspace(self) -> Workspace:
1373        """Return the current workspace.  See also `set_current_workspace`."""
1374        if self._current_workspace is None:
1375            # Is there a default? Use that or make one.
1376            default_ws = Workspace._get_user_default_workspace(self)
1377            if default_ws is not None:
1378                self._current_workspace = default_ws
1379            else:
1380                self._current_workspace = Workspace._create_user_default_workspace(self)
1381
1382        return cast("Workspace", self._current_workspace)
1383
1384    def invite_user(self, email, password=None):
1385        return User.invite_user(
1386            email, password, self.auth, self.api_endpoint, self.auth_endpoint
1387        )
1388
1389    def get_topic_name(self, pipeline_pk_id: int) -> str:
1390        return self._post_rest_api(
1391            "v1/api/plateau/get_topic_name",
1392            {
1393                "pipeline_pk_id": pipeline_pk_id,
1394            },
1395        ).json()["topic_name"]
1396
1397    def shim_token(self, token_data: auth.TokenData):
1398        fetcher = auth._RawTokenFetcher(token_data)
1399        self.auth = auth._PlatformAuth(fetcher)
1400
1401    def convert_model(
1402        self,
1403        path: Union[str, pathlib.Path],
1404        source_type: ModelConversionSource,
1405        conversion_arguments: ModelConversionArguments,
1406    ) -> Model:
1407        """
1408        Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments.
1409        Convert the model to onnx, and add to available models for a pipeline.
1410
1411        :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model.
1412        :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost.
1413        :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type.
1414        :return: An instance of the Model being converted to Onnx.
1415        :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance.
1416        :raises ModelConversionFailure: Failure in converting the model type.
1417        :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported.
1418        :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist.
1419        """
1420        if isinstance(path, str):
1421            path = pathlib.Path(path)
1422        if not os.path.exists(path):
1423            raise ModelConversionSourceFileNotPresent(
1424                f"The provided source file: {path} can not be found."
1425            )
1426        file_handle = open(path, "rb")
1427        files = [("files", file_handle)]
1428        base_url = self.api_endpoint
1429        workspace_id = self.get_current_workspace().id()
1430        common_headers = {
1431            "user_id": self.auth.user_id(),
1432            "user_email": self.auth.user_email(),
1433            "User-Agent": _user_agent,
1434        }
1435        curry_post = partial(
1436            requests.post,
1437            auth=self.auth,
1438            headers=common_headers,
1439            files=files,
1440        )
1441        model_id: Union[int, None] = None
1442
1443        def _handle_response(http_response) -> int:
1444            http_response.raise_for_status()
1445            response_record = http_response.json()
1446            if response_record is not None and "model_id" in response_record:
1447                return int(response_record["model_id"])
1448            else:
1449                raise ModelConversionFailure("Failed to convert keras model")
1450
1451        try:
1452            if source_type == ModelConversionSource.KERAS:
1453                assert (
1454                    type(conversion_arguments).__name__
1455                    == ConvertKerasArguments.__name__
1456                )
1457                ## This is not optimal but api-lb(envoy) -> python-api
1458                ## Gave a 50(3|4|2|0) on several other options
1459                ## dimensions being of type List[Any]
1460                ## dimensions being a base64 encoded json array
1461                ## dimensions being a comma seperated string that was proccessed server side.
1462                ## This casts the ConvertKerasArguments dict structure to a json string.
1463                ## Base64 encoding via a utf-8 binary conversion, passed as the url
1464                ## parameter config. Which is handlded server side.
1465                data = {
1466                    **conversion_arguments.to_dict(),
1467                    "workspace_id": workspace_id,
1468                }
1469                params = {
1470                    "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode(
1471                        "utf-8"
1472                    )
1473                }
1474                response = curry_post(url=f"{base_url}/v1/convert/keras", params=params)
1475                model_id = _handle_response(response)
1476            elif source_type == ModelConversionSource.SKLEARN:
1477                assert (
1478                    type(conversion_arguments).__name__
1479                    == ConvertSKLearnArguments.__name__
1480                )
1481                response = curry_post(
1482                    url=f"{base_url}/v1/convert/sklearn",
1483                    params={
1484                        **conversion_arguments.to_dict(),
1485                        "workspace_id": workspace_id,
1486                    },
1487                )
1488                model_id = _handle_response(response)
1489
1490            elif source_type == ModelConversionSource.XGBOOST:
1491                assert (
1492                    type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__
1493                )
1494                response = curry_post(
1495                    url=f"{base_url}/v1/convert/xgboost",
1496                    params={
1497                        **conversion_arguments.to_dict(),
1498                        "workspace_id": workspace_id,
1499                    },
1500                )
1501                model_id = _handle_response(response)
1502            else:
1503                raise ModelConversionUnsupportedType(
1504                    f"Unsupported model source type of {source_type} passed."
1505                )
1506            if model_id is None:
1507                raise ModelConversionFailure("Failed to retrieve final model id")
1508            return Model(self, {"id": model_id})
1509        except Exception:
1510            raise ModelConversionGenericException(
1511                "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai"
1512            )
1513        finally:
1514            file_handle.close()
1515
1516    def _post_rest_api_json(self, uri: str, body: dict):
1517        result = self._post_rest_api(uri, body)
1518        if result.status_code == 200:
1519            return result.json()
1520        else:
1521            raise Exception(f"{result.status_code}: {result.text}")
1522
1523    def _setup_mlops_client(self) -> "AuthenticatedClient":
1524        self._mlops = AuthenticatedClient(
1525            base_url=self.api_endpoint, token=self.auth._access_token().token
1526        )
1527        return self._mlops
1528
1529    def mlops(self) -> "AuthenticatedClient":
1530        return self._setup_mlops_client()

Client handle to a Wallaroo platform instance.

Objects of this class serve as the entrypoint to Wallaroo platform functionality.

Client( api_endpoint: str = 'http://api-lb:8080', auth_endpoint: str = '', request_timeout: int = 45, auth_type: Optional[str] = None, gql_client: Optional[gql.client.Client] = None, pg_connection_string: str = 'dbname=postgres user=postgres password=password host=postgres port=5432', interactive: Optional[bool] = None, time_format: str = '%Y-%d-%b %H:%M:%S')
127    def __init__(
128        self,
129        api_endpoint: str = "http://api-lb:8080",
130        auth_endpoint: str = "",
131        request_timeout: int = 45,
132        auth_type: Optional[str] = None,
133        gql_client: Optional[gql.Client] = None,
134        pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432",
135        interactive: Optional[bool] = None,
136        time_format: str = "%Y-%d-%b %H:%M:%S",
137    ):
138        """Create a Client handle.
139
140        :param str api_endpoint: Host/port of the platform API endpoint
141        :param str auth_endpoint: Host/port of the platform Keycloak instance
142        :param int timeout: Max timeout of web requests, in seconds
143        :param str auth_type: Authentication type to use. Can be one of: "none",
144            "sso", "user_password".
145        :param str pg_connection_string: Postgres connection string
146        :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise.
147        :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context.
148        """
149
150        auth_type, api_endpoint, auth_endpoint = Client.get_urls(
151            auth_type, api_endpoint, auth_endpoint
152        )
153
154        self.auth = auth.create(auth_endpoint, auth_type)
155
156        if gql_client:
157            self._gql_client = gql_client
158        else:
159            gql_transport = RequestsHTTPTransport(
160                url=posixpath.join(api_endpoint, "v1/graphql"),
161                auth=self.auth,
162                timeout=request_timeout,
163            )
164            self._gql_client = gql.Client(
165                transport=gql_transport, fetch_schema_from_transport=True
166            )
167
168        self.api_endpoint = api_endpoint
169
170        self.auth_endpoint = auth_endpoint
171
172        self.timeout = request_timeout
173
174        self._setup_mlops_client()
175
176        self.pg_connection_string = pg_connection_string
177
178        self._current_workspace: Optional[Workspace] = None
179
180        # TODO: debate the names of these things
181        self._default_ws_name: Optional[str] = None
182
183        user_email = self.auth.user_email()
184        if user_email is not None:
185            self._default_ws_name = user_email + "_ws"
186
187        if interactive is not None:
188            self._interactive = interactive
189        elif (
190            "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ
191        ):
192            self._interactive = True
193        else:
194            self._interactive = False
195
196        self._time_format = time_format

Create a Client handle.

Parameters
  • str api_endpoint: Host/port of the platform API endpoint
  • str auth_endpoint: Host/port of the platform Keycloak instance
  • int timeout: Max timeout of web requests, in seconds
  • str auth_type: Authentication type to use. Can be one of: "none", "sso", "user_password".
  • str pg_connection_string: Postgres connection string
  • bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise.
  • str time_format: Preferred strftime format string for displaying timestamps in a human context.
@staticmethod
def get_urls( auth_type: Optional[str], api_endpoint: str, auth_endpoint: str) -> Tuple[Optional[str], str, str]:
101    @staticmethod
102    def get_urls(
103        auth_type: Optional[str], api_endpoint: str, auth_endpoint: str
104    ) -> Tuple[Optional[str], str, str]:
105        """Method to calculate the auth values specified as defaults,
106        as params or in ENV vars.
107        Made static to be testable without reaching out to SSO, etc."""
108
109        if auth_type is None:
110            auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None)
111
112        # ideally we'd set auth_endpoint to None default value but that would
113        # make the auth_endpoint type to be Optiona[str] which messes up
114        # a lot| of type hinting and I wanted to make minimal changes without a
115        # lot of 'warnings'.
116        if len(auth_endpoint.strip()) == 0:
117            auth_endpoint = (
118                os.environ.get(WALLAROO_AUTH_URL)
119                or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT)
120                or "http://api-lb:8080"
121            )
122
123        api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint)
124
125        return auth_type, api_endpoint, auth_endpoint

Method to calculate the auth values specified as defaults, as params or in ENV vars. Made static to be testable without reaching out to SSO, etc.

def list_tags(self) -> wallaroo.tag.Tags:
217    def list_tags(self) -> Tags:
218        """List all tags on the platform.
219
220        :return: A list of all tags on the platform.
221        :rtype: List[Tag]
222        """
223        res = self._gql_client.execute(
224            gql.gql(
225                """
226            query ListTags {
227              tag(order_by: {id: desc}) {
228                id
229                tag
230                model_tags {
231                  model {
232                    id
233                    model_id
234                    models_pk_id
235                    model_version
236                    
237                  }
238                }
239                pipeline_tags {
240                  pipeline {
241                    id
242                    pipeline_id
243                    pipeline_versions {
244                        id
245                        version
246                    }
247                  }
248                }
249              }
250            }
251
252
253            """
254            )
255        )
256        return Tags([Tag(client=self, data={"id": p["id"]}) for p in res["tag"]])

List all tags on the platform.

Returns

A list of all tags on the platform.

def list_models(self) -> wallaroo.models.ModelsList:
258    def list_models(self) -> ModelsList:
259        """List all models on the platform.
260
261        :return: A list of all models on the platform.
262        :rtype: List[Model]
263        """
264        id = self.get_current_workspace().id()
265        res = models_list.sync(
266            client=self.mlops(),
267            json_body=models_list_json_body.ModelsListJsonBody(id),
268        )
269
270        if res is None:
271            raise Exception("Failed to list models")
272
273        if not isinstance(res, ModelsListResponse200):
274            raise Exception(res.msg)
275
276        return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models])

List all models on the platform.

Returns

A list of all models on the platform.

def list_deployments(self) -> List[wallaroo.deployment.Deployment]:
278    def list_deployments(self) -> List[Deployment]:
279        """List all deployments (active or not) on the platform.
280
281        :return: A list of all deployments on the platform.
282        :rtype: List[Deployment]
283        """
284        res = self._gql_client.execute(
285            gql.gql(
286                """
287            query ListDeployments {
288              deployment {
289                id
290                deploy_id
291                deployed
292                deployment_model_configs {
293                  model_config {
294                    id
295                  }
296                }
297              }
298            }
299            """
300            )
301        )
302        return [Deployment(client=self, data=d) for d in res["deployment"]]

List all deployments (active or not) on the platform.

Returns

A list of all deployments on the platform.

def search_pipelines( self, search_term: Optional[str] = None, deployed: Optional[bool] = None, created_start: Optional[wallaroo.client.Datetime] = None, created_end: Optional[wallaroo.client.Datetime] = None, updated_start: Optional[wallaroo.client.Datetime] = None, updated_end: Optional[wallaroo.client.Datetime] = None) -> wallaroo.pipeline_variant.PipelineVariants:
330    def search_pipelines(
331        self,
332        search_term: Optional[str] = None,
333        deployed: Optional[bool] = None,
334        created_start: Optional["Datetime"] = None,
335        created_end: Optional["Datetime"] = None,
336        updated_start: Optional["Datetime"] = None,
337        updated_end: Optional["Datetime"] = None,
338    ) -> PipelineVariants:
339        """Search for pipelines. All parameters are optional, in which case the result is the same as
340        `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example:
341
342             myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo")
343
344        :param str search_term: Will be matched against tags and model names. Example: "footag123".
345        :param bool deployed: Pipeline was deployed or not
346        :param str created_start: Pipeline was created at or after this time
347        :param str created_end: Pipeline was created at or before this time
348        :param str updated_start: Pipeline was updated at or before this time
349        :param str updated_end: Pipeline was updated at or before this time
350
351        :return: A list of all pipelines on the platform.
352        :rtype: List[Pipeline]
353        """
354        return PipelineVariants(
355            self._search_pipelines(
356                search_term,
357                deployed,
358                None,
359                created_start,
360                created_end,
361                updated_start,
362                updated_end,
363            )
364        )

Search for pipelines. All parameters are optional, in which case the result is the same as list_pipelines(). All times are strings to be parsed by datetime.isoformat. Example:

 myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo")
Parameters
  • str search_term: Will be matched against tags and model names. Example: "footag123".
  • bool deployed: Pipeline was deployed or not
  • str created_start: Pipeline was created at or after this time
  • str created_end: Pipeline was created at or before this time
  • str updated_start: Pipeline was updated at or before this time
  • str updated_end: Pipeline was updated at or before this time
Returns

A list of all pipelines on the platform.

def search_my_models( self, search_term: Optional[str] = None, uploaded_time_start: Optional[wallaroo.client.Datetime] = None, uploaded_time_end: Optional[wallaroo.client.Datetime] = None) -> wallaroo.model.ModelVersions:
502    def search_my_models(
503        self,
504        search_term: Optional[str] = None,
505        uploaded_time_start: Optional["Datetime"] = None,
506        uploaded_time_end: Optional["Datetime"] = None,
507    ) -> ModelVersions:
508        """Search models owned by you
509        params:
510             search_term: Searches the following metadata: names, shas, versions, file names, and tags
511             uploaded_time_start: Inclusive time of upload
512             uploaded_time_end: Inclusive time of upload
513        """
514        user_id = self.auth.user_id()
515        return ModelVersions(
516            self._search_models(
517                search_term=search_term,
518                user_id=user_id,
519                start=uploaded_time_start,
520                end=uploaded_time_end,
521            )
522        )

Search models owned by you params: search_term: Searches the following metadata: names, shas, versions, file names, and tags uploaded_time_start: Inclusive time of upload uploaded_time_end: Inclusive time of upload

def search_models( self, search_term: Optional[str] = None, uploaded_time_start: Optional[wallaroo.client.Datetime] = None, uploaded_time_end: Optional[wallaroo.client.Datetime] = None) -> wallaroo.model.ModelVersions:
524    def search_models(
525        self,
526        search_term: Optional[str] = None,
527        uploaded_time_start: Optional["Datetime"] = None,
528        uploaded_time_end: Optional["Datetime"] = None,
529    ) -> ModelVersions:
530        """Search all models you have access to.
531        params:
532             search_term: Searches the following metadata: names, shas, versions, file names, and tags
533             uploaded_time_start: Inclusive time of upload
534             uploaded_time_end: Inclusive time of upload
535        """
536        return ModelVersions(
537            self._search_models(
538                search_term=search_term,
539                start=uploaded_time_start,
540                end=uploaded_time_end,
541            )
542        )

Search all models you have access to. params: search_term: Searches the following metadata: names, shas, versions, file names, and tags uploaded_time_start: Inclusive time of upload uploaded_time_end: Inclusive time of upload

def get_user_by_email(self, email: str) -> Optional[wallaroo.user.User]:
603    def get_user_by_email(self, email: str) -> Optional[User]:
604        """Find a user by email"""
605        assert email is not None
606        escaped_email = quote_plus(email)
607        url = (
608            f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}"
609        )
610        headers = {
611            "Content-Type": "application/json",
612            "Authorization": self.auth._bearer_token_str(),
613            "User-Agent": _user_agent,
614        }
615        resp = requests.get(url, headers=headers, data={})
616        jresp = resp.json()
617        return None if jresp == [] else User(client=self, data=jresp[0])

Find a user by email

def deactivate_user(self, email: str) -> None:
619    def deactivate_user(self, email: str) -> None:
620        """Deactivates an existing user of the platform
621
622        Deactivated users cannot log into the platform.
623        Deactivated users do not count towards the number of allotted user seats from the license.
624
625        The Models and Pipelines owned by the deactivated user are not removed from the platform.
626
627        :param str email: The email address of the user to deactivate.
628
629        :return: None
630        :rtype: None
631        """
632
633        if self.auth.user_email() == email:
634            raise Exception("A user may not deactive themselves.")
635
636        user = self.get_user_by_email(email)
637
638        if user is None:
639            raise EntityNotFoundError("User", {"email": email})
640
641        if user.username() == "admin":
642            raise Exception("Admin user may not be deactivated.")
643
644        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
645        headers = {
646            "Content-Type": "application/json",
647            "Authorization": self.auth._bearer_token_str(),
648            "User-Agent": _user_agent,
649        }
650
651        # Get the current full user representation to return in the mutation due to keycloak bug
652        get_user_response = requests.get(url, headers=headers, data={})
653
654        cur_user_rep = get_user_response.json()
655        cur_user_rep["enabled"] = False
656
657        resp = requests.put(url, headers=headers, json=cur_user_rep)
658
659        if resp.status_code != 204:
660            raise EntityNotFoundError("User", {"email": email})
661        return None

Deactivates an existing user of the platform

Deactivated users cannot log into the platform. Deactivated users do not count towards the number of allotted user seats from the license.

The Models and Pipelines owned by the deactivated user are not removed from the platform.

Parameters
  • str email: The email address of the user to deactivate.
Returns

None

def activate_user(self, email: str) -> None:
663    def activate_user(self, email: str) -> None:
664        """Activates an existing user of the platform that had been previously deactivated.
665
666        Activated users can log into the platform.
667
668        :param str email: The email address of the user to activate.
669
670        :return: None
671        :rtype: None
672        """
673        user = self.get_user_by_email(email)
674
675        if user is None:
676            raise EntityNotFoundError("User", {"email": email})
677
678        url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}"
679
680        headers = {
681            "Content-Type": "application/json",
682            "Authorization": self.auth._bearer_token_str(),
683            "User-Agent": _user_agent,
684        }
685
686        # Get the current full user representation to return in the mutation due to keycloak bug
687        get_user_response = requests.get(url, headers=headers, data={})
688
689        cur_user_rep = get_user_response.json()
690        cur_user_rep["enabled"] = True
691
692        resp = requests.put(url, headers=headers, json=cur_user_rep)
693
694        if resp.status_code != 204:
695            raise EntityNotFoundError("User", {"email": email})
696        return None

Activates an existing user of the platform that had been previously deactivated.

Activated users can log into the platform.

Parameters
  • str email: The email address of the user to activate.
Returns

None

def list_users(self) -> List[wallaroo.user.User]:
710    def list_users(self) -> List[User]:
711        """List of all Users on the platform
712
713        :return: A list of all Users on the platform.
714        :rtype: List[User]
715        """
716        resp = User.list_users(auth=self.auth)
717        return [User(client=self, data=u) for u in resp]

List of all Users on the platform

Returns

A list of all Users on the platform.

def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> wallaroo.model.Model:
719    def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model:
720        """Upload a model defined by a file as a new model variant.
721
722        :param str model_name: The name of the model of which this is a variant.
723            Names must be ASCII alpha-numeric characters or dash (-) only.
724        :param Union[str, pathlib.Path] path: Path of the model file to upload.
725        :return: The created Model.
726        :rtype: Model
727        """
728
729        _Visibility.PRIVATE
730        if isinstance(path, str):
731            path = pathlib.Path(path)
732        with path.open("rb") as f:
733            return self._upload_model_stream(name, {"filename": path.name}, f)

Upload a model defined by a file as a new model variant.

Parameters
  • str model_name: The name of the model of which this is a variant. Names must be ASCII alpha-numeric characters or dash (-) only.
  • Union[str, pathlib.Path] path: Path of the model file to upload.
Returns

The created Model.

def register_model_image(self, name: str, image: str) -> wallaroo.model.Model:
750    def register_model_image(self, name: str, image: str) -> Model:
751        """Registers an MLFlow model as a new model.
752
753        :param str model_name: The name of the model of which this is a variant.
754            Names must be ASCII alpha-numeric characters or dash (-) only.
755        :param str image: Image name of the MLFlow model to register.
756        :return: The created Model.
757        :rtype: Model
758        """
759        data = {
760            "image_path": image,
761        }
762        return self._upload_model(name, data)

Registers an MLFlow model as a new model.

Parameters
  • str model_name: The name of the model of which this is a variant. Names must be ASCII alpha-numeric characters or dash (-) only.
  • str image: Image name of the MLFlow model to register.
Returns

The created Model.

def model_by_name(self, model_class: str, model_name: str) -> wallaroo.model.Model:
784    def model_by_name(self, model_class: str, model_name: str) -> Model:
785        """Fetch a Model by name.
786
787        :param str model_class: Name of the model class.
788        :param str model_name: Name of the variant within the specified model class.
789        :return: The Model with the corresponding model and variant name.
790        :rtype: Model
791        """
792        res = self._gql_client.execute(
793            gql.gql(
794                """
795                query ModelByName($model_id: String!, $model_version: String!) {
796                  model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) {
797                    id
798                    model_id
799                    model_version
800                  }
801                }
802                """
803            ),
804            variable_values={
805                "model_id": model_class,
806                "model_version": model_name,
807            },
808        )
809        if not res["model"]:
810            raise EntityNotFoundError(
811                "Model", {"model_class": model_class, "model_name": model_name}
812            )
813        return Model(client=self, data={"id": res["model"][0]["id"]})

Fetch a Model by name.

Parameters
  • str model_class: Name of the model class.
  • str model_name: Name of the variant within the specified model class.
Returns

The Model with the corresponding model and variant name.

def deployment_by_name(self, deployment_name: str) -> wallaroo.deployment.Deployment:
815    def deployment_by_name(self, deployment_name: str) -> Deployment:
816        """Fetch a Deployment by name.
817
818        :param str deployment_name: Name of the deployment.
819        :return: The Deployment with the corresponding name.
820        :rtype: Deployment
821        """
822        res = self._gql_client.execute(
823            gql.gql(
824                """
825                query DeploymentByName($deployment_name: String!) {
826                  deployment(where: {deploy_id: {_eq: $deployment_name}}) {
827                    id
828                  }
829                }
830                """
831            ),
832            variable_values={
833                "deployment_name": deployment_name,
834            },
835        )
836        if not res["deployment"]:
837            raise EntityNotFoundError(
838                "Deployment", {"deployment_name": deployment_name}
839            )
840        return Deployment(client=self, data={"id": res["deployment"][0]["id"]})

Fetch a Deployment by name.

Parameters
  • str deployment_name: Name of the deployment.
Returns

The Deployment with the corresponding name.

def pipelines_by_name(self, pipeline_name: str) -> List[wallaroo.pipeline.Pipeline]:
842    def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]:
843        """Fetch Pipelines by name.
844
845        :param str pipeline_name: Name of the pipeline.
846        :return: The Pipeline with the corresponding name.
847        :rtype: Pipeline
848        """
849        res = self._gql_client.execute(
850            gql.gql(
851                """
852                query PipelineByName($pipeline_name: String!) {
853                  pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) {
854                    id
855                  }
856                }
857                """
858            ),
859            variable_values={
860                "pipeline_name": pipeline_name,
861            },
862        )
863        assert "pipeline" in res
864        length = len(res["pipeline"])
865        if length < 1:
866            raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name})
867        return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]]

Fetch Pipelines by name.

Parameters
  • str pipeline_name: Name of the pipeline.
Returns

The Pipeline with the corresponding name.

def list_pipelines(self) -> List[wallaroo.pipeline.Pipeline]:
869    def list_pipelines(self) -> List[Pipeline]:
870        """List all pipelines on the platform.
871
872        :return: A list of all pipelines on the platform.
873        :rtype: List[Pipeline]
874        """
875        res = self._gql_client.execute(
876            gql.gql(
877                """
878            query ListPipelines {
879              pipeline(order_by: {id: desc}) {
880                id
881                pipeline_tags {
882                  tag {
883                    id
884                    tag
885                  }
886                }
887              }
888            }
889            """
890            )
891        )
892        return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]])

List all pipelines on the platform.

Returns

A list of all pipelines on the platform.

def build_pipeline(self, pipeline_name: str) -> wallaroo.pipeline.Pipeline:
894    def build_pipeline(self, pipeline_name: str) -> "Pipeline":
895        """Starts building a pipeline with the given `pipeline_name`,
896        returning a :py:PipelineConfigBuilder:
897
898        When completed, the pipeline can be uploaded with `.upload()`
899
900        :param pipeline_name string: Name of the pipeline, must be composed of ASCII
901          alpha-numeric characters plus dash (-).
902        """
903
904        require_dns_compliance(pipeline_name)
905
906        _Visibility.PRIVATE
907
908        # TODO: Needs to handle visibility?
909        data = pipelines_create.sync(
910            client=self.mlops(),
911            json_body=pipelines_create_json_body.PipelinesCreateJsonBody(
912                pipeline_name,
913                self.get_current_workspace().id(),
914                pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict(
915                    {}
916                ),
917            ),
918        )
919
920        if data is None:
921            raise Exception("Failed to create pipeline")
922
923        if not isinstance(data, PipelinesCreateResponse200):
924            raise Exception(data.msg)
925
926        return Pipeline(client=self, data={"id": data.pipeline_pk_id})

Starts building a pipeline with the given pipeline_name, returning a :py:PipelineConfigBuilder:

When completed, the pipeline can be uploaded with .upload()

Parameters
  • pipeline_name string: Name of the pipeline, must be composed of ASCII alpha-numeric characters plus dash (-).
def create_value_split_experiment( self, name: str, meta_key: str, default_model: wallaroo.model_config.ModelConfig, challenger_models: List[Tuple[Any, wallaroo.model_config.ModelConfig]]) -> wallaroo.pipeline.Pipeline:
 994    def create_value_split_experiment(
 995        self,
 996        name: str,
 997        meta_key: str,
 998        default_model: ModelConfig,
 999        challenger_models: List[Tuple[Any, ModelConfig]],
1000    ) -> Pipeline:
1001        """Creates a new PipelineVariant of a "value-split experiment" type.
1002        :param str name: Name of the Pipeline
1003        :param meta_key str: Inference input key on which to redirect inputs to
1004            experiment models.
1005        :param default_model ModelConfig: Model to send inferences by default.
1006        :param challenger_models List[Tuple[Any, ModelConfig]]: A list of
1007            meta_key values -> Models to send inferences. If the inference data
1008            referred to by meta_key is equal to one of the keys in this tuple,
1009            that inference is redirected to the corresponding model instead of
1010            the default model.
1011        """
1012        args = [meta_key, default_model.model().name()]
1013        for v, m in challenger_models:
1014            args.append(v)
1015            args.append(m.model().name())
1016        step = {
1017            "id": "metavalue_split",
1018            "operation": "map",
1019            "args": args,
1020        }
1021        definition = {"id": name, "steps": [step]}
1022        # TODO: This seems like a one-to-one replace, find appropriate test.
1023        data = self._gql_client.execute(
1024            gql.gql(
1025                """
1026            mutation CreatePipeline(
1027                $pipeline_id: String,
1028                $version: String,
1029                $definition: jsonb,
1030                $workspace_id: bigint
1031            ) {
1032                insert_pipeline(
1033                    objects: {
1034                    pipeline_versions: {
1035                        data: { definition: $definition }
1036                    }
1037                    pipeline_id: $pipeline_id
1038                    }
1039                ) {
1040                    returning {
1041                        id
1042                    }
1043                }
1044            }
1045            """
1046            ),
1047            variable_values={
1048                "pipeline_id": name,
1049                "definition": definition,
1050                "workspace_id": self.get_current_workspace().id(),
1051            },
1052        )
1053        return Pipeline(
1054            client=self,
1055            data=data["insert_pipeline"]["returning"][0],
1056        )

Creates a new PipelineVariant of a "value-split experiment" type.

Parameters
  • str name: Name of the Pipeline
  • meta_key str: Inference input key on which to redirect inputs to experiment models.
  • default_model ModelConfig: Model to send inferences by default.
  • challenger_models List[Tuple[Any, ModelConfig]]: A list of meta_key values -> Models to send inferences. If the inference data referred to by meta_key is equal to one of the keys in this tuple, that inference is redirected to the corresponding model instead of the default model.
def get_logs( self, topic: str, limit: int = 100) -> Tuple[wallaroo.logs.LogEntries, str]:
1058    def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]:
1059        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1060        headers = {"User-Agent": _user_agent}
1061        partitions = requests.get(base, auth=self.auth, headers=headers).json()[
1062            "partitions"
1063        ]
1064
1065        iterator = {
1066            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1067            for k, span in partitions.items()
1068        }
1069
1070        response = requests.post(
1071            base + "/records",
1072            params={"limit": limit},
1073            json=iterator,
1074            auth=self.auth,
1075            headers=headers,
1076        ).json()
1077
1078        return (
1079            LogEntries([LogEntry(json.loads(l)) for l in response["records"]]),
1080            response.get("status", "None"),
1081        )
def security_logs(self, limit: int) -> List[dict]:
1083    def security_logs(self, limit: int) -> List[dict]:
1084        """This function is not available in this release"""
1085        raise RuntimeError("security_log() is not available in this release.")

This function is not available in this release

def get_raw_logs( self, topic: str, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, limit: int = 100000, parse: bool = False, verbose: bool = False) -> List[Dict[str, Any]]:
1087    def get_raw_logs(
1088        self,
1089        topic: str,
1090        start: Optional[datetime] = None,
1091        end: Optional[datetime] = None,
1092        limit: int = 100_000,
1093        parse: bool = False,
1094        verbose: bool = False,
1095    ) -> List[Dict[str, Any]]:
1096        """Gets logs from Plateau for a particular time window without attempting
1097        to convert them to Inference LogEntries. Logs can be returned as strings
1098        or the json parsed into lists and dicts.
1099        :param topic str: The name of the topic to query
1100        :param start Optional[datetime]: The start of the time window
1101        :param end Optional[datetime]: The end of the time window
1102        :param limit int: The number of records to retrieve. Note retrieving many
1103            records may be a performance bottleneck.
1104        :param parse bool: Wether to attempt to parse the string as a json object.
1105        :param verbose bool: Prints out info to help diagnose issues.
1106        """
1107
1108        assert limit <= 1_000_000
1109
1110        base = self.api_endpoint + f"/v1/logs/topic/" + topic
1111        headers = {"User-Agent": _user_agent}
1112        resp = requests.get(base, auth=self.auth, headers=headers)
1113        if resp.status_code != 200:
1114            raise EntityNotFoundError(
1115                f"Could not get partitions {resp.text}", {"url": base}
1116            )
1117        data = resp.json()
1118        partitions = data["partitions"]
1119
1120        if verbose:
1121            print(f"Got partitions {partitions}")
1122
1123        params: Dict[str, Any] = {"limit": limit}
1124        if start is not None:
1125            start_str = start.astimezone(tz=timezone.utc).isoformat()
1126            params["time.start"] = start_str
1127        if end is not None:
1128            end_str = end.astimezone(tz=timezone.utc).isoformat()
1129            params["time.end"] = end_str
1130
1131        next: Union[Any, None] = {
1132            k: max(0, span["end"] - math.floor(limit / len(partitions)))
1133            for k, span in partitions.items()
1134        }
1135
1136        if verbose:
1137            print("Using params: ", params)
1138
1139        records = []
1140        while next is not None:
1141            response = requests.post(
1142                base + "/records",
1143                params=params,
1144                json=next,
1145                auth=self.auth,
1146                headers=headers,
1147            )
1148            if response.status_code != 200:
1149                raise EntityNotFoundError(
1150                    f"Could not get records {resp.text}",
1151                    {"url": base, "params": str(params), "iterator": str(next)},
1152                )
1153
1154            if verbose:
1155                print("response: ", response)
1156
1157            result = response.json()
1158            result_records = result["records"]
1159            if len(result_records) > 0:
1160                records.extend(result_records)
1161                next = result["next"]
1162            else:
1163                next = None
1164        if parse:
1165            return [json.loads(r) for r in records]
1166        return records

Gets logs from Plateau for a particular time window without attempting to convert them to Inference LogEntries. Logs can be returned as strings or the json parsed into lists and dicts.

Parameters
  • topic str: The name of the topic to query
  • start Optional[datetime]: The start of the time window
  • end Optional[datetime]: The end of the time window
  • limit int: The number of records to retrieve. Note retrieving many records may be a performance bottleneck.
  • parse bool: Wether to attempt to parse the string as a json object.
  • verbose bool: Prints out info to help diagnose issues.
def get_raw_pipeline_inference_logs( self, topic: str, start: datetime.datetime, end: datetime.datetime, model_name: Optional[str] = None, limit: int = 100000, verbose: bool = False) -> List[Dict[str, Any]]:
1168    def get_raw_pipeline_inference_logs(
1169        self,
1170        topic: str,
1171        start: datetime,
1172        end: datetime,
1173        model_name: Optional[str] = None,
1174        limit: int = 100_000,
1175        verbose: bool = False,
1176    ) -> List[Dict[str, Any]]:
1177        """Gets logs from Plateau for a particular time window and filters them for
1178        the model specified.
1179        :param pipeline_name str: The name/pipeline_id of the pipeline to query
1180        :param topic str: The name of the topic to query
1181        :param start Optional[datetime]: The start of the time window
1182        :param end Optional[datetime]: The end of the time window
1183        :param model_id: The name of the specific model to filter if any
1184        :param limit int: The number of records to retrieve. Note retrieving many
1185            records may be a performance bottleneck.
1186        :param verbose bool: Prints out info to help diagnose issues.
1187        """
1188        logs = self.get_raw_logs(
1189            topic,
1190            start=start,
1191            end=end,
1192            limit=limit,
1193            parse=True,
1194            verbose=verbose,
1195        )
1196
1197        if verbose:
1198            print(f"Got {len(logs)} initial logs")
1199
1200        if model_name:
1201            logs = [l for l in logs if l["model_name"] == model_name]
1202
1203        # inference results are a unix timestamp in millis - filter by that
1204        start_ts = int(start.timestamp() * 1000)
1205        end_ts = int(end.timestamp() * 1000)
1206        logs = [l for l in logs if start_ts <= l["time"] < end_ts]
1207
1208        return logs

Gets logs from Plateau for a particular time window and filters them for the model specified.

Parameters
  • pipeline_name str: The name/pipeline_id of the pipeline to query
  • topic str: The name of the topic to query
  • start Optional[datetime]: The start of the time window
  • end Optional[datetime]: The end of the time window
  • model_id: The name of the specific model to filter if any
  • limit int: The number of records to retrieve. Note retrieving many records may be a performance bottleneck.
  • verbose bool: Prints out info to help diagnose issues.
def get_pipeline_inference_dataframe( self, topic: str, start: datetime.datetime, end: datetime.datetime, model_name: Optional[str] = None, limit: int = 100000, verbose=False) -> pandas.core.frame.DataFrame:
1210    def get_pipeline_inference_dataframe(
1211        self,
1212        topic: str,
1213        start: datetime,
1214        end: datetime,
1215        model_name: Optional[str] = None,
1216        limit: int = 100_000,
1217        verbose=False,
1218    ) -> pd.DataFrame:
1219        logs = self.get_raw_pipeline_inference_logs(
1220            topic, start, end, model_name, limit, verbose
1221        )
1222        return inference_logs_to_dataframe(logs)
def get_assay_results( self, assay_id: int, start: datetime.datetime, end: datetime.datetime) -> wallaroo.assay.AssayAnalysisList:
1224    def get_assay_results(
1225        self,
1226        assay_id: int,
1227        start: datetime,
1228        end: datetime,
1229    ) -> AssayAnalysisList:
1230        """Gets the assay results for a particular time window, parses them, and returns an
1231        AssayAnalysisList of AssayAnalysis.
1232        :param assay_id int: The id of the assay we are looking for.
1233        :param start datetime: The start of the time window
1234        :param end datetime: The end of the time window
1235        """
1236        res = assays_get_assay_results.sync(
1237            client=self.mlops(),
1238            json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody(
1239                assay_id, start, end
1240            ),
1241        )
1242
1243        if res is None:
1244            raise Exception("Failed to list models")
1245
1246        if not isinstance(res, List):
1247            raise Exception(res.msg)
1248
1249        if len(res) != 0 and not isinstance(
1250            res[0], AssaysGetAssayResultsResponse200Item
1251        ):
1252            raise Exception("invalid response")
1253
1254        return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res])

Gets the assay results for a particular time window, parses them, and returns an AssayAnalysisList of AssayAnalysis.

Parameters
  • assay_id int: The id of the assay we are looking for.
  • start datetime: The start of the time window
  • end datetime: The end of the time window
def build_assay( self, assay_name: str, pipeline: wallaroo.pipeline.Pipeline, model_name: str, baseline_start: datetime.datetime, baseline_end: datetime.datetime) -> wallaroo.assay_config.AssayBuilder:
1256    def build_assay(
1257        self,
1258        assay_name: str,
1259        pipeline: Pipeline,
1260        model_name: str,
1261        baseline_start: datetime,
1262        baseline_end: datetime,
1263    ) -> AssayBuilder:
1264        """Creates an AssayBuilder that can be used to configure and create
1265        Assays.
1266        :param assay_name str: Human friendly name for the assay
1267        :param pipeline Pipeline: The pipeline this assay will work on
1268        :param model_name str: The model that this assay will monitor
1269        :param baseline_start datetime: The start time for the inferences to
1270            use as the baseline
1271        :param baseline_end datetime: The end time of the baseline window.
1272        the baseline. Windows start immediately after the baseline window and
1273        are run at regular intervals continously until the assay is deactivated
1274        or deleted.
1275        """
1276        assay_builder = AssayBuilder(
1277            self,
1278            assay_name,
1279            pipeline.id(),
1280            pipeline.name(),
1281            model_name,
1282            baseline_start,
1283            baseline_end,
1284        )
1285
1286        return assay_builder

Creates an AssayBuilder that can be used to configure and create Assays.

Parameters
  • assay_name str: Human friendly name for the assay
  • pipeline Pipeline: The pipeline this assay will work on
  • model_name str: The model that this assay will monitor
  • baseline_start datetime: The start time for the inferences to use as the baseline
  • baseline_end datetime: The end time of the baseline window. the baseline. Windows start immediately after the baseline window and are run at regular intervals continously until the assay is deactivated or deleted.
def upload_assay(self, config: wallaroo.assay_config.AssayConfig) -> int:
1288    def upload_assay(self, config: AssayConfig) -> int:
1289        """Creates an assay in the database.
1290        :param config AssayConfig: The configuration for the assay to create.
1291        :return assay_id: The identifier for the assay that was created.
1292        :rtype int
1293        """
1294        data = assays_create.sync(
1295            client=self.mlops(),
1296            json_body=AssaysCreateJsonBody.from_dict(
1297                {
1298                    **json.loads(config.to_json()),
1299                    "created_at": datetime.now(timezone.utc).isoformat(),
1300                }
1301            ),
1302        )
1303
1304        if data is None:
1305            raise Exception("Failed to create assay")
1306
1307        if not isinstance(data, AssaysCreateResponse200):
1308            raise Exception(data.msg)
1309
1310        return data.assay_id

Creates an assay in the database.

Parameters
  • config AssayConfig: The configuration for the assay to create.
Returns

The identifier for the assay that was created. :rtype int

def list_assays(self) -> List[wallaroo.assay.Assay]:
1312    def list_assays(self) -> List[Assay]:
1313        """List all assays on the platform.
1314
1315        :return: A list of all assays on the platform.
1316        :rtype: List[Assay]
1317        """
1318        res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET))
1319
1320        if res is None:
1321            raise Exception("Failed to get assays")
1322
1323        if not isinstance(res, List):
1324            raise Exception(res.msg)
1325
1326        return Assays([Assay(client=self, data=v.to_dict()) for v in res])

List all assays on the platform.

Returns

A list of all assays on the platform.

def create_tag(self, tag_text: str) -> wallaroo.tag.Tag:
1328    def create_tag(self, tag_text: str) -> Tag:
1329        """Create a new tag with the given text."""
1330        assert tag_text is not None
1331        return Tag._create_tag(client=self, tag_text=tag_text)

Create a new tag with the given text.

def create_workspace(self, workspace_name: str) -> wallaroo.workspace.Workspace:
1333    def create_workspace(self, workspace_name: str) -> Workspace:
1334        """Create a new workspace with the current user as its first owner.
1335
1336        :param str workspace_name: Name of the workspace, must be composed of ASCII
1337           alpha-numeric characters plus dash (-)"""
1338        assert workspace_name is not None
1339        require_dns_compliance(workspace_name)
1340        return Workspace._create_workspace(client=self, name=workspace_name)

Create a new workspace with the current user as its first owner.

Parameters
  • str workspace_name: Name of the workspace, must be composed of ASCII alpha-numeric characters plus dash (-)
def list_workspaces(self) -> List[wallaroo.workspace.Workspace]:
1342    def list_workspaces(self) -> List[Workspace]:
1343        """List all workspaces on the platform which this user has permission see.
1344
1345        :return: A list of all workspaces on the platform.
1346        :rtype: List[Workspace]
1347        """
1348        res = workspaces_list.sync(
1349            client=self.mlops(),
1350            json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET),
1351        )
1352
1353        if res is None:
1354            raise Exception("Failed to get workspaces")
1355
1356        if not isinstance(res, WorkspacesListResponse200):
1357            raise Exception(res.msg)
1358
1359        return Workspaces(
1360            [Workspace(client=self, data=d.to_dict()) for d in res.workspaces]
1361        )

List all workspaces on the platform which this user has permission see.

Returns

A list of all workspaces on the platform.

def set_current_workspace( self, workspace: wallaroo.workspace.Workspace) -> wallaroo.workspace.Workspace:
1363    def set_current_workspace(self, workspace: Workspace) -> Workspace:
1364        """Any calls involving pipelines or models will use the given workspace from then on."""
1365        assert workspace is not None
1366        if not isinstance(workspace, Workspace):
1367            raise TypeError("Workspace type was expected")
1368
1369        self._current_workspace = workspace
1370        return cast("Workspace", self._current_workspace)

Any calls involving pipelines or models will use the given workspace from then on.

def get_current_workspace(self) -> wallaroo.workspace.Workspace:
1372    def get_current_workspace(self) -> Workspace:
1373        """Return the current workspace.  See also `set_current_workspace`."""
1374        if self._current_workspace is None:
1375            # Is there a default? Use that or make one.
1376            default_ws = Workspace._get_user_default_workspace(self)
1377            if default_ws is not None:
1378                self._current_workspace = default_ws
1379            else:
1380                self._current_workspace = Workspace._create_user_default_workspace(self)
1381
1382        return cast("Workspace", self._current_workspace)

Return the current workspace. See also set_current_workspace.

def invite_user(self, email, password=None):
1384    def invite_user(self, email, password=None):
1385        return User.invite_user(
1386            email, password, self.auth, self.api_endpoint, self.auth_endpoint
1387        )
def get_topic_name(self, pipeline_pk_id: int) -> str:
1389    def get_topic_name(self, pipeline_pk_id: int) -> str:
1390        return self._post_rest_api(
1391            "v1/api/plateau/get_topic_name",
1392            {
1393                "pipeline_pk_id": pipeline_pk_id,
1394            },
1395        ).json()["topic_name"]
def shim_token(self, token_data: wallaroo.auth.TokenData):
1397    def shim_token(self, token_data: auth.TokenData):
1398        fetcher = auth._RawTokenFetcher(token_data)
1399        self.auth = auth._PlatformAuth(fetcher)
1401    def convert_model(
1402        self,
1403        path: Union[str, pathlib.Path],
1404        source_type: ModelConversionSource,
1405        conversion_arguments: ModelConversionArguments,
1406    ) -> Model:
1407        """
1408        Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments.
1409        Convert the model to onnx, and add to available models for a pipeline.
1410
1411        :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model.
1412        :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost.
1413        :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type.
1414        :return: An instance of the Model being converted to Onnx.
1415        :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance.
1416        :raises ModelConversionFailure: Failure in converting the model type.
1417        :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported.
1418        :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist.
1419        """
1420        if isinstance(path, str):
1421            path = pathlib.Path(path)
1422        if not os.path.exists(path):
1423            raise ModelConversionSourceFileNotPresent(
1424                f"The provided source file: {path} can not be found."
1425            )
1426        file_handle = open(path, "rb")
1427        files = [("files", file_handle)]
1428        base_url = self.api_endpoint
1429        workspace_id = self.get_current_workspace().id()
1430        common_headers = {
1431            "user_id": self.auth.user_id(),
1432            "user_email": self.auth.user_email(),
1433            "User-Agent": _user_agent,
1434        }
1435        curry_post = partial(
1436            requests.post,
1437            auth=self.auth,
1438            headers=common_headers,
1439            files=files,
1440        )
1441        model_id: Union[int, None] = None
1442
1443        def _handle_response(http_response) -> int:
1444            http_response.raise_for_status()
1445            response_record = http_response.json()
1446            if response_record is not None and "model_id" in response_record:
1447                return int(response_record["model_id"])
1448            else:
1449                raise ModelConversionFailure("Failed to convert keras model")
1450
1451        try:
1452            if source_type == ModelConversionSource.KERAS:
1453                assert (
1454                    type(conversion_arguments).__name__
1455                    == ConvertKerasArguments.__name__
1456                )
1457                ## This is not optimal but api-lb(envoy) -> python-api
1458                ## Gave a 50(3|4|2|0) on several other options
1459                ## dimensions being of type List[Any]
1460                ## dimensions being a base64 encoded json array
1461                ## dimensions being a comma seperated string that was proccessed server side.
1462                ## This casts the ConvertKerasArguments dict structure to a json string.
1463                ## Base64 encoding via a utf-8 binary conversion, passed as the url
1464                ## parameter config. Which is handlded server side.
1465                data = {
1466                    **conversion_arguments.to_dict(),
1467                    "workspace_id": workspace_id,
1468                }
1469                params = {
1470                    "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode(
1471                        "utf-8"
1472                    )
1473                }
1474                response = curry_post(url=f"{base_url}/v1/convert/keras", params=params)
1475                model_id = _handle_response(response)
1476            elif source_type == ModelConversionSource.SKLEARN:
1477                assert (
1478                    type(conversion_arguments).__name__
1479                    == ConvertSKLearnArguments.__name__
1480                )
1481                response = curry_post(
1482                    url=f"{base_url}/v1/convert/sklearn",
1483                    params={
1484                        **conversion_arguments.to_dict(),
1485                        "workspace_id": workspace_id,
1486                    },
1487                )
1488                model_id = _handle_response(response)
1489
1490            elif source_type == ModelConversionSource.XGBOOST:
1491                assert (
1492                    type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__
1493                )
1494                response = curry_post(
1495                    url=f"{base_url}/v1/convert/xgboost",
1496                    params={
1497                        **conversion_arguments.to_dict(),
1498                        "workspace_id": workspace_id,
1499                    },
1500                )
1501                model_id = _handle_response(response)
1502            else:
1503                raise ModelConversionUnsupportedType(
1504                    f"Unsupported model source type of {source_type} passed."
1505                )
1506            if model_id is None:
1507                raise ModelConversionFailure("Failed to retrieve final model id")
1508            return Model(self, {"id": model_id})
1509        except Exception:
1510            raise ModelConversionGenericException(
1511                "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai"
1512            )
1513        finally:
1514            file_handle.close()

Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments. Convert the model to onnx, and add to available models for a pipeline.

Parameters
  • Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model.
  • ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost.
  • ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type.
Returns

An instance of the Model being converted to Onnx.

Raises
  • ModelConversionGenericException: On a generic failure, please contact our support for further assistance.
  • ModelConversionFailure: Failure in converting the model type.
  • ModelConversionUnsupportedType: Raised when the source type passed is not supported.
  • ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist.
1529    def mlops(self) -> "AuthenticatedClient":
1530        return self._setup_mlops_client()