wallaroo.client
1import base64 2import json 3import math 4import os 5import pathlib 6import posixpath 7from datetime import datetime, timezone 8from functools import partial 9from typing import Any, Dict, List, NewType, Optional, Tuple, Union, cast 10from urllib.parse import quote_plus 11 12import gql # type: ignore 13import pandas as pd 14import requests 15from gql.transport.requests import RequestsHTTPTransport 16 17from wallaroo.models import Models, ModelsList 18 19from . import auth 20from .assay import Assay, AssayAnalysis, AssayAnalysisList, Assays 21from .assay_config import AssayBuilder, AssayConfig 22from .checks import require_dns_compliance 23from .deployment import Deployment 24from .inference_decode import inference_logs_to_dataframe 25from .logs import LogEntries, LogEntry 26from .model import Model, ModelVersions 27from .model_config import ModelConfig 28from .ModelConversion import ( 29 ConvertKerasArguments, 30 ConvertSKLearnArguments, 31 ConvertXGBoostArgs, 32 ModelConversionArguments, 33 ModelConversionFailure, 34 ModelConversionGenericException, 35 ModelConversionSource, 36 ModelConversionSourceFileNotPresent, 37 ModelConversionUnsupportedType, 38) 39from .object import EntityNotFoundError, ModelUploadError 40from .pipeline import Pipeline, Pipelines 41from .pipeline_config import PipelineConfig 42from .pipeline_variant import PipelineVariant, PipelineVariants 43from .tag import Tag, Tags 44from .user import User 45from .version import _user_agent 46from .visibility import _Visibility 47from .wallaroo_ml_ops_api_client.api.assay import ( 48 assays_create, 49 assays_get_assay_results, 50 assays_list, 51) 52from .wallaroo_ml_ops_api_client.api.model import models_list 53from .wallaroo_ml_ops_api_client.api.pipeline import pipelines_create 54from .wallaroo_ml_ops_api_client.api.workspace import workspaces_list 55from .wallaroo_ml_ops_api_client.client import AuthenticatedClient 56from .wallaroo_ml_ops_api_client.models import ( 57 assays_get_assay_results_json_body, 58 models_list_json_body, 59 pipelines_create_json_body, 60 pipelines_create_json_body_definition, 61 workspaces_list_json_body, 62) 63from .wallaroo_ml_ops_api_client.models.assays_create_json_body import ( 64 AssaysCreateJsonBody, 65) 66from .wallaroo_ml_ops_api_client.models.assays_create_response_200 import ( 67 AssaysCreateResponse200, 68) 69from .wallaroo_ml_ops_api_client.models.assays_get_assay_results_response_200_item import ( 70 AssaysGetAssayResultsResponse200Item, 71) 72from .wallaroo_ml_ops_api_client.models.assays_list_json_body import AssaysListJsonBody 73from .wallaroo_ml_ops_api_client.models.models_list_response_200 import ( 74 ModelsListResponse200, 75) 76from .wallaroo_ml_ops_api_client.models.pipelines_create_response_200 import ( 77 PipelinesCreateResponse200, 78) 79from .wallaroo_ml_ops_api_client.models.workspaces_list_response_200 import ( 80 WorkspacesListResponse200, 81) 82from .wallaroo_ml_ops_api_client.types import UNSET 83from .workspace import Workspace, Workspaces 84 85Datetime = NewType("Datetime", datetime) 86 87WALLAROO_SDK_AUTH_TYPE = "WALLAROO_SDK_AUTH_TYPE" 88WALLAROO_SDK_AUTH_ENDPOINT = "WALLAROO_SDK_AUTH_ENDPOINT" 89WALLAROO_URL = "WALLAROO_URL" 90WALLAROO_AUTH_URL = "WALLAROO_AUTH_URL" 91 92 93class Client(object): 94 """Client handle to a Wallaroo platform instance. 95 96 Objects of this class serve as the entrypoint to Wallaroo platform 97 functionality. 98 """ 99 100 @staticmethod 101 def get_urls( 102 auth_type: Optional[str], api_endpoint: str, auth_endpoint: str 103 ) -> Tuple[Optional[str], str, str]: 104 """Method to calculate the auth values specified as defaults, 105 as params or in ENV vars. 106 Made static to be testable without reaching out to SSO, etc.""" 107 108 if auth_type is None: 109 auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None) 110 111 # ideally we'd set auth_endpoint to None default value but that would 112 # make the auth_endpoint type to be Optiona[str] which messes up 113 # a lot| of type hinting and I wanted to make minimal changes without a 114 # lot of 'warnings'. 115 if len(auth_endpoint.strip()) == 0: 116 auth_endpoint = ( 117 os.environ.get(WALLAROO_AUTH_URL) 118 or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT) 119 or "http://api-lb:8080" 120 ) 121 122 api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint) 123 124 return auth_type, api_endpoint, auth_endpoint 125 126 def __init__( 127 self, 128 api_endpoint: str = "http://api-lb:8080", 129 auth_endpoint: str = "", 130 request_timeout: int = 45, 131 auth_type: Optional[str] = None, 132 gql_client: Optional[gql.Client] = None, 133 pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432", 134 interactive: Optional[bool] = None, 135 time_format: str = "%Y-%d-%b %H:%M:%S", 136 ): 137 """Create a Client handle. 138 139 :param str api_endpoint: Host/port of the platform API endpoint 140 :param str auth_endpoint: Host/port of the platform Keycloak instance 141 :param int timeout: Max timeout of web requests, in seconds 142 :param str auth_type: Authentication type to use. Can be one of: "none", 143 "sso", "user_password". 144 :param str pg_connection_string: Postgres connection string 145 :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise. 146 :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context. 147 """ 148 149 auth_type, api_endpoint, auth_endpoint = Client.get_urls( 150 auth_type, api_endpoint, auth_endpoint 151 ) 152 153 self.auth = auth.create(auth_endpoint, auth_type) 154 155 if gql_client: 156 self._gql_client = gql_client 157 else: 158 gql_transport = RequestsHTTPTransport( 159 url=posixpath.join(api_endpoint, "v1/graphql"), 160 auth=self.auth, 161 timeout=request_timeout, 162 ) 163 self._gql_client = gql.Client( 164 transport=gql_transport, fetch_schema_from_transport=True 165 ) 166 167 self.api_endpoint = api_endpoint 168 169 self.auth_endpoint = auth_endpoint 170 171 self.timeout = request_timeout 172 173 self._setup_mlops_client() 174 175 self.pg_connection_string = pg_connection_string 176 177 self._current_workspace: Optional[Workspace] = None 178 179 # TODO: debate the names of these things 180 self._default_ws_name: Optional[str] = None 181 182 user_email = self.auth.user_email() 183 if user_email is not None: 184 self._default_ws_name = user_email + "_ws" 185 186 if interactive is not None: 187 self._interactive = interactive 188 elif ( 189 "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ 190 ): 191 self._interactive = True 192 else: 193 self._interactive = False 194 195 self._time_format = time_format 196 197 def _get_rest_api(self, path: str, params: dict): 198 headers = { 199 "authorization": self.auth._bearer_token_str(), 200 "user-agent": _user_agent, 201 } 202 203 url = f"{self.api_endpoint}/{path}" 204 205 return requests.get(url, headers=headers, params=params) 206 207 def _post_rest_api(self, path: str, body: dict): 208 headers = { 209 "authorization": self.auth._bearer_token_str(), 210 "user-agent": _user_agent, 211 } 212 213 url = f"{self.api_endpoint}/{path}" 214 return requests.post(url, headers=headers, json=body) 215 216 def list_tags(self) -> Tags: 217 """List all tags on the platform. 218 219 :return: A list of all tags on the platform. 220 :rtype: List[Tag] 221 """ 222 res = self._gql_client.execute( 223 gql.gql( 224 """ 225 query ListTags { 226 tag(order_by: {id: desc}) { 227 id 228 tag 229 model_tags { 230 model { 231 id 232 model_id 233 models_pk_id 234 model_version 235 236 } 237 } 238 pipeline_tags { 239 pipeline { 240 id 241 pipeline_id 242 pipeline_versions { 243 id 244 version 245 } 246 } 247 } 248 } 249 } 250 251 252 """ 253 ) 254 ) 255 return Tags([Tag(client=self, data={"id": p["id"]}) for p in res["tag"]]) 256 257 def list_models(self) -> ModelsList: 258 """List all models on the platform. 259 260 :return: A list of all models on the platform. 261 :rtype: List[Model] 262 """ 263 id = self.get_current_workspace().id() 264 res = models_list.sync( 265 client=self.mlops(), 266 json_body=models_list_json_body.ModelsListJsonBody(id), 267 ) 268 269 if res is None: 270 raise Exception("Failed to list models") 271 272 if not isinstance(res, ModelsListResponse200): 273 raise Exception(res.msg) 274 275 return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models]) 276 277 def list_deployments(self) -> List[Deployment]: 278 """List all deployments (active or not) on the platform. 279 280 :return: A list of all deployments on the platform. 281 :rtype: List[Deployment] 282 """ 283 res = self._gql_client.execute( 284 gql.gql( 285 """ 286 query ListDeployments { 287 deployment { 288 id 289 deploy_id 290 deployed 291 deployment_model_configs { 292 model_config { 293 id 294 } 295 } 296 } 297 } 298 """ 299 ) 300 ) 301 return [Deployment(client=self, data=d) for d in res["deployment"]] 302 303 """ 304 # Removed until we figure out what pipeline ownership means 305 # 306 # def search_my_pipelines( 307 # self, 308 # search_term: Optional[str] = None, 309 # deployed: Optional[bool] = None, 310 # created_start: Optional["Datetime"] = None, 311 # created_end: Optional["Datetime"] = None, 312 # updated_start: Optional["Datetime"] = None, 313 # updated_end: Optional["Datetime"] = None, 314 # ) -> List[Pipeline]: 315 # user_id = self.auth.user_id() 316 # return Pipelines( 317 # self._search_pipelines( 318 # search_term, 319 # deployed, 320 # user_id, 321 # created_start, 322 # created_end, 323 # updated_start, 324 # updated_end, 325 # ) 326 # ) 327 """ 328 329 def search_pipelines( 330 self, 331 search_term: Optional[str] = None, 332 deployed: Optional[bool] = None, 333 created_start: Optional["Datetime"] = None, 334 created_end: Optional["Datetime"] = None, 335 updated_start: Optional["Datetime"] = None, 336 updated_end: Optional["Datetime"] = None, 337 ) -> PipelineVariants: 338 """Search for pipelines. All parameters are optional, in which case the result is the same as 339 `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example: 340 341 myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo") 342 343 :param str search_term: Will be matched against tags and model names. Example: "footag123". 344 :param bool deployed: Pipeline was deployed or not 345 :param str created_start: Pipeline was created at or after this time 346 :param str created_end: Pipeline was created at or before this time 347 :param str updated_start: Pipeline was updated at or before this time 348 :param str updated_end: Pipeline was updated at or before this time 349 350 :return: A list of all pipelines on the platform. 351 :rtype: List[Pipeline] 352 """ 353 return PipelineVariants( 354 self._search_pipelines( 355 search_term, 356 deployed, 357 None, 358 created_start, 359 created_end, 360 updated_start, 361 updated_end, 362 ) 363 ) 364 365 def _search_pipelines( 366 self, 367 search_term: Optional[str] = None, 368 deployed: Optional[bool] = None, 369 user_id: Optional[str] = None, 370 created_start: Optional["Datetime"] = None, 371 created_end: Optional["Datetime"] = None, 372 updated_start: Optional["Datetime"] = None, 373 updated_end: Optional["Datetime"] = None, 374 ) -> List[PipelineVariant]: 375 (query, params) = self._generate_search_pipeline_query( 376 search_term=search_term, 377 deployed=deployed, 378 user_id=user_id, 379 created_start=created_start, 380 created_end=created_end, 381 updated_start=updated_start, 382 updated_end=updated_end, 383 ) 384 q = gql.gql(query) 385 data = self._gql_client.execute(q, variable_values=params) 386 pipelines = [] 387 if data["search_pipelines"]: 388 for p in data["search_pipelines"]: 389 pipelines.append(PipelineVariant(self, p)) 390 return pipelines 391 392 def _generate_search_pipeline_query( 393 self, 394 search_term: Optional[str] = None, 395 deployed: Optional[bool] = None, 396 user_id: Optional[str] = None, 397 created_start: Optional["Datetime"] = None, 398 created_end: Optional["Datetime"] = None, 399 updated_start: Optional["Datetime"] = None, 400 updated_end: Optional["Datetime"] = None, 401 ): 402 filters = [] 403 query_params = [] 404 params: Dict[str, Any] = {} 405 search = "" 406 if search_term: 407 search = search_term 408 params["search_term"] = search 409 query_params.append("$search_term: String!") 410 411 if user_id: 412 filters.append("owner_id: {_eq: $user_id}") 413 params["user_id"] = user_id 414 query_params.append("$user_id: String!") 415 416 if deployed is not None: 417 filters.append("pipeline: {deployment: {deployed: {_eq: $deployed}}}") 418 params["deployed"] = deployed 419 query_params.append("$deployed: Boolean") 420 421 self._generate_time_range_graphql( 422 "created_at", 423 start=created_start, 424 end=created_end, 425 filters=filters, 426 query_params=query_params, 427 params=params, 428 ) 429 self._generate_time_range_graphql( 430 "updated_at", 431 start=updated_start, 432 end=updated_end, 433 filters=filters, 434 query_params=query_params, 435 params=params, 436 ) 437 438 where_clause_str = self._generate_where_clause_str(filters) 439 query_param_str = self._generate_query_param_str(query_params) 440 query = f""" 441 query GetPipelines({query_param_str}) {{ 442 search_pipelines(args: {{search: $search_term}}, distinct_on: id{where_clause_str}, order_by: {{id: desc}}) {{ 443 id 444 created_at 445 pipeline_pk_id 446 updated_at 447 version 448 pipeline {{ 449 id 450 pipeline_id 451 pipeline_tags {{ 452 id 453 tag {{ 454 id 455 tag 456 }} 457 }} 458 }} 459 }} 460 }} 461 """ 462 return (query, params) 463 464 def _generate_where_clause_str(self, filters: List[str]) -> str: 465 where_clause_str = "" 466 filters_len = len(filters) 467 if filters_len > 0: 468 if filters_len > 1: 469 where_clause_str = f""", where: {{_and: {{ {", ".join(filters)} }}}}""" 470 else: 471 where_clause_str = f", where: {{{filters[0]}}}" 472 return where_clause_str 473 474 def _generate_query_param_str(self, query_params: List[str]): 475 return ", ".join(query_params) 476 477 def _generate_time_range_graphql( 478 self, 479 field: str, 480 start: Optional["Datetime"], 481 end: Optional["Datetime"], 482 filters: List[str], 483 query_params: List[str], 484 params: Dict[str, Any], 485 ): 486 if start and not end: 487 filters.append(f"{field}: {{_gte: $start_{field}}}") 488 params[f"start_{field}"] = start 489 query_params.append(f"$start_{field}: timestamptz!") 490 elif end and not start: 491 filters.append(f"{field}: {{_lte: $end_{field}}}") 492 params[f"end_{field}"] = end 493 query_params.append(f"$end_{field}: timestamptz!") 494 elif start and end: 495 filters.append(f"{field}: {{_gte: $start_{field}, _lte: $end_{field}}}") 496 params[f"start_{field}"] = start 497 params[f"end_{field}"] = start 498 query_params.append(f"$start_{field}: timestamptz!") 499 query_params.append(f"$end_{field}: timestamptz!") 500 501 def search_my_models( 502 self, 503 search_term: Optional[str] = None, 504 uploaded_time_start: Optional["Datetime"] = None, 505 uploaded_time_end: Optional["Datetime"] = None, 506 ) -> ModelVersions: 507 """Search models owned by you 508 params: 509 search_term: Searches the following metadata: names, shas, versions, file names, and tags 510 uploaded_time_start: Inclusive time of upload 511 uploaded_time_end: Inclusive time of upload 512 """ 513 user_id = self.auth.user_id() 514 return ModelVersions( 515 self._search_models( 516 search_term=search_term, 517 user_id=user_id, 518 start=uploaded_time_start, 519 end=uploaded_time_end, 520 ) 521 ) 522 523 def search_models( 524 self, 525 search_term: Optional[str] = None, 526 uploaded_time_start: Optional["Datetime"] = None, 527 uploaded_time_end: Optional["Datetime"] = None, 528 ) -> ModelVersions: 529 """Search all models you have access to. 530 params: 531 search_term: Searches the following metadata: names, shas, versions, file names, and tags 532 uploaded_time_start: Inclusive time of upload 533 uploaded_time_end: Inclusive time of upload 534 """ 535 return ModelVersions( 536 self._search_models( 537 search_term=search_term, 538 start=uploaded_time_start, 539 end=uploaded_time_end, 540 ) 541 ) 542 543 def _search_models( 544 self, search_term=None, user_id=None, start=None, end=None 545 ) -> List[Model]: 546 (query, params) = self._generate_model_query( 547 search_term=search_term, 548 user_id=user_id, 549 start=start, 550 end=end, 551 ) 552 553 q = gql.gql(query) 554 555 data = self._gql_client.execute(q, variable_values=params) 556 models = [] 557 if data["search_models"]: 558 for m in data["search_models"]: 559 models.append(Model(self, m)) 560 return models 561 562 def _generate_model_query( 563 self, 564 search_term=None, 565 user_id=None, 566 start=None, 567 end=None, 568 ): 569 filters = [] 570 query_params = [] 571 params = {} 572 search = "" 573 if search_term: 574 search = search_term 575 params["search_term"] = search 576 query_params.append("$search_term: String!") 577 if user_id: 578 filters.append("owner_id: {_eq: $user_id}") 579 params["user_id"] = user_id 580 query_params.append("$user_id: String!") 581 582 self._generate_time_range_graphql( 583 "created_at", 584 start=start, 585 end=end, 586 filters=filters, 587 params=params, 588 query_params=query_params, 589 ) 590 591 where_clause_str = self._generate_where_clause_str(filters) 592 query_param_str = self._generate_query_param_str(query_params) 593 query = f""" 594 query GetModels({query_param_str}) {{ 595 search_models(args: {{search: $search_term}}{where_clause_str}, order_by: {{created_at: desc}}) {{ 596 id 597 }} 598 }} 599 """ 600 return (query, params) 601 602 def get_user_by_email(self, email: str) -> Optional[User]: 603 """Find a user by email""" 604 assert email is not None 605 escaped_email = quote_plus(email) 606 url = ( 607 f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}" 608 ) 609 headers = { 610 "Content-Type": "application/json", 611 "Authorization": self.auth._bearer_token_str(), 612 "User-Agent": _user_agent, 613 } 614 resp = requests.get(url, headers=headers, data={}) 615 jresp = resp.json() 616 return None if jresp == [] else User(client=self, data=jresp[0]) 617 618 def deactivate_user(self, email: str) -> None: 619 """Deactivates an existing user of the platform 620 621 Deactivated users cannot log into the platform. 622 Deactivated users do not count towards the number of allotted user seats from the license. 623 624 The Models and Pipelines owned by the deactivated user are not removed from the platform. 625 626 :param str email: The email address of the user to deactivate. 627 628 :return: None 629 :rtype: None 630 """ 631 632 if self.auth.user_email() == email: 633 raise Exception("A user may not deactive themselves.") 634 635 user = self.get_user_by_email(email) 636 637 if user is None: 638 raise EntityNotFoundError("User", {"email": email}) 639 640 if user.username() == "admin": 641 raise Exception("Admin user may not be deactivated.") 642 643 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 644 headers = { 645 "Content-Type": "application/json", 646 "Authorization": self.auth._bearer_token_str(), 647 "User-Agent": _user_agent, 648 } 649 650 # Get the current full user representation to return in the mutation due to keycloak bug 651 get_user_response = requests.get(url, headers=headers, data={}) 652 653 cur_user_rep = get_user_response.json() 654 cur_user_rep["enabled"] = False 655 656 resp = requests.put(url, headers=headers, json=cur_user_rep) 657 658 if resp.status_code != 204: 659 raise EntityNotFoundError("User", {"email": email}) 660 return None 661 662 def activate_user(self, email: str) -> None: 663 """Activates an existing user of the platform that had been previously deactivated. 664 665 Activated users can log into the platform. 666 667 :param str email: The email address of the user to activate. 668 669 :return: None 670 :rtype: None 671 """ 672 user = self.get_user_by_email(email) 673 674 if user is None: 675 raise EntityNotFoundError("User", {"email": email}) 676 677 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 678 679 headers = { 680 "Content-Type": "application/json", 681 "Authorization": self.auth._bearer_token_str(), 682 "User-Agent": _user_agent, 683 } 684 685 # Get the current full user representation to return in the mutation due to keycloak bug 686 get_user_response = requests.get(url, headers=headers, data={}) 687 688 cur_user_rep = get_user_response.json() 689 cur_user_rep["enabled"] = True 690 691 resp = requests.put(url, headers=headers, json=cur_user_rep) 692 693 if resp.status_code != 204: 694 raise EntityNotFoundError("User", {"email": email}) 695 return None 696 697 def _get_user_by_id(self, id: str) -> Optional[User]: 698 assert id is not None 699 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{id}" 700 headers = { 701 "Content-Type": "application/json", 702 "Authorization": self.auth._bearer_token_str(), 703 "User-Agent": _user_agent, 704 } 705 resp = requests.get(url, headers=headers, data={}) 706 jresp = resp.json() 707 return None if jresp == [] else User(client=self, data=jresp) 708 709 def list_users(self) -> List[User]: 710 """List of all Users on the platform 711 712 :return: A list of all Users on the platform. 713 :rtype: List[User] 714 """ 715 resp = User.list_users(auth=self.auth) 716 return [User(client=self, data=u) for u in resp] 717 718 def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model: 719 """Upload a model defined by a file as a new model variant. 720 721 :param str model_name: The name of the model of which this is a variant. 722 Names must be ASCII alpha-numeric characters or dash (-) only. 723 :param Union[str, pathlib.Path] path: Path of the model file to upload. 724 :return: The created Model. 725 :rtype: Model 726 """ 727 728 _Visibility.PRIVATE 729 if isinstance(path, str): 730 path = pathlib.Path(path) 731 with path.open("rb") as f: 732 return self._upload_model_stream(name, {"filename": path.name}, f) 733 734 def _upload_model_stream(self, name: str, data: Dict[str, Any], file: Any): 735 require_dns_compliance(name) 736 endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload_stream") 737 data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()} 738 headers = {"User-Agent": _user_agent} 739 740 res = requests.post( 741 endpoint, auth=self.auth, params=data, data=file, headers=headers 742 ) 743 if res.status_code != 200: 744 raise ModelUploadError(res.text) 745 746 res_dict = json.loads(res.text) 747 return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0]) 748 749 def register_model_image(self, name: str, image: str) -> Model: 750 """Registers an MLFlow model as a new model. 751 752 :param str model_name: The name of the model of which this is a variant. 753 Names must be ASCII alpha-numeric characters or dash (-) only. 754 :param str image: Image name of the MLFlow model to register. 755 :return: The created Model. 756 :rtype: Model 757 """ 758 data = { 759 "image_path": image, 760 } 761 return self._upload_model(name, data) 762 763 def _upload_model( 764 self, name: str, data: Dict[str, Any], files: Dict[str, Tuple[str, bytes]] = {} 765 ): 766 require_dns_compliance(name) 767 endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload") 768 data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()} 769 if len(files) == 0: 770 files = {"dummy": ("none", b"")} 771 772 headers = {"User-Agent": _user_agent} 773 774 res = requests.post( 775 endpoint, files=files, auth=self.auth, data=data, headers=headers 776 ) 777 if res.status_code != 200: 778 raise ModelUploadError(res.text) 779 780 res_dict = json.loads(res.text) 781 return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0]) 782 783 def model_by_name(self, model_class: str, model_name: str) -> Model: 784 """Fetch a Model by name. 785 786 :param str model_class: Name of the model class. 787 :param str model_name: Name of the variant within the specified model class. 788 :return: The Model with the corresponding model and variant name. 789 :rtype: Model 790 """ 791 res = self._gql_client.execute( 792 gql.gql( 793 """ 794 query ModelByName($model_id: String!, $model_version: String!) { 795 model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) { 796 id 797 model_id 798 model_version 799 } 800 } 801 """ 802 ), 803 variable_values={ 804 "model_id": model_class, 805 "model_version": model_name, 806 }, 807 ) 808 if not res["model"]: 809 raise EntityNotFoundError( 810 "Model", {"model_class": model_class, "model_name": model_name} 811 ) 812 return Model(client=self, data={"id": res["model"][0]["id"]}) 813 814 def deployment_by_name(self, deployment_name: str) -> Deployment: 815 """Fetch a Deployment by name. 816 817 :param str deployment_name: Name of the deployment. 818 :return: The Deployment with the corresponding name. 819 :rtype: Deployment 820 """ 821 res = self._gql_client.execute( 822 gql.gql( 823 """ 824 query DeploymentByName($deployment_name: String!) { 825 deployment(where: {deploy_id: {_eq: $deployment_name}}) { 826 id 827 } 828 } 829 """ 830 ), 831 variable_values={ 832 "deployment_name": deployment_name, 833 }, 834 ) 835 if not res["deployment"]: 836 raise EntityNotFoundError( 837 "Deployment", {"deployment_name": deployment_name} 838 ) 839 return Deployment(client=self, data={"id": res["deployment"][0]["id"]}) 840 841 def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]: 842 """Fetch Pipelines by name. 843 844 :param str pipeline_name: Name of the pipeline. 845 :return: The Pipeline with the corresponding name. 846 :rtype: Pipeline 847 """ 848 res = self._gql_client.execute( 849 gql.gql( 850 """ 851 query PipelineByName($pipeline_name: String!) { 852 pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) { 853 id 854 } 855 } 856 """ 857 ), 858 variable_values={ 859 "pipeline_name": pipeline_name, 860 }, 861 ) 862 assert "pipeline" in res 863 length = len(res["pipeline"]) 864 if length < 1: 865 raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name}) 866 return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]] 867 868 def list_pipelines(self) -> List[Pipeline]: 869 """List all pipelines on the platform. 870 871 :return: A list of all pipelines on the platform. 872 :rtype: List[Pipeline] 873 """ 874 res = self._gql_client.execute( 875 gql.gql( 876 """ 877 query ListPipelines { 878 pipeline(order_by: {id: desc}) { 879 id 880 pipeline_tags { 881 tag { 882 id 883 tag 884 } 885 } 886 } 887 } 888 """ 889 ) 890 ) 891 return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]]) 892 893 def build_pipeline(self, pipeline_name: str) -> "Pipeline": 894 """Starts building a pipeline with the given `pipeline_name`, 895 returning a :py:PipelineConfigBuilder: 896 897 When completed, the pipeline can be uploaded with `.upload()` 898 899 :param pipeline_name string: Name of the pipeline, must be composed of ASCII 900 alpha-numeric characters plus dash (-). 901 """ 902 903 require_dns_compliance(pipeline_name) 904 905 _Visibility.PRIVATE 906 907 # TODO: Needs to handle visibility? 908 data = pipelines_create.sync( 909 client=self.mlops(), 910 json_body=pipelines_create_json_body.PipelinesCreateJsonBody( 911 pipeline_name, 912 self.get_current_workspace().id(), 913 pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict( 914 {} 915 ), 916 ), 917 ) 918 919 if data is None: 920 raise Exception("Failed to create pipeline") 921 922 if not isinstance(data, PipelinesCreateResponse200): 923 raise Exception(data.msg) 924 925 return Pipeline(client=self, data={"id": data.pipeline_pk_id}) 926 927 def _upload_pipeline_variant( 928 self, 929 name: str, 930 config: PipelineConfig, 931 ) -> Pipeline: 932 """Creates a new PipelineVariant with the specified configuration. 933 934 :param str name: Name of the Pipeline. Must be unique across all Pipelines. 935 :param config PipelineConfig: Pipeline configuration. 936 """ 937 definition = config.to_json() 938 _Visibility.PRIVATE 939 940 data = pipelines_create.sync( 941 client=self.mlops(), 942 json_body=pipelines_create_json_body.PipelinesCreateJsonBody( 943 name, 944 self.get_current_workspace().id(), 945 pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict( 946 definition 947 ), 948 ), 949 ) 950 951 if data is None: 952 # TODO: Generalize 953 raise Exception("Failed to create pipeline") 954 955 if not isinstance(data, PipelinesCreateResponse200): 956 raise Exception(data.msg) 957 958 for alert_config in config.alert_configurations: 959 self._gql_client.execute( 960 gql.gql( 961 """ 962 mutation CreateAlertConfiguration( 963 $pipeline_version_id: bigint, 964 $name: String, 965 $expression: String, 966 $notifications: jsonb 967 ) { 968 insert_alert_configuration(objects: { 969 name: $name, 970 expression: $expression, 971 notifications: $notifications, 972 pipeline_version_pk_id: $pipeline_version_id 973 }) { 974 returning { id } 975 } 976 } 977 """ 978 ), 979 variable_values={ 980 **alert_config.to_json(), 981 "pipeline_version_id": data.pipeline_pk_id, 982 }, 983 ) 984 985 pipeline_data = data.to_dict() 986 pipeline_data["id"] = data.pipeline_pk_id 987 988 return Pipeline( 989 client=self, 990 data=pipeline_data, 991 ) 992 993 def create_value_split_experiment( 994 self, 995 name: str, 996 meta_key: str, 997 default_model: ModelConfig, 998 challenger_models: List[Tuple[Any, ModelConfig]], 999 ) -> Pipeline: 1000 """Creates a new PipelineVariant of a "value-split experiment" type. 1001 :param str name: Name of the Pipeline 1002 :param meta_key str: Inference input key on which to redirect inputs to 1003 experiment models. 1004 :param default_model ModelConfig: Model to send inferences by default. 1005 :param challenger_models List[Tuple[Any, ModelConfig]]: A list of 1006 meta_key values -> Models to send inferences. If the inference data 1007 referred to by meta_key is equal to one of the keys in this tuple, 1008 that inference is redirected to the corresponding model instead of 1009 the default model. 1010 """ 1011 args = [meta_key, default_model.model().name()] 1012 for v, m in challenger_models: 1013 args.append(v) 1014 args.append(m.model().name()) 1015 step = { 1016 "id": "metavalue_split", 1017 "operation": "map", 1018 "args": args, 1019 } 1020 definition = {"id": name, "steps": [step]} 1021 # TODO: This seems like a one-to-one replace, find appropriate test. 1022 data = self._gql_client.execute( 1023 gql.gql( 1024 """ 1025 mutation CreatePipeline( 1026 $pipeline_id: String, 1027 $version: String, 1028 $definition: jsonb, 1029 $workspace_id: bigint 1030 ) { 1031 insert_pipeline( 1032 objects: { 1033 pipeline_versions: { 1034 data: { definition: $definition } 1035 } 1036 pipeline_id: $pipeline_id 1037 } 1038 ) { 1039 returning { 1040 id 1041 } 1042 } 1043 } 1044 """ 1045 ), 1046 variable_values={ 1047 "pipeline_id": name, 1048 "definition": definition, 1049 "workspace_id": self.get_current_workspace().id(), 1050 }, 1051 ) 1052 return Pipeline( 1053 client=self, 1054 data=data["insert_pipeline"]["returning"][0], 1055 ) 1056 1057 def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]: 1058 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1059 headers = {"User-Agent": _user_agent} 1060 partitions = requests.get(base, auth=self.auth, headers=headers).json()[ 1061 "partitions" 1062 ] 1063 1064 iterator = { 1065 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1066 for k, span in partitions.items() 1067 } 1068 1069 response = requests.post( 1070 base + "/records", 1071 params={"limit": limit}, 1072 json=iterator, 1073 auth=self.auth, 1074 headers=headers, 1075 ).json() 1076 1077 return ( 1078 LogEntries([LogEntry(json.loads(l)) for l in response["records"]]), 1079 response.get("status", "None"), 1080 ) 1081 1082 def security_logs(self, limit: int) -> List[dict]: 1083 """This function is not available in this release""" 1084 raise RuntimeError("security_log() is not available in this release.") 1085 1086 def get_raw_logs( 1087 self, 1088 topic: str, 1089 start: Optional[datetime] = None, 1090 end: Optional[datetime] = None, 1091 limit: int = 100_000, 1092 parse: bool = False, 1093 verbose: bool = False, 1094 ) -> List[Dict[str, Any]]: 1095 """Gets logs from Plateau for a particular time window without attempting 1096 to convert them to Inference LogEntries. Logs can be returned as strings 1097 or the json parsed into lists and dicts. 1098 :param topic str: The name of the topic to query 1099 :param start Optional[datetime]: The start of the time window 1100 :param end Optional[datetime]: The end of the time window 1101 :param limit int: The number of records to retrieve. Note retrieving many 1102 records may be a performance bottleneck. 1103 :param parse bool: Wether to attempt to parse the string as a json object. 1104 :param verbose bool: Prints out info to help diagnose issues. 1105 """ 1106 1107 assert limit <= 1_000_000 1108 1109 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1110 headers = {"User-Agent": _user_agent} 1111 resp = requests.get(base, auth=self.auth, headers=headers) 1112 if resp.status_code != 200: 1113 raise EntityNotFoundError( 1114 f"Could not get partitions {resp.text}", {"url": base} 1115 ) 1116 data = resp.json() 1117 partitions = data["partitions"] 1118 1119 if verbose: 1120 print(f"Got partitions {partitions}") 1121 1122 params: Dict[str, Any] = {"limit": limit} 1123 if start is not None: 1124 start_str = start.astimezone(tz=timezone.utc).isoformat() 1125 params["time.start"] = start_str 1126 if end is not None: 1127 end_str = end.astimezone(tz=timezone.utc).isoformat() 1128 params["time.end"] = end_str 1129 1130 next: Union[Any, None] = { 1131 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1132 for k, span in partitions.items() 1133 } 1134 1135 if verbose: 1136 print("Using params: ", params) 1137 1138 records = [] 1139 while next is not None: 1140 response = requests.post( 1141 base + "/records", 1142 params=params, 1143 json=next, 1144 auth=self.auth, 1145 headers=headers, 1146 ) 1147 if response.status_code != 200: 1148 raise EntityNotFoundError( 1149 f"Could not get records {resp.text}", 1150 {"url": base, "params": str(params), "iterator": str(next)}, 1151 ) 1152 1153 if verbose: 1154 print("response: ", response) 1155 1156 result = response.json() 1157 result_records = result["records"] 1158 if len(result_records) > 0: 1159 records.extend(result_records) 1160 next = result["next"] 1161 else: 1162 next = None 1163 if parse: 1164 return [json.loads(r) for r in records] 1165 return records 1166 1167 def get_raw_pipeline_inference_logs( 1168 self, 1169 topic: str, 1170 start: datetime, 1171 end: datetime, 1172 model_name: Optional[str] = None, 1173 limit: int = 100_000, 1174 verbose: bool = False, 1175 ) -> List[Dict[str, Any]]: 1176 """Gets logs from Plateau for a particular time window and filters them for 1177 the model specified. 1178 :param pipeline_name str: The name/pipeline_id of the pipeline to query 1179 :param topic str: The name of the topic to query 1180 :param start Optional[datetime]: The start of the time window 1181 :param end Optional[datetime]: The end of the time window 1182 :param model_id: The name of the specific model to filter if any 1183 :param limit int: The number of records to retrieve. Note retrieving many 1184 records may be a performance bottleneck. 1185 :param verbose bool: Prints out info to help diagnose issues. 1186 """ 1187 logs = self.get_raw_logs( 1188 topic, 1189 start=start, 1190 end=end, 1191 limit=limit, 1192 parse=True, 1193 verbose=verbose, 1194 ) 1195 1196 if verbose: 1197 print(f"Got {len(logs)} initial logs") 1198 1199 if model_name: 1200 logs = [l for l in logs if l["model_name"] == model_name] 1201 1202 # inference results are a unix timestamp in millis - filter by that 1203 start_ts = int(start.timestamp() * 1000) 1204 end_ts = int(end.timestamp() * 1000) 1205 logs = [l for l in logs if start_ts <= l["time"] < end_ts] 1206 1207 return logs 1208 1209 def get_pipeline_inference_dataframe( 1210 self, 1211 topic: str, 1212 start: datetime, 1213 end: datetime, 1214 model_name: Optional[str] = None, 1215 limit: int = 100_000, 1216 verbose=False, 1217 ) -> pd.DataFrame: 1218 logs = self.get_raw_pipeline_inference_logs( 1219 topic, start, end, model_name, limit, verbose 1220 ) 1221 return inference_logs_to_dataframe(logs) 1222 1223 def get_assay_results( 1224 self, 1225 assay_id: int, 1226 start: datetime, 1227 end: datetime, 1228 ) -> AssayAnalysisList: 1229 """Gets the assay results for a particular time window, parses them, and returns an 1230 AssayAnalysisList of AssayAnalysis. 1231 :param assay_id int: The id of the assay we are looking for. 1232 :param start datetime: The start of the time window 1233 :param end datetime: The end of the time window 1234 """ 1235 res = assays_get_assay_results.sync( 1236 client=self.mlops(), 1237 json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody( 1238 assay_id, start, end 1239 ), 1240 ) 1241 1242 if res is None: 1243 raise Exception("Failed to list models") 1244 1245 if not isinstance(res, List): 1246 raise Exception(res.msg) 1247 1248 if len(res) != 0 and not isinstance( 1249 res[0], AssaysGetAssayResultsResponse200Item 1250 ): 1251 raise Exception("invalid response") 1252 1253 return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res]) 1254 1255 def build_assay( 1256 self, 1257 assay_name: str, 1258 pipeline: Pipeline, 1259 model_name: str, 1260 baseline_start: datetime, 1261 baseline_end: datetime, 1262 ) -> AssayBuilder: 1263 """Creates an AssayBuilder that can be used to configure and create 1264 Assays. 1265 :param assay_name str: Human friendly name for the assay 1266 :param pipeline Pipeline: The pipeline this assay will work on 1267 :param model_name str: The model that this assay will monitor 1268 :param baseline_start datetime: The start time for the inferences to 1269 use as the baseline 1270 :param baseline_end datetime: The end time of the baseline window. 1271 the baseline. Windows start immediately after the baseline window and 1272 are run at regular intervals continously until the assay is deactivated 1273 or deleted. 1274 """ 1275 assay_builder = AssayBuilder( 1276 self, 1277 assay_name, 1278 pipeline.id(), 1279 pipeline.name(), 1280 model_name, 1281 baseline_start, 1282 baseline_end, 1283 ) 1284 1285 return assay_builder 1286 1287 def upload_assay(self, config: AssayConfig) -> int: 1288 """Creates an assay in the database. 1289 :param config AssayConfig: The configuration for the assay to create. 1290 :return assay_id: The identifier for the assay that was created. 1291 :rtype int 1292 """ 1293 data = assays_create.sync( 1294 client=self.mlops(), 1295 json_body=AssaysCreateJsonBody.from_dict( 1296 { 1297 **json.loads(config.to_json()), 1298 "created_at": datetime.now(timezone.utc).isoformat(), 1299 } 1300 ), 1301 ) 1302 1303 if data is None: 1304 raise Exception("Failed to create assay") 1305 1306 if not isinstance(data, AssaysCreateResponse200): 1307 raise Exception(data.msg) 1308 1309 return data.assay_id 1310 1311 def list_assays(self) -> List[Assay]: 1312 """List all assays on the platform. 1313 1314 :return: A list of all assays on the platform. 1315 :rtype: List[Assay] 1316 """ 1317 res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET)) 1318 1319 if res is None: 1320 raise Exception("Failed to get assays") 1321 1322 if not isinstance(res, List): 1323 raise Exception(res.msg) 1324 1325 return Assays([Assay(client=self, data=v.to_dict()) for v in res]) 1326 1327 def create_tag(self, tag_text: str) -> Tag: 1328 """Create a new tag with the given text.""" 1329 assert tag_text is not None 1330 return Tag._create_tag(client=self, tag_text=tag_text) 1331 1332 def create_workspace(self, workspace_name: str) -> Workspace: 1333 """Create a new workspace with the current user as its first owner. 1334 1335 :param str workspace_name: Name of the workspace, must be composed of ASCII 1336 alpha-numeric characters plus dash (-)""" 1337 assert workspace_name is not None 1338 require_dns_compliance(workspace_name) 1339 return Workspace._create_workspace(client=self, name=workspace_name) 1340 1341 def list_workspaces(self) -> List[Workspace]: 1342 """List all workspaces on the platform which this user has permission see. 1343 1344 :return: A list of all workspaces on the platform. 1345 :rtype: List[Workspace] 1346 """ 1347 res = workspaces_list.sync( 1348 client=self.mlops(), 1349 json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET), 1350 ) 1351 1352 if res is None: 1353 raise Exception("Failed to get workspaces") 1354 1355 if not isinstance(res, WorkspacesListResponse200): 1356 raise Exception(res.msg) 1357 1358 return Workspaces( 1359 [Workspace(client=self, data=d.to_dict()) for d in res.workspaces] 1360 ) 1361 1362 def set_current_workspace(self, workspace: Workspace) -> Workspace: 1363 """Any calls involving pipelines or models will use the given workspace from then on.""" 1364 assert workspace is not None 1365 if not isinstance(workspace, Workspace): 1366 raise TypeError("Workspace type was expected") 1367 1368 self._current_workspace = workspace 1369 return cast("Workspace", self._current_workspace) 1370 1371 def get_current_workspace(self) -> Workspace: 1372 """Return the current workspace. See also `set_current_workspace`.""" 1373 if self._current_workspace is None: 1374 # Is there a default? Use that or make one. 1375 default_ws = Workspace._get_user_default_workspace(self) 1376 if default_ws is not None: 1377 self._current_workspace = default_ws 1378 else: 1379 self._current_workspace = Workspace._create_user_default_workspace(self) 1380 1381 return cast("Workspace", self._current_workspace) 1382 1383 def invite_user(self, email, password=None): 1384 return User.invite_user( 1385 email, password, self.auth, self.api_endpoint, self.auth_endpoint 1386 ) 1387 1388 def get_topic_name(self, pipeline_pk_id: int) -> str: 1389 return self._post_rest_api( 1390 "v1/api/plateau/get_topic_name", 1391 { 1392 "pipeline_pk_id": pipeline_pk_id, 1393 }, 1394 ).json()["topic_name"] 1395 1396 def shim_token(self, token_data: auth.TokenData): 1397 fetcher = auth._RawTokenFetcher(token_data) 1398 self.auth = auth._PlatformAuth(fetcher) 1399 1400 def convert_model( 1401 self, 1402 path: Union[str, pathlib.Path], 1403 source_type: ModelConversionSource, 1404 conversion_arguments: ModelConversionArguments, 1405 ) -> Model: 1406 """ 1407 Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments. 1408 Convert the model to onnx, and add to available models for a pipeline. 1409 1410 :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model. 1411 :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost. 1412 :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type. 1413 :return: An instance of the Model being converted to Onnx. 1414 :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance. 1415 :raises ModelConversionFailure: Failure in converting the model type. 1416 :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported. 1417 :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist. 1418 """ 1419 if isinstance(path, str): 1420 path = pathlib.Path(path) 1421 if not os.path.exists(path): 1422 raise ModelConversionSourceFileNotPresent( 1423 f"The provided source file: {path} can not be found." 1424 ) 1425 file_handle = open(path, "rb") 1426 files = [("files", file_handle)] 1427 base_url = self.api_endpoint 1428 workspace_id = self.get_current_workspace().id() 1429 common_headers = { 1430 "user_id": self.auth.user_id(), 1431 "user_email": self.auth.user_email(), 1432 "User-Agent": _user_agent, 1433 } 1434 curry_post = partial( 1435 requests.post, 1436 auth=self.auth, 1437 headers=common_headers, 1438 files=files, 1439 ) 1440 model_id: Union[int, None] = None 1441 1442 def _handle_response(http_response) -> int: 1443 http_response.raise_for_status() 1444 response_record = http_response.json() 1445 if response_record is not None and "model_id" in response_record: 1446 return int(response_record["model_id"]) 1447 else: 1448 raise ModelConversionFailure("Failed to convert keras model") 1449 1450 try: 1451 if source_type == ModelConversionSource.KERAS: 1452 assert ( 1453 type(conversion_arguments).__name__ 1454 == ConvertKerasArguments.__name__ 1455 ) 1456 ## This is not optimal but api-lb(envoy) -> python-api 1457 ## Gave a 50(3|4|2|0) on several other options 1458 ## dimensions being of type List[Any] 1459 ## dimensions being a base64 encoded json array 1460 ## dimensions being a comma seperated string that was proccessed server side. 1461 ## This casts the ConvertKerasArguments dict structure to a json string. 1462 ## Base64 encoding via a utf-8 binary conversion, passed as the url 1463 ## parameter config. Which is handlded server side. 1464 data = { 1465 **conversion_arguments.to_dict(), 1466 "workspace_id": workspace_id, 1467 } 1468 params = { 1469 "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode( 1470 "utf-8" 1471 ) 1472 } 1473 response = curry_post(url=f"{base_url}/v1/convert/keras", params=params) 1474 model_id = _handle_response(response) 1475 elif source_type == ModelConversionSource.SKLEARN: 1476 assert ( 1477 type(conversion_arguments).__name__ 1478 == ConvertSKLearnArguments.__name__ 1479 ) 1480 response = curry_post( 1481 url=f"{base_url}/v1/convert/sklearn", 1482 params={ 1483 **conversion_arguments.to_dict(), 1484 "workspace_id": workspace_id, 1485 }, 1486 ) 1487 model_id = _handle_response(response) 1488 1489 elif source_type == ModelConversionSource.XGBOOST: 1490 assert ( 1491 type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__ 1492 ) 1493 response = curry_post( 1494 url=f"{base_url}/v1/convert/xgboost", 1495 params={ 1496 **conversion_arguments.to_dict(), 1497 "workspace_id": workspace_id, 1498 }, 1499 ) 1500 model_id = _handle_response(response) 1501 else: 1502 raise ModelConversionUnsupportedType( 1503 f"Unsupported model source type of {source_type} passed." 1504 ) 1505 if model_id is None: 1506 raise ModelConversionFailure("Failed to retrieve final model id") 1507 return Model(self, {"id": model_id}) 1508 except Exception: 1509 raise ModelConversionGenericException( 1510 "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai" 1511 ) 1512 finally: 1513 file_handle.close() 1514 1515 def _post_rest_api_json(self, uri: str, body: dict): 1516 result = self._post_rest_api(uri, body) 1517 if result.status_code == 200: 1518 return result.json() 1519 else: 1520 raise Exception(f"{result.status_code}: {result.text}") 1521 1522 def _setup_mlops_client(self) -> "AuthenticatedClient": 1523 self._mlops = AuthenticatedClient( 1524 base_url=self.api_endpoint, token=self.auth._access_token().token 1525 ) 1526 return self._mlops 1527 1528 def mlops(self) -> "AuthenticatedClient": 1529 return self._setup_mlops_client()
94class Client(object): 95 """Client handle to a Wallaroo platform instance. 96 97 Objects of this class serve as the entrypoint to Wallaroo platform 98 functionality. 99 """ 100 101 @staticmethod 102 def get_urls( 103 auth_type: Optional[str], api_endpoint: str, auth_endpoint: str 104 ) -> Tuple[Optional[str], str, str]: 105 """Method to calculate the auth values specified as defaults, 106 as params or in ENV vars. 107 Made static to be testable without reaching out to SSO, etc.""" 108 109 if auth_type is None: 110 auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None) 111 112 # ideally we'd set auth_endpoint to None default value but that would 113 # make the auth_endpoint type to be Optiona[str] which messes up 114 # a lot| of type hinting and I wanted to make minimal changes without a 115 # lot of 'warnings'. 116 if len(auth_endpoint.strip()) == 0: 117 auth_endpoint = ( 118 os.environ.get(WALLAROO_AUTH_URL) 119 or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT) 120 or "http://api-lb:8080" 121 ) 122 123 api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint) 124 125 return auth_type, api_endpoint, auth_endpoint 126 127 def __init__( 128 self, 129 api_endpoint: str = "http://api-lb:8080", 130 auth_endpoint: str = "", 131 request_timeout: int = 45, 132 auth_type: Optional[str] = None, 133 gql_client: Optional[gql.Client] = None, 134 pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432", 135 interactive: Optional[bool] = None, 136 time_format: str = "%Y-%d-%b %H:%M:%S", 137 ): 138 """Create a Client handle. 139 140 :param str api_endpoint: Host/port of the platform API endpoint 141 :param str auth_endpoint: Host/port of the platform Keycloak instance 142 :param int timeout: Max timeout of web requests, in seconds 143 :param str auth_type: Authentication type to use. Can be one of: "none", 144 "sso", "user_password". 145 :param str pg_connection_string: Postgres connection string 146 :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise. 147 :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context. 148 """ 149 150 auth_type, api_endpoint, auth_endpoint = Client.get_urls( 151 auth_type, api_endpoint, auth_endpoint 152 ) 153 154 self.auth = auth.create(auth_endpoint, auth_type) 155 156 if gql_client: 157 self._gql_client = gql_client 158 else: 159 gql_transport = RequestsHTTPTransport( 160 url=posixpath.join(api_endpoint, "v1/graphql"), 161 auth=self.auth, 162 timeout=request_timeout, 163 ) 164 self._gql_client = gql.Client( 165 transport=gql_transport, fetch_schema_from_transport=True 166 ) 167 168 self.api_endpoint = api_endpoint 169 170 self.auth_endpoint = auth_endpoint 171 172 self.timeout = request_timeout 173 174 self._setup_mlops_client() 175 176 self.pg_connection_string = pg_connection_string 177 178 self._current_workspace: Optional[Workspace] = None 179 180 # TODO: debate the names of these things 181 self._default_ws_name: Optional[str] = None 182 183 user_email = self.auth.user_email() 184 if user_email is not None: 185 self._default_ws_name = user_email + "_ws" 186 187 if interactive is not None: 188 self._interactive = interactive 189 elif ( 190 "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ 191 ): 192 self._interactive = True 193 else: 194 self._interactive = False 195 196 self._time_format = time_format 197 198 def _get_rest_api(self, path: str, params: dict): 199 headers = { 200 "authorization": self.auth._bearer_token_str(), 201 "user-agent": _user_agent, 202 } 203 204 url = f"{self.api_endpoint}/{path}" 205 206 return requests.get(url, headers=headers, params=params) 207 208 def _post_rest_api(self, path: str, body: dict): 209 headers = { 210 "authorization": self.auth._bearer_token_str(), 211 "user-agent": _user_agent, 212 } 213 214 url = f"{self.api_endpoint}/{path}" 215 return requests.post(url, headers=headers, json=body) 216 217 def list_tags(self) -> Tags: 218 """List all tags on the platform. 219 220 :return: A list of all tags on the platform. 221 :rtype: List[Tag] 222 """ 223 res = self._gql_client.execute( 224 gql.gql( 225 """ 226 query ListTags { 227 tag(order_by: {id: desc}) { 228 id 229 tag 230 model_tags { 231 model { 232 id 233 model_id 234 models_pk_id 235 model_version 236 237 } 238 } 239 pipeline_tags { 240 pipeline { 241 id 242 pipeline_id 243 pipeline_versions { 244 id 245 version 246 } 247 } 248 } 249 } 250 } 251 252 253 """ 254 ) 255 ) 256 return Tags([Tag(client=self, data={"id": p["id"]}) for p in res["tag"]]) 257 258 def list_models(self) -> ModelsList: 259 """List all models on the platform. 260 261 :return: A list of all models on the platform. 262 :rtype: List[Model] 263 """ 264 id = self.get_current_workspace().id() 265 res = models_list.sync( 266 client=self.mlops(), 267 json_body=models_list_json_body.ModelsListJsonBody(id), 268 ) 269 270 if res is None: 271 raise Exception("Failed to list models") 272 273 if not isinstance(res, ModelsListResponse200): 274 raise Exception(res.msg) 275 276 return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models]) 277 278 def list_deployments(self) -> List[Deployment]: 279 """List all deployments (active or not) on the platform. 280 281 :return: A list of all deployments on the platform. 282 :rtype: List[Deployment] 283 """ 284 res = self._gql_client.execute( 285 gql.gql( 286 """ 287 query ListDeployments { 288 deployment { 289 id 290 deploy_id 291 deployed 292 deployment_model_configs { 293 model_config { 294 id 295 } 296 } 297 } 298 } 299 """ 300 ) 301 ) 302 return [Deployment(client=self, data=d) for d in res["deployment"]] 303 304 """ 305 # Removed until we figure out what pipeline ownership means 306 # 307 # def search_my_pipelines( 308 # self, 309 # search_term: Optional[str] = None, 310 # deployed: Optional[bool] = None, 311 # created_start: Optional["Datetime"] = None, 312 # created_end: Optional["Datetime"] = None, 313 # updated_start: Optional["Datetime"] = None, 314 # updated_end: Optional["Datetime"] = None, 315 # ) -> List[Pipeline]: 316 # user_id = self.auth.user_id() 317 # return Pipelines( 318 # self._search_pipelines( 319 # search_term, 320 # deployed, 321 # user_id, 322 # created_start, 323 # created_end, 324 # updated_start, 325 # updated_end, 326 # ) 327 # ) 328 """ 329 330 def search_pipelines( 331 self, 332 search_term: Optional[str] = None, 333 deployed: Optional[bool] = None, 334 created_start: Optional["Datetime"] = None, 335 created_end: Optional["Datetime"] = None, 336 updated_start: Optional["Datetime"] = None, 337 updated_end: Optional["Datetime"] = None, 338 ) -> PipelineVariants: 339 """Search for pipelines. All parameters are optional, in which case the result is the same as 340 `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example: 341 342 myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo") 343 344 :param str search_term: Will be matched against tags and model names. Example: "footag123". 345 :param bool deployed: Pipeline was deployed or not 346 :param str created_start: Pipeline was created at or after this time 347 :param str created_end: Pipeline was created at or before this time 348 :param str updated_start: Pipeline was updated at or before this time 349 :param str updated_end: Pipeline was updated at or before this time 350 351 :return: A list of all pipelines on the platform. 352 :rtype: List[Pipeline] 353 """ 354 return PipelineVariants( 355 self._search_pipelines( 356 search_term, 357 deployed, 358 None, 359 created_start, 360 created_end, 361 updated_start, 362 updated_end, 363 ) 364 ) 365 366 def _search_pipelines( 367 self, 368 search_term: Optional[str] = None, 369 deployed: Optional[bool] = None, 370 user_id: Optional[str] = None, 371 created_start: Optional["Datetime"] = None, 372 created_end: Optional["Datetime"] = None, 373 updated_start: Optional["Datetime"] = None, 374 updated_end: Optional["Datetime"] = None, 375 ) -> List[PipelineVariant]: 376 (query, params) = self._generate_search_pipeline_query( 377 search_term=search_term, 378 deployed=deployed, 379 user_id=user_id, 380 created_start=created_start, 381 created_end=created_end, 382 updated_start=updated_start, 383 updated_end=updated_end, 384 ) 385 q = gql.gql(query) 386 data = self._gql_client.execute(q, variable_values=params) 387 pipelines = [] 388 if data["search_pipelines"]: 389 for p in data["search_pipelines"]: 390 pipelines.append(PipelineVariant(self, p)) 391 return pipelines 392 393 def _generate_search_pipeline_query( 394 self, 395 search_term: Optional[str] = None, 396 deployed: Optional[bool] = None, 397 user_id: Optional[str] = None, 398 created_start: Optional["Datetime"] = None, 399 created_end: Optional["Datetime"] = None, 400 updated_start: Optional["Datetime"] = None, 401 updated_end: Optional["Datetime"] = None, 402 ): 403 filters = [] 404 query_params = [] 405 params: Dict[str, Any] = {} 406 search = "" 407 if search_term: 408 search = search_term 409 params["search_term"] = search 410 query_params.append("$search_term: String!") 411 412 if user_id: 413 filters.append("owner_id: {_eq: $user_id}") 414 params["user_id"] = user_id 415 query_params.append("$user_id: String!") 416 417 if deployed is not None: 418 filters.append("pipeline: {deployment: {deployed: {_eq: $deployed}}}") 419 params["deployed"] = deployed 420 query_params.append("$deployed: Boolean") 421 422 self._generate_time_range_graphql( 423 "created_at", 424 start=created_start, 425 end=created_end, 426 filters=filters, 427 query_params=query_params, 428 params=params, 429 ) 430 self._generate_time_range_graphql( 431 "updated_at", 432 start=updated_start, 433 end=updated_end, 434 filters=filters, 435 query_params=query_params, 436 params=params, 437 ) 438 439 where_clause_str = self._generate_where_clause_str(filters) 440 query_param_str = self._generate_query_param_str(query_params) 441 query = f""" 442 query GetPipelines({query_param_str}) {{ 443 search_pipelines(args: {{search: $search_term}}, distinct_on: id{where_clause_str}, order_by: {{id: desc}}) {{ 444 id 445 created_at 446 pipeline_pk_id 447 updated_at 448 version 449 pipeline {{ 450 id 451 pipeline_id 452 pipeline_tags {{ 453 id 454 tag {{ 455 id 456 tag 457 }} 458 }} 459 }} 460 }} 461 }} 462 """ 463 return (query, params) 464 465 def _generate_where_clause_str(self, filters: List[str]) -> str: 466 where_clause_str = "" 467 filters_len = len(filters) 468 if filters_len > 0: 469 if filters_len > 1: 470 where_clause_str = f""", where: {{_and: {{ {", ".join(filters)} }}}}""" 471 else: 472 where_clause_str = f", where: {{{filters[0]}}}" 473 return where_clause_str 474 475 def _generate_query_param_str(self, query_params: List[str]): 476 return ", ".join(query_params) 477 478 def _generate_time_range_graphql( 479 self, 480 field: str, 481 start: Optional["Datetime"], 482 end: Optional["Datetime"], 483 filters: List[str], 484 query_params: List[str], 485 params: Dict[str, Any], 486 ): 487 if start and not end: 488 filters.append(f"{field}: {{_gte: $start_{field}}}") 489 params[f"start_{field}"] = start 490 query_params.append(f"$start_{field}: timestamptz!") 491 elif end and not start: 492 filters.append(f"{field}: {{_lte: $end_{field}}}") 493 params[f"end_{field}"] = end 494 query_params.append(f"$end_{field}: timestamptz!") 495 elif start and end: 496 filters.append(f"{field}: {{_gte: $start_{field}, _lte: $end_{field}}}") 497 params[f"start_{field}"] = start 498 params[f"end_{field}"] = start 499 query_params.append(f"$start_{field}: timestamptz!") 500 query_params.append(f"$end_{field}: timestamptz!") 501 502 def search_my_models( 503 self, 504 search_term: Optional[str] = None, 505 uploaded_time_start: Optional["Datetime"] = None, 506 uploaded_time_end: Optional["Datetime"] = None, 507 ) -> ModelVersions: 508 """Search models owned by you 509 params: 510 search_term: Searches the following metadata: names, shas, versions, file names, and tags 511 uploaded_time_start: Inclusive time of upload 512 uploaded_time_end: Inclusive time of upload 513 """ 514 user_id = self.auth.user_id() 515 return ModelVersions( 516 self._search_models( 517 search_term=search_term, 518 user_id=user_id, 519 start=uploaded_time_start, 520 end=uploaded_time_end, 521 ) 522 ) 523 524 def search_models( 525 self, 526 search_term: Optional[str] = None, 527 uploaded_time_start: Optional["Datetime"] = None, 528 uploaded_time_end: Optional["Datetime"] = None, 529 ) -> ModelVersions: 530 """Search all models you have access to. 531 params: 532 search_term: Searches the following metadata: names, shas, versions, file names, and tags 533 uploaded_time_start: Inclusive time of upload 534 uploaded_time_end: Inclusive time of upload 535 """ 536 return ModelVersions( 537 self._search_models( 538 search_term=search_term, 539 start=uploaded_time_start, 540 end=uploaded_time_end, 541 ) 542 ) 543 544 def _search_models( 545 self, search_term=None, user_id=None, start=None, end=None 546 ) -> List[Model]: 547 (query, params) = self._generate_model_query( 548 search_term=search_term, 549 user_id=user_id, 550 start=start, 551 end=end, 552 ) 553 554 q = gql.gql(query) 555 556 data = self._gql_client.execute(q, variable_values=params) 557 models = [] 558 if data["search_models"]: 559 for m in data["search_models"]: 560 models.append(Model(self, m)) 561 return models 562 563 def _generate_model_query( 564 self, 565 search_term=None, 566 user_id=None, 567 start=None, 568 end=None, 569 ): 570 filters = [] 571 query_params = [] 572 params = {} 573 search = "" 574 if search_term: 575 search = search_term 576 params["search_term"] = search 577 query_params.append("$search_term: String!") 578 if user_id: 579 filters.append("owner_id: {_eq: $user_id}") 580 params["user_id"] = user_id 581 query_params.append("$user_id: String!") 582 583 self._generate_time_range_graphql( 584 "created_at", 585 start=start, 586 end=end, 587 filters=filters, 588 params=params, 589 query_params=query_params, 590 ) 591 592 where_clause_str = self._generate_where_clause_str(filters) 593 query_param_str = self._generate_query_param_str(query_params) 594 query = f""" 595 query GetModels({query_param_str}) {{ 596 search_models(args: {{search: $search_term}}{where_clause_str}, order_by: {{created_at: desc}}) {{ 597 id 598 }} 599 }} 600 """ 601 return (query, params) 602 603 def get_user_by_email(self, email: str) -> Optional[User]: 604 """Find a user by email""" 605 assert email is not None 606 escaped_email = quote_plus(email) 607 url = ( 608 f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}" 609 ) 610 headers = { 611 "Content-Type": "application/json", 612 "Authorization": self.auth._bearer_token_str(), 613 "User-Agent": _user_agent, 614 } 615 resp = requests.get(url, headers=headers, data={}) 616 jresp = resp.json() 617 return None if jresp == [] else User(client=self, data=jresp[0]) 618 619 def deactivate_user(self, email: str) -> None: 620 """Deactivates an existing user of the platform 621 622 Deactivated users cannot log into the platform. 623 Deactivated users do not count towards the number of allotted user seats from the license. 624 625 The Models and Pipelines owned by the deactivated user are not removed from the platform. 626 627 :param str email: The email address of the user to deactivate. 628 629 :return: None 630 :rtype: None 631 """ 632 633 if self.auth.user_email() == email: 634 raise Exception("A user may not deactive themselves.") 635 636 user = self.get_user_by_email(email) 637 638 if user is None: 639 raise EntityNotFoundError("User", {"email": email}) 640 641 if user.username() == "admin": 642 raise Exception("Admin user may not be deactivated.") 643 644 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 645 headers = { 646 "Content-Type": "application/json", 647 "Authorization": self.auth._bearer_token_str(), 648 "User-Agent": _user_agent, 649 } 650 651 # Get the current full user representation to return in the mutation due to keycloak bug 652 get_user_response = requests.get(url, headers=headers, data={}) 653 654 cur_user_rep = get_user_response.json() 655 cur_user_rep["enabled"] = False 656 657 resp = requests.put(url, headers=headers, json=cur_user_rep) 658 659 if resp.status_code != 204: 660 raise EntityNotFoundError("User", {"email": email}) 661 return None 662 663 def activate_user(self, email: str) -> None: 664 """Activates an existing user of the platform that had been previously deactivated. 665 666 Activated users can log into the platform. 667 668 :param str email: The email address of the user to activate. 669 670 :return: None 671 :rtype: None 672 """ 673 user = self.get_user_by_email(email) 674 675 if user is None: 676 raise EntityNotFoundError("User", {"email": email}) 677 678 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 679 680 headers = { 681 "Content-Type": "application/json", 682 "Authorization": self.auth._bearer_token_str(), 683 "User-Agent": _user_agent, 684 } 685 686 # Get the current full user representation to return in the mutation due to keycloak bug 687 get_user_response = requests.get(url, headers=headers, data={}) 688 689 cur_user_rep = get_user_response.json() 690 cur_user_rep["enabled"] = True 691 692 resp = requests.put(url, headers=headers, json=cur_user_rep) 693 694 if resp.status_code != 204: 695 raise EntityNotFoundError("User", {"email": email}) 696 return None 697 698 def _get_user_by_id(self, id: str) -> Optional[User]: 699 assert id is not None 700 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{id}" 701 headers = { 702 "Content-Type": "application/json", 703 "Authorization": self.auth._bearer_token_str(), 704 "User-Agent": _user_agent, 705 } 706 resp = requests.get(url, headers=headers, data={}) 707 jresp = resp.json() 708 return None if jresp == [] else User(client=self, data=jresp) 709 710 def list_users(self) -> List[User]: 711 """List of all Users on the platform 712 713 :return: A list of all Users on the platform. 714 :rtype: List[User] 715 """ 716 resp = User.list_users(auth=self.auth) 717 return [User(client=self, data=u) for u in resp] 718 719 def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model: 720 """Upload a model defined by a file as a new model variant. 721 722 :param str model_name: The name of the model of which this is a variant. 723 Names must be ASCII alpha-numeric characters or dash (-) only. 724 :param Union[str, pathlib.Path] path: Path of the model file to upload. 725 :return: The created Model. 726 :rtype: Model 727 """ 728 729 _Visibility.PRIVATE 730 if isinstance(path, str): 731 path = pathlib.Path(path) 732 with path.open("rb") as f: 733 return self._upload_model_stream(name, {"filename": path.name}, f) 734 735 def _upload_model_stream(self, name: str, data: Dict[str, Any], file: Any): 736 require_dns_compliance(name) 737 endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload_stream") 738 data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()} 739 headers = {"User-Agent": _user_agent} 740 741 res = requests.post( 742 endpoint, auth=self.auth, params=data, data=file, headers=headers 743 ) 744 if res.status_code != 200: 745 raise ModelUploadError(res.text) 746 747 res_dict = json.loads(res.text) 748 return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0]) 749 750 def register_model_image(self, name: str, image: str) -> Model: 751 """Registers an MLFlow model as a new model. 752 753 :param str model_name: The name of the model of which this is a variant. 754 Names must be ASCII alpha-numeric characters or dash (-) only. 755 :param str image: Image name of the MLFlow model to register. 756 :return: The created Model. 757 :rtype: Model 758 """ 759 data = { 760 "image_path": image, 761 } 762 return self._upload_model(name, data) 763 764 def _upload_model( 765 self, name: str, data: Dict[str, Any], files: Dict[str, Tuple[str, bytes]] = {} 766 ): 767 require_dns_compliance(name) 768 endpoint = posixpath.join(self.api_endpoint, "v1/api/models/upload") 769 data = {**data, "name": name, "workspace_id": self.get_current_workspace().id()} 770 if len(files) == 0: 771 files = {"dummy": ("none", b"")} 772 773 headers = {"User-Agent": _user_agent} 774 775 res = requests.post( 776 endpoint, files=files, auth=self.auth, data=data, headers=headers 777 ) 778 if res.status_code != 200: 779 raise ModelUploadError(res.text) 780 781 res_dict = json.loads(res.text) 782 return Model(self, data=res_dict["insert_models"]["returning"][0]["models"][0]) 783 784 def model_by_name(self, model_class: str, model_name: str) -> Model: 785 """Fetch a Model by name. 786 787 :param str model_class: Name of the model class. 788 :param str model_name: Name of the variant within the specified model class. 789 :return: The Model with the corresponding model and variant name. 790 :rtype: Model 791 """ 792 res = self._gql_client.execute( 793 gql.gql( 794 """ 795 query ModelByName($model_id: String!, $model_version: String!) { 796 model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) { 797 id 798 model_id 799 model_version 800 } 801 } 802 """ 803 ), 804 variable_values={ 805 "model_id": model_class, 806 "model_version": model_name, 807 }, 808 ) 809 if not res["model"]: 810 raise EntityNotFoundError( 811 "Model", {"model_class": model_class, "model_name": model_name} 812 ) 813 return Model(client=self, data={"id": res["model"][0]["id"]}) 814 815 def deployment_by_name(self, deployment_name: str) -> Deployment: 816 """Fetch a Deployment by name. 817 818 :param str deployment_name: Name of the deployment. 819 :return: The Deployment with the corresponding name. 820 :rtype: Deployment 821 """ 822 res = self._gql_client.execute( 823 gql.gql( 824 """ 825 query DeploymentByName($deployment_name: String!) { 826 deployment(where: {deploy_id: {_eq: $deployment_name}}) { 827 id 828 } 829 } 830 """ 831 ), 832 variable_values={ 833 "deployment_name": deployment_name, 834 }, 835 ) 836 if not res["deployment"]: 837 raise EntityNotFoundError( 838 "Deployment", {"deployment_name": deployment_name} 839 ) 840 return Deployment(client=self, data={"id": res["deployment"][0]["id"]}) 841 842 def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]: 843 """Fetch Pipelines by name. 844 845 :param str pipeline_name: Name of the pipeline. 846 :return: The Pipeline with the corresponding name. 847 :rtype: Pipeline 848 """ 849 res = self._gql_client.execute( 850 gql.gql( 851 """ 852 query PipelineByName($pipeline_name: String!) { 853 pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) { 854 id 855 } 856 } 857 """ 858 ), 859 variable_values={ 860 "pipeline_name": pipeline_name, 861 }, 862 ) 863 assert "pipeline" in res 864 length = len(res["pipeline"]) 865 if length < 1: 866 raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name}) 867 return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]] 868 869 def list_pipelines(self) -> List[Pipeline]: 870 """List all pipelines on the platform. 871 872 :return: A list of all pipelines on the platform. 873 :rtype: List[Pipeline] 874 """ 875 res = self._gql_client.execute( 876 gql.gql( 877 """ 878 query ListPipelines { 879 pipeline(order_by: {id: desc}) { 880 id 881 pipeline_tags { 882 tag { 883 id 884 tag 885 } 886 } 887 } 888 } 889 """ 890 ) 891 ) 892 return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]]) 893 894 def build_pipeline(self, pipeline_name: str) -> "Pipeline": 895 """Starts building a pipeline with the given `pipeline_name`, 896 returning a :py:PipelineConfigBuilder: 897 898 When completed, the pipeline can be uploaded with `.upload()` 899 900 :param pipeline_name string: Name of the pipeline, must be composed of ASCII 901 alpha-numeric characters plus dash (-). 902 """ 903 904 require_dns_compliance(pipeline_name) 905 906 _Visibility.PRIVATE 907 908 # TODO: Needs to handle visibility? 909 data = pipelines_create.sync( 910 client=self.mlops(), 911 json_body=pipelines_create_json_body.PipelinesCreateJsonBody( 912 pipeline_name, 913 self.get_current_workspace().id(), 914 pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict( 915 {} 916 ), 917 ), 918 ) 919 920 if data is None: 921 raise Exception("Failed to create pipeline") 922 923 if not isinstance(data, PipelinesCreateResponse200): 924 raise Exception(data.msg) 925 926 return Pipeline(client=self, data={"id": data.pipeline_pk_id}) 927 928 def _upload_pipeline_variant( 929 self, 930 name: str, 931 config: PipelineConfig, 932 ) -> Pipeline: 933 """Creates a new PipelineVariant with the specified configuration. 934 935 :param str name: Name of the Pipeline. Must be unique across all Pipelines. 936 :param config PipelineConfig: Pipeline configuration. 937 """ 938 definition = config.to_json() 939 _Visibility.PRIVATE 940 941 data = pipelines_create.sync( 942 client=self.mlops(), 943 json_body=pipelines_create_json_body.PipelinesCreateJsonBody( 944 name, 945 self.get_current_workspace().id(), 946 pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict( 947 definition 948 ), 949 ), 950 ) 951 952 if data is None: 953 # TODO: Generalize 954 raise Exception("Failed to create pipeline") 955 956 if not isinstance(data, PipelinesCreateResponse200): 957 raise Exception(data.msg) 958 959 for alert_config in config.alert_configurations: 960 self._gql_client.execute( 961 gql.gql( 962 """ 963 mutation CreateAlertConfiguration( 964 $pipeline_version_id: bigint, 965 $name: String, 966 $expression: String, 967 $notifications: jsonb 968 ) { 969 insert_alert_configuration(objects: { 970 name: $name, 971 expression: $expression, 972 notifications: $notifications, 973 pipeline_version_pk_id: $pipeline_version_id 974 }) { 975 returning { id } 976 } 977 } 978 """ 979 ), 980 variable_values={ 981 **alert_config.to_json(), 982 "pipeline_version_id": data.pipeline_pk_id, 983 }, 984 ) 985 986 pipeline_data = data.to_dict() 987 pipeline_data["id"] = data.pipeline_pk_id 988 989 return Pipeline( 990 client=self, 991 data=pipeline_data, 992 ) 993 994 def create_value_split_experiment( 995 self, 996 name: str, 997 meta_key: str, 998 default_model: ModelConfig, 999 challenger_models: List[Tuple[Any, ModelConfig]], 1000 ) -> Pipeline: 1001 """Creates a new PipelineVariant of a "value-split experiment" type. 1002 :param str name: Name of the Pipeline 1003 :param meta_key str: Inference input key on which to redirect inputs to 1004 experiment models. 1005 :param default_model ModelConfig: Model to send inferences by default. 1006 :param challenger_models List[Tuple[Any, ModelConfig]]: A list of 1007 meta_key values -> Models to send inferences. If the inference data 1008 referred to by meta_key is equal to one of the keys in this tuple, 1009 that inference is redirected to the corresponding model instead of 1010 the default model. 1011 """ 1012 args = [meta_key, default_model.model().name()] 1013 for v, m in challenger_models: 1014 args.append(v) 1015 args.append(m.model().name()) 1016 step = { 1017 "id": "metavalue_split", 1018 "operation": "map", 1019 "args": args, 1020 } 1021 definition = {"id": name, "steps": [step]} 1022 # TODO: This seems like a one-to-one replace, find appropriate test. 1023 data = self._gql_client.execute( 1024 gql.gql( 1025 """ 1026 mutation CreatePipeline( 1027 $pipeline_id: String, 1028 $version: String, 1029 $definition: jsonb, 1030 $workspace_id: bigint 1031 ) { 1032 insert_pipeline( 1033 objects: { 1034 pipeline_versions: { 1035 data: { definition: $definition } 1036 } 1037 pipeline_id: $pipeline_id 1038 } 1039 ) { 1040 returning { 1041 id 1042 } 1043 } 1044 } 1045 """ 1046 ), 1047 variable_values={ 1048 "pipeline_id": name, 1049 "definition": definition, 1050 "workspace_id": self.get_current_workspace().id(), 1051 }, 1052 ) 1053 return Pipeline( 1054 client=self, 1055 data=data["insert_pipeline"]["returning"][0], 1056 ) 1057 1058 def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]: 1059 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1060 headers = {"User-Agent": _user_agent} 1061 partitions = requests.get(base, auth=self.auth, headers=headers).json()[ 1062 "partitions" 1063 ] 1064 1065 iterator = { 1066 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1067 for k, span in partitions.items() 1068 } 1069 1070 response = requests.post( 1071 base + "/records", 1072 params={"limit": limit}, 1073 json=iterator, 1074 auth=self.auth, 1075 headers=headers, 1076 ).json() 1077 1078 return ( 1079 LogEntries([LogEntry(json.loads(l)) for l in response["records"]]), 1080 response.get("status", "None"), 1081 ) 1082 1083 def security_logs(self, limit: int) -> List[dict]: 1084 """This function is not available in this release""" 1085 raise RuntimeError("security_log() is not available in this release.") 1086 1087 def get_raw_logs( 1088 self, 1089 topic: str, 1090 start: Optional[datetime] = None, 1091 end: Optional[datetime] = None, 1092 limit: int = 100_000, 1093 parse: bool = False, 1094 verbose: bool = False, 1095 ) -> List[Dict[str, Any]]: 1096 """Gets logs from Plateau for a particular time window without attempting 1097 to convert them to Inference LogEntries. Logs can be returned as strings 1098 or the json parsed into lists and dicts. 1099 :param topic str: The name of the topic to query 1100 :param start Optional[datetime]: The start of the time window 1101 :param end Optional[datetime]: The end of the time window 1102 :param limit int: The number of records to retrieve. Note retrieving many 1103 records may be a performance bottleneck. 1104 :param parse bool: Wether to attempt to parse the string as a json object. 1105 :param verbose bool: Prints out info to help diagnose issues. 1106 """ 1107 1108 assert limit <= 1_000_000 1109 1110 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1111 headers = {"User-Agent": _user_agent} 1112 resp = requests.get(base, auth=self.auth, headers=headers) 1113 if resp.status_code != 200: 1114 raise EntityNotFoundError( 1115 f"Could not get partitions {resp.text}", {"url": base} 1116 ) 1117 data = resp.json() 1118 partitions = data["partitions"] 1119 1120 if verbose: 1121 print(f"Got partitions {partitions}") 1122 1123 params: Dict[str, Any] = {"limit": limit} 1124 if start is not None: 1125 start_str = start.astimezone(tz=timezone.utc).isoformat() 1126 params["time.start"] = start_str 1127 if end is not None: 1128 end_str = end.astimezone(tz=timezone.utc).isoformat() 1129 params["time.end"] = end_str 1130 1131 next: Union[Any, None] = { 1132 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1133 for k, span in partitions.items() 1134 } 1135 1136 if verbose: 1137 print("Using params: ", params) 1138 1139 records = [] 1140 while next is not None: 1141 response = requests.post( 1142 base + "/records", 1143 params=params, 1144 json=next, 1145 auth=self.auth, 1146 headers=headers, 1147 ) 1148 if response.status_code != 200: 1149 raise EntityNotFoundError( 1150 f"Could not get records {resp.text}", 1151 {"url": base, "params": str(params), "iterator": str(next)}, 1152 ) 1153 1154 if verbose: 1155 print("response: ", response) 1156 1157 result = response.json() 1158 result_records = result["records"] 1159 if len(result_records) > 0: 1160 records.extend(result_records) 1161 next = result["next"] 1162 else: 1163 next = None 1164 if parse: 1165 return [json.loads(r) for r in records] 1166 return records 1167 1168 def get_raw_pipeline_inference_logs( 1169 self, 1170 topic: str, 1171 start: datetime, 1172 end: datetime, 1173 model_name: Optional[str] = None, 1174 limit: int = 100_000, 1175 verbose: bool = False, 1176 ) -> List[Dict[str, Any]]: 1177 """Gets logs from Plateau for a particular time window and filters them for 1178 the model specified. 1179 :param pipeline_name str: The name/pipeline_id of the pipeline to query 1180 :param topic str: The name of the topic to query 1181 :param start Optional[datetime]: The start of the time window 1182 :param end Optional[datetime]: The end of the time window 1183 :param model_id: The name of the specific model to filter if any 1184 :param limit int: The number of records to retrieve. Note retrieving many 1185 records may be a performance bottleneck. 1186 :param verbose bool: Prints out info to help diagnose issues. 1187 """ 1188 logs = self.get_raw_logs( 1189 topic, 1190 start=start, 1191 end=end, 1192 limit=limit, 1193 parse=True, 1194 verbose=verbose, 1195 ) 1196 1197 if verbose: 1198 print(f"Got {len(logs)} initial logs") 1199 1200 if model_name: 1201 logs = [l for l in logs if l["model_name"] == model_name] 1202 1203 # inference results are a unix timestamp in millis - filter by that 1204 start_ts = int(start.timestamp() * 1000) 1205 end_ts = int(end.timestamp() * 1000) 1206 logs = [l for l in logs if start_ts <= l["time"] < end_ts] 1207 1208 return logs 1209 1210 def get_pipeline_inference_dataframe( 1211 self, 1212 topic: str, 1213 start: datetime, 1214 end: datetime, 1215 model_name: Optional[str] = None, 1216 limit: int = 100_000, 1217 verbose=False, 1218 ) -> pd.DataFrame: 1219 logs = self.get_raw_pipeline_inference_logs( 1220 topic, start, end, model_name, limit, verbose 1221 ) 1222 return inference_logs_to_dataframe(logs) 1223 1224 def get_assay_results( 1225 self, 1226 assay_id: int, 1227 start: datetime, 1228 end: datetime, 1229 ) -> AssayAnalysisList: 1230 """Gets the assay results for a particular time window, parses them, and returns an 1231 AssayAnalysisList of AssayAnalysis. 1232 :param assay_id int: The id of the assay we are looking for. 1233 :param start datetime: The start of the time window 1234 :param end datetime: The end of the time window 1235 """ 1236 res = assays_get_assay_results.sync( 1237 client=self.mlops(), 1238 json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody( 1239 assay_id, start, end 1240 ), 1241 ) 1242 1243 if res is None: 1244 raise Exception("Failed to list models") 1245 1246 if not isinstance(res, List): 1247 raise Exception(res.msg) 1248 1249 if len(res) != 0 and not isinstance( 1250 res[0], AssaysGetAssayResultsResponse200Item 1251 ): 1252 raise Exception("invalid response") 1253 1254 return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res]) 1255 1256 def build_assay( 1257 self, 1258 assay_name: str, 1259 pipeline: Pipeline, 1260 model_name: str, 1261 baseline_start: datetime, 1262 baseline_end: datetime, 1263 ) -> AssayBuilder: 1264 """Creates an AssayBuilder that can be used to configure and create 1265 Assays. 1266 :param assay_name str: Human friendly name for the assay 1267 :param pipeline Pipeline: The pipeline this assay will work on 1268 :param model_name str: The model that this assay will monitor 1269 :param baseline_start datetime: The start time for the inferences to 1270 use as the baseline 1271 :param baseline_end datetime: The end time of the baseline window. 1272 the baseline. Windows start immediately after the baseline window and 1273 are run at regular intervals continously until the assay is deactivated 1274 or deleted. 1275 """ 1276 assay_builder = AssayBuilder( 1277 self, 1278 assay_name, 1279 pipeline.id(), 1280 pipeline.name(), 1281 model_name, 1282 baseline_start, 1283 baseline_end, 1284 ) 1285 1286 return assay_builder 1287 1288 def upload_assay(self, config: AssayConfig) -> int: 1289 """Creates an assay in the database. 1290 :param config AssayConfig: The configuration for the assay to create. 1291 :return assay_id: The identifier for the assay that was created. 1292 :rtype int 1293 """ 1294 data = assays_create.sync( 1295 client=self.mlops(), 1296 json_body=AssaysCreateJsonBody.from_dict( 1297 { 1298 **json.loads(config.to_json()), 1299 "created_at": datetime.now(timezone.utc).isoformat(), 1300 } 1301 ), 1302 ) 1303 1304 if data is None: 1305 raise Exception("Failed to create assay") 1306 1307 if not isinstance(data, AssaysCreateResponse200): 1308 raise Exception(data.msg) 1309 1310 return data.assay_id 1311 1312 def list_assays(self) -> List[Assay]: 1313 """List all assays on the platform. 1314 1315 :return: A list of all assays on the platform. 1316 :rtype: List[Assay] 1317 """ 1318 res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET)) 1319 1320 if res is None: 1321 raise Exception("Failed to get assays") 1322 1323 if not isinstance(res, List): 1324 raise Exception(res.msg) 1325 1326 return Assays([Assay(client=self, data=v.to_dict()) for v in res]) 1327 1328 def create_tag(self, tag_text: str) -> Tag: 1329 """Create a new tag with the given text.""" 1330 assert tag_text is not None 1331 return Tag._create_tag(client=self, tag_text=tag_text) 1332 1333 def create_workspace(self, workspace_name: str) -> Workspace: 1334 """Create a new workspace with the current user as its first owner. 1335 1336 :param str workspace_name: Name of the workspace, must be composed of ASCII 1337 alpha-numeric characters plus dash (-)""" 1338 assert workspace_name is not None 1339 require_dns_compliance(workspace_name) 1340 return Workspace._create_workspace(client=self, name=workspace_name) 1341 1342 def list_workspaces(self) -> List[Workspace]: 1343 """List all workspaces on the platform which this user has permission see. 1344 1345 :return: A list of all workspaces on the platform. 1346 :rtype: List[Workspace] 1347 """ 1348 res = workspaces_list.sync( 1349 client=self.mlops(), 1350 json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET), 1351 ) 1352 1353 if res is None: 1354 raise Exception("Failed to get workspaces") 1355 1356 if not isinstance(res, WorkspacesListResponse200): 1357 raise Exception(res.msg) 1358 1359 return Workspaces( 1360 [Workspace(client=self, data=d.to_dict()) for d in res.workspaces] 1361 ) 1362 1363 def set_current_workspace(self, workspace: Workspace) -> Workspace: 1364 """Any calls involving pipelines or models will use the given workspace from then on.""" 1365 assert workspace is not None 1366 if not isinstance(workspace, Workspace): 1367 raise TypeError("Workspace type was expected") 1368 1369 self._current_workspace = workspace 1370 return cast("Workspace", self._current_workspace) 1371 1372 def get_current_workspace(self) -> Workspace: 1373 """Return the current workspace. See also `set_current_workspace`.""" 1374 if self._current_workspace is None: 1375 # Is there a default? Use that or make one. 1376 default_ws = Workspace._get_user_default_workspace(self) 1377 if default_ws is not None: 1378 self._current_workspace = default_ws 1379 else: 1380 self._current_workspace = Workspace._create_user_default_workspace(self) 1381 1382 return cast("Workspace", self._current_workspace) 1383 1384 def invite_user(self, email, password=None): 1385 return User.invite_user( 1386 email, password, self.auth, self.api_endpoint, self.auth_endpoint 1387 ) 1388 1389 def get_topic_name(self, pipeline_pk_id: int) -> str: 1390 return self._post_rest_api( 1391 "v1/api/plateau/get_topic_name", 1392 { 1393 "pipeline_pk_id": pipeline_pk_id, 1394 }, 1395 ).json()["topic_name"] 1396 1397 def shim_token(self, token_data: auth.TokenData): 1398 fetcher = auth._RawTokenFetcher(token_data) 1399 self.auth = auth._PlatformAuth(fetcher) 1400 1401 def convert_model( 1402 self, 1403 path: Union[str, pathlib.Path], 1404 source_type: ModelConversionSource, 1405 conversion_arguments: ModelConversionArguments, 1406 ) -> Model: 1407 """ 1408 Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments. 1409 Convert the model to onnx, and add to available models for a pipeline. 1410 1411 :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model. 1412 :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost. 1413 :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type. 1414 :return: An instance of the Model being converted to Onnx. 1415 :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance. 1416 :raises ModelConversionFailure: Failure in converting the model type. 1417 :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported. 1418 :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist. 1419 """ 1420 if isinstance(path, str): 1421 path = pathlib.Path(path) 1422 if not os.path.exists(path): 1423 raise ModelConversionSourceFileNotPresent( 1424 f"The provided source file: {path} can not be found." 1425 ) 1426 file_handle = open(path, "rb") 1427 files = [("files", file_handle)] 1428 base_url = self.api_endpoint 1429 workspace_id = self.get_current_workspace().id() 1430 common_headers = { 1431 "user_id": self.auth.user_id(), 1432 "user_email": self.auth.user_email(), 1433 "User-Agent": _user_agent, 1434 } 1435 curry_post = partial( 1436 requests.post, 1437 auth=self.auth, 1438 headers=common_headers, 1439 files=files, 1440 ) 1441 model_id: Union[int, None] = None 1442 1443 def _handle_response(http_response) -> int: 1444 http_response.raise_for_status() 1445 response_record = http_response.json() 1446 if response_record is not None and "model_id" in response_record: 1447 return int(response_record["model_id"]) 1448 else: 1449 raise ModelConversionFailure("Failed to convert keras model") 1450 1451 try: 1452 if source_type == ModelConversionSource.KERAS: 1453 assert ( 1454 type(conversion_arguments).__name__ 1455 == ConvertKerasArguments.__name__ 1456 ) 1457 ## This is not optimal but api-lb(envoy) -> python-api 1458 ## Gave a 50(3|4|2|0) on several other options 1459 ## dimensions being of type List[Any] 1460 ## dimensions being a base64 encoded json array 1461 ## dimensions being a comma seperated string that was proccessed server side. 1462 ## This casts the ConvertKerasArguments dict structure to a json string. 1463 ## Base64 encoding via a utf-8 binary conversion, passed as the url 1464 ## parameter config. Which is handlded server side. 1465 data = { 1466 **conversion_arguments.to_dict(), 1467 "workspace_id": workspace_id, 1468 } 1469 params = { 1470 "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode( 1471 "utf-8" 1472 ) 1473 } 1474 response = curry_post(url=f"{base_url}/v1/convert/keras", params=params) 1475 model_id = _handle_response(response) 1476 elif source_type == ModelConversionSource.SKLEARN: 1477 assert ( 1478 type(conversion_arguments).__name__ 1479 == ConvertSKLearnArguments.__name__ 1480 ) 1481 response = curry_post( 1482 url=f"{base_url}/v1/convert/sklearn", 1483 params={ 1484 **conversion_arguments.to_dict(), 1485 "workspace_id": workspace_id, 1486 }, 1487 ) 1488 model_id = _handle_response(response) 1489 1490 elif source_type == ModelConversionSource.XGBOOST: 1491 assert ( 1492 type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__ 1493 ) 1494 response = curry_post( 1495 url=f"{base_url}/v1/convert/xgboost", 1496 params={ 1497 **conversion_arguments.to_dict(), 1498 "workspace_id": workspace_id, 1499 }, 1500 ) 1501 model_id = _handle_response(response) 1502 else: 1503 raise ModelConversionUnsupportedType( 1504 f"Unsupported model source type of {source_type} passed." 1505 ) 1506 if model_id is None: 1507 raise ModelConversionFailure("Failed to retrieve final model id") 1508 return Model(self, {"id": model_id}) 1509 except Exception: 1510 raise ModelConversionGenericException( 1511 "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai" 1512 ) 1513 finally: 1514 file_handle.close() 1515 1516 def _post_rest_api_json(self, uri: str, body: dict): 1517 result = self._post_rest_api(uri, body) 1518 if result.status_code == 200: 1519 return result.json() 1520 else: 1521 raise Exception(f"{result.status_code}: {result.text}") 1522 1523 def _setup_mlops_client(self) -> "AuthenticatedClient": 1524 self._mlops = AuthenticatedClient( 1525 base_url=self.api_endpoint, token=self.auth._access_token().token 1526 ) 1527 return self._mlops 1528 1529 def mlops(self) -> "AuthenticatedClient": 1530 return self._setup_mlops_client()
Client handle to a Wallaroo platform instance.
Objects of this class serve as the entrypoint to Wallaroo platform functionality.
127 def __init__( 128 self, 129 api_endpoint: str = "http://api-lb:8080", 130 auth_endpoint: str = "", 131 request_timeout: int = 45, 132 auth_type: Optional[str] = None, 133 gql_client: Optional[gql.Client] = None, 134 pg_connection_string: str = "dbname=postgres user=postgres password=password host=postgres port=5432", 135 interactive: Optional[bool] = None, 136 time_format: str = "%Y-%d-%b %H:%M:%S", 137 ): 138 """Create a Client handle. 139 140 :param str api_endpoint: Host/port of the platform API endpoint 141 :param str auth_endpoint: Host/port of the platform Keycloak instance 142 :param int timeout: Max timeout of web requests, in seconds 143 :param str auth_type: Authentication type to use. Can be one of: "none", 144 "sso", "user_password". 145 :param str pg_connection_string: Postgres connection string 146 :param bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise. 147 :param str time_format: Preferred `strftime` format string for displaying timestamps in a human context. 148 """ 149 150 auth_type, api_endpoint, auth_endpoint = Client.get_urls( 151 auth_type, api_endpoint, auth_endpoint 152 ) 153 154 self.auth = auth.create(auth_endpoint, auth_type) 155 156 if gql_client: 157 self._gql_client = gql_client 158 else: 159 gql_transport = RequestsHTTPTransport( 160 url=posixpath.join(api_endpoint, "v1/graphql"), 161 auth=self.auth, 162 timeout=request_timeout, 163 ) 164 self._gql_client = gql.Client( 165 transport=gql_transport, fetch_schema_from_transport=True 166 ) 167 168 self.api_endpoint = api_endpoint 169 170 self.auth_endpoint = auth_endpoint 171 172 self.timeout = request_timeout 173 174 self._setup_mlops_client() 175 176 self.pg_connection_string = pg_connection_string 177 178 self._current_workspace: Optional[Workspace] = None 179 180 # TODO: debate the names of these things 181 self._default_ws_name: Optional[str] = None 182 183 user_email = self.auth.user_email() 184 if user_email is not None: 185 self._default_ws_name = user_email + "_ws" 186 187 if interactive is not None: 188 self._interactive = interactive 189 elif ( 190 "JUPYTER_SVC_SERVICE_HOST" in os.environ or "JUPYTERHUB_HOST" in os.environ 191 ): 192 self._interactive = True 193 else: 194 self._interactive = False 195 196 self._time_format = time_format
Create a Client handle.
Parameters
- str api_endpoint: Host/port of the platform API endpoint
- str auth_endpoint: Host/port of the platform Keycloak instance
- int timeout: Max timeout of web requests, in seconds
- str auth_type: Authentication type to use. Can be one of: "none", "sso", "user_password".
- str pg_connection_string: Postgres connection string
- bool interactive: If provided and True, some calls will print additional human information, or won't when False. If not provided, interactive defaults to True if running inside Jupyter and False otherwise.
- str time_format: Preferred
strftime
format string for displaying timestamps in a human context.
101 @staticmethod 102 def get_urls( 103 auth_type: Optional[str], api_endpoint: str, auth_endpoint: str 104 ) -> Tuple[Optional[str], str, str]: 105 """Method to calculate the auth values specified as defaults, 106 as params or in ENV vars. 107 Made static to be testable without reaching out to SSO, etc.""" 108 109 if auth_type is None: 110 auth_type = os.environ.get(WALLAROO_SDK_AUTH_TYPE, None) 111 112 # ideally we'd set auth_endpoint to None default value but that would 113 # make the auth_endpoint type to be Optiona[str] which messes up 114 # a lot| of type hinting and I wanted to make minimal changes without a 115 # lot of 'warnings'. 116 if len(auth_endpoint.strip()) == 0: 117 auth_endpoint = ( 118 os.environ.get(WALLAROO_AUTH_URL) 119 or os.environ.get(WALLAROO_SDK_AUTH_ENDPOINT) 120 or "http://api-lb:8080" 121 ) 122 123 api_endpoint = os.environ.get(WALLAROO_URL, api_endpoint) 124 125 return auth_type, api_endpoint, auth_endpoint
Method to calculate the auth values specified as defaults, as params or in ENV vars. Made static to be testable without reaching out to SSO, etc.
258 def list_models(self) -> ModelsList: 259 """List all models on the platform. 260 261 :return: A list of all models on the platform. 262 :rtype: List[Model] 263 """ 264 id = self.get_current_workspace().id() 265 res = models_list.sync( 266 client=self.mlops(), 267 json_body=models_list_json_body.ModelsListJsonBody(id), 268 ) 269 270 if res is None: 271 raise Exception("Failed to list models") 272 273 if not isinstance(res, ModelsListResponse200): 274 raise Exception(res.msg) 275 276 return ModelsList([Models(client=self, data=v.to_dict()) for v in res.models])
List all models on the platform.
Returns
A list of all models on the platform.
278 def list_deployments(self) -> List[Deployment]: 279 """List all deployments (active or not) on the platform. 280 281 :return: A list of all deployments on the platform. 282 :rtype: List[Deployment] 283 """ 284 res = self._gql_client.execute( 285 gql.gql( 286 """ 287 query ListDeployments { 288 deployment { 289 id 290 deploy_id 291 deployed 292 deployment_model_configs { 293 model_config { 294 id 295 } 296 } 297 } 298 } 299 """ 300 ) 301 ) 302 return [Deployment(client=self, data=d) for d in res["deployment"]]
List all deployments (active or not) on the platform.
Returns
A list of all deployments on the platform.
330 def search_pipelines( 331 self, 332 search_term: Optional[str] = None, 333 deployed: Optional[bool] = None, 334 created_start: Optional["Datetime"] = None, 335 created_end: Optional["Datetime"] = None, 336 updated_start: Optional["Datetime"] = None, 337 updated_end: Optional["Datetime"] = None, 338 ) -> PipelineVariants: 339 """Search for pipelines. All parameters are optional, in which case the result is the same as 340 `list_pipelines()`. All times are strings to be parsed by `datetime.isoformat`. Example: 341 342 myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo") 343 344 :param str search_term: Will be matched against tags and model names. Example: "footag123". 345 :param bool deployed: Pipeline was deployed or not 346 :param str created_start: Pipeline was created at or after this time 347 :param str created_end: Pipeline was created at or before this time 348 :param str updated_start: Pipeline was updated at or before this time 349 :param str updated_end: Pipeline was updated at or before this time 350 351 :return: A list of all pipelines on the platform. 352 :rtype: List[Pipeline] 353 """ 354 return PipelineVariants( 355 self._search_pipelines( 356 search_term, 357 deployed, 358 None, 359 created_start, 360 created_end, 361 updated_start, 362 updated_end, 363 ) 364 )
Search for pipelines. All parameters are optional, in which case the result is the same as
list_pipelines()
. All times are strings to be parsed by datetime.isoformat
. Example:
myclient.search_pipelines(created_end='2022-04-19 13:17:59+00:00', search_term="foo")
Parameters
- str search_term: Will be matched against tags and model names. Example: "footag123".
- bool deployed: Pipeline was deployed or not
- str created_start: Pipeline was created at or after this time
- str created_end: Pipeline was created at or before this time
- str updated_start: Pipeline was updated at or before this time
- str updated_end: Pipeline was updated at or before this time
Returns
A list of all pipelines on the platform.
502 def search_my_models( 503 self, 504 search_term: Optional[str] = None, 505 uploaded_time_start: Optional["Datetime"] = None, 506 uploaded_time_end: Optional["Datetime"] = None, 507 ) -> ModelVersions: 508 """Search models owned by you 509 params: 510 search_term: Searches the following metadata: names, shas, versions, file names, and tags 511 uploaded_time_start: Inclusive time of upload 512 uploaded_time_end: Inclusive time of upload 513 """ 514 user_id = self.auth.user_id() 515 return ModelVersions( 516 self._search_models( 517 search_term=search_term, 518 user_id=user_id, 519 start=uploaded_time_start, 520 end=uploaded_time_end, 521 ) 522 )
Search models owned by you params: search_term: Searches the following metadata: names, shas, versions, file names, and tags uploaded_time_start: Inclusive time of upload uploaded_time_end: Inclusive time of upload
524 def search_models( 525 self, 526 search_term: Optional[str] = None, 527 uploaded_time_start: Optional["Datetime"] = None, 528 uploaded_time_end: Optional["Datetime"] = None, 529 ) -> ModelVersions: 530 """Search all models you have access to. 531 params: 532 search_term: Searches the following metadata: names, shas, versions, file names, and tags 533 uploaded_time_start: Inclusive time of upload 534 uploaded_time_end: Inclusive time of upload 535 """ 536 return ModelVersions( 537 self._search_models( 538 search_term=search_term, 539 start=uploaded_time_start, 540 end=uploaded_time_end, 541 ) 542 )
Search all models you have access to. params: search_term: Searches the following metadata: names, shas, versions, file names, and tags uploaded_time_start: Inclusive time of upload uploaded_time_end: Inclusive time of upload
603 def get_user_by_email(self, email: str) -> Optional[User]: 604 """Find a user by email""" 605 assert email is not None 606 escaped_email = quote_plus(email) 607 url = ( 608 f"{self.auth_endpoint}/auth/admin/realms/master/users?email={escaped_email}" 609 ) 610 headers = { 611 "Content-Type": "application/json", 612 "Authorization": self.auth._bearer_token_str(), 613 "User-Agent": _user_agent, 614 } 615 resp = requests.get(url, headers=headers, data={}) 616 jresp = resp.json() 617 return None if jresp == [] else User(client=self, data=jresp[0])
Find a user by email
619 def deactivate_user(self, email: str) -> None: 620 """Deactivates an existing user of the platform 621 622 Deactivated users cannot log into the platform. 623 Deactivated users do not count towards the number of allotted user seats from the license. 624 625 The Models and Pipelines owned by the deactivated user are not removed from the platform. 626 627 :param str email: The email address of the user to deactivate. 628 629 :return: None 630 :rtype: None 631 """ 632 633 if self.auth.user_email() == email: 634 raise Exception("A user may not deactive themselves.") 635 636 user = self.get_user_by_email(email) 637 638 if user is None: 639 raise EntityNotFoundError("User", {"email": email}) 640 641 if user.username() == "admin": 642 raise Exception("Admin user may not be deactivated.") 643 644 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 645 headers = { 646 "Content-Type": "application/json", 647 "Authorization": self.auth._bearer_token_str(), 648 "User-Agent": _user_agent, 649 } 650 651 # Get the current full user representation to return in the mutation due to keycloak bug 652 get_user_response = requests.get(url, headers=headers, data={}) 653 654 cur_user_rep = get_user_response.json() 655 cur_user_rep["enabled"] = False 656 657 resp = requests.put(url, headers=headers, json=cur_user_rep) 658 659 if resp.status_code != 204: 660 raise EntityNotFoundError("User", {"email": email}) 661 return None
Deactivates an existing user of the platform
Deactivated users cannot log into the platform. Deactivated users do not count towards the number of allotted user seats from the license.
The Models and Pipelines owned by the deactivated user are not removed from the platform.
Parameters
- str email: The email address of the user to deactivate.
Returns
None
663 def activate_user(self, email: str) -> None: 664 """Activates an existing user of the platform that had been previously deactivated. 665 666 Activated users can log into the platform. 667 668 :param str email: The email address of the user to activate. 669 670 :return: None 671 :rtype: None 672 """ 673 user = self.get_user_by_email(email) 674 675 if user is None: 676 raise EntityNotFoundError("User", {"email": email}) 677 678 url = f"{self.auth_endpoint}/auth/admin/realms/master/users/{user._id}" 679 680 headers = { 681 "Content-Type": "application/json", 682 "Authorization": self.auth._bearer_token_str(), 683 "User-Agent": _user_agent, 684 } 685 686 # Get the current full user representation to return in the mutation due to keycloak bug 687 get_user_response = requests.get(url, headers=headers, data={}) 688 689 cur_user_rep = get_user_response.json() 690 cur_user_rep["enabled"] = True 691 692 resp = requests.put(url, headers=headers, json=cur_user_rep) 693 694 if resp.status_code != 204: 695 raise EntityNotFoundError("User", {"email": email}) 696 return None
Activates an existing user of the platform that had been previously deactivated.
Activated users can log into the platform.
Parameters
- str email: The email address of the user to activate.
Returns
None
710 def list_users(self) -> List[User]: 711 """List of all Users on the platform 712 713 :return: A list of all Users on the platform. 714 :rtype: List[User] 715 """ 716 resp = User.list_users(auth=self.auth) 717 return [User(client=self, data=u) for u in resp]
List of all Users on the platform
Returns
A list of all Users on the platform.
719 def upload_model(self, name: str, path: Union[str, pathlib.Path]) -> Model: 720 """Upload a model defined by a file as a new model variant. 721 722 :param str model_name: The name of the model of which this is a variant. 723 Names must be ASCII alpha-numeric characters or dash (-) only. 724 :param Union[str, pathlib.Path] path: Path of the model file to upload. 725 :return: The created Model. 726 :rtype: Model 727 """ 728 729 _Visibility.PRIVATE 730 if isinstance(path, str): 731 path = pathlib.Path(path) 732 with path.open("rb") as f: 733 return self._upload_model_stream(name, {"filename": path.name}, f)
Upload a model defined by a file as a new model variant.
Parameters
- str model_name: The name of the model of which this is a variant. Names must be ASCII alpha-numeric characters or dash (-) only.
- Union[str, pathlib.Path] path: Path of the model file to upload.
Returns
The created Model.
750 def register_model_image(self, name: str, image: str) -> Model: 751 """Registers an MLFlow model as a new model. 752 753 :param str model_name: The name of the model of which this is a variant. 754 Names must be ASCII alpha-numeric characters or dash (-) only. 755 :param str image: Image name of the MLFlow model to register. 756 :return: The created Model. 757 :rtype: Model 758 """ 759 data = { 760 "image_path": image, 761 } 762 return self._upload_model(name, data)
Registers an MLFlow model as a new model.
Parameters
- str model_name: The name of the model of which this is a variant. Names must be ASCII alpha-numeric characters or dash (-) only.
- str image: Image name of the MLFlow model to register.
Returns
The created Model.
784 def model_by_name(self, model_class: str, model_name: str) -> Model: 785 """Fetch a Model by name. 786 787 :param str model_class: Name of the model class. 788 :param str model_name: Name of the variant within the specified model class. 789 :return: The Model with the corresponding model and variant name. 790 :rtype: Model 791 """ 792 res = self._gql_client.execute( 793 gql.gql( 794 """ 795 query ModelByName($model_id: String!, $model_version: String!) { 796 model(where: {_and: [{model_id: {_eq: $model_id}}, {model_version: {_eq: $model_version}}]}) { 797 id 798 model_id 799 model_version 800 } 801 } 802 """ 803 ), 804 variable_values={ 805 "model_id": model_class, 806 "model_version": model_name, 807 }, 808 ) 809 if not res["model"]: 810 raise EntityNotFoundError( 811 "Model", {"model_class": model_class, "model_name": model_name} 812 ) 813 return Model(client=self, data={"id": res["model"][0]["id"]})
Fetch a Model by name.
Parameters
- str model_class: Name of the model class.
- str model_name: Name of the variant within the specified model class.
Returns
The Model with the corresponding model and variant name.
815 def deployment_by_name(self, deployment_name: str) -> Deployment: 816 """Fetch a Deployment by name. 817 818 :param str deployment_name: Name of the deployment. 819 :return: The Deployment with the corresponding name. 820 :rtype: Deployment 821 """ 822 res = self._gql_client.execute( 823 gql.gql( 824 """ 825 query DeploymentByName($deployment_name: String!) { 826 deployment(where: {deploy_id: {_eq: $deployment_name}}) { 827 id 828 } 829 } 830 """ 831 ), 832 variable_values={ 833 "deployment_name": deployment_name, 834 }, 835 ) 836 if not res["deployment"]: 837 raise EntityNotFoundError( 838 "Deployment", {"deployment_name": deployment_name} 839 ) 840 return Deployment(client=self, data={"id": res["deployment"][0]["id"]})
Fetch a Deployment by name.
Parameters
- str deployment_name: Name of the deployment.
Returns
The Deployment with the corresponding name.
842 def pipelines_by_name(self, pipeline_name: str) -> List[Pipeline]: 843 """Fetch Pipelines by name. 844 845 :param str pipeline_name: Name of the pipeline. 846 :return: The Pipeline with the corresponding name. 847 :rtype: Pipeline 848 """ 849 res = self._gql_client.execute( 850 gql.gql( 851 """ 852 query PipelineByName($pipeline_name: String!) { 853 pipeline(where: {pipeline_id: {_eq: $pipeline_name}}, order_by: {created_at: desc}) { 854 id 855 } 856 } 857 """ 858 ), 859 variable_values={ 860 "pipeline_name": pipeline_name, 861 }, 862 ) 863 assert "pipeline" in res 864 length = len(res["pipeline"]) 865 if length < 1: 866 raise EntityNotFoundError("Pipeline", {"pipeline_name": pipeline_name}) 867 return [Pipeline(client=self, data={"id": p["id"]}) for p in res["pipeline"]]
Fetch Pipelines by name.
Parameters
- str pipeline_name: Name of the pipeline.
Returns
The Pipeline with the corresponding name.
869 def list_pipelines(self) -> List[Pipeline]: 870 """List all pipelines on the platform. 871 872 :return: A list of all pipelines on the platform. 873 :rtype: List[Pipeline] 874 """ 875 res = self._gql_client.execute( 876 gql.gql( 877 """ 878 query ListPipelines { 879 pipeline(order_by: {id: desc}) { 880 id 881 pipeline_tags { 882 tag { 883 id 884 tag 885 } 886 } 887 } 888 } 889 """ 890 ) 891 ) 892 return Pipelines([Pipeline(client=self, data=d) for d in res["pipeline"]])
List all pipelines on the platform.
Returns
A list of all pipelines on the platform.
894 def build_pipeline(self, pipeline_name: str) -> "Pipeline": 895 """Starts building a pipeline with the given `pipeline_name`, 896 returning a :py:PipelineConfigBuilder: 897 898 When completed, the pipeline can be uploaded with `.upload()` 899 900 :param pipeline_name string: Name of the pipeline, must be composed of ASCII 901 alpha-numeric characters plus dash (-). 902 """ 903 904 require_dns_compliance(pipeline_name) 905 906 _Visibility.PRIVATE 907 908 # TODO: Needs to handle visibility? 909 data = pipelines_create.sync( 910 client=self.mlops(), 911 json_body=pipelines_create_json_body.PipelinesCreateJsonBody( 912 pipeline_name, 913 self.get_current_workspace().id(), 914 pipelines_create_json_body_definition.PipelinesCreateJsonBodyDefinition.from_dict( 915 {} 916 ), 917 ), 918 ) 919 920 if data is None: 921 raise Exception("Failed to create pipeline") 922 923 if not isinstance(data, PipelinesCreateResponse200): 924 raise Exception(data.msg) 925 926 return Pipeline(client=self, data={"id": data.pipeline_pk_id})
Starts building a pipeline with the given pipeline_name
,
returning a :py:PipelineConfigBuilder:
When completed, the pipeline can be uploaded with .upload()
Parameters
- pipeline_name string: Name of the pipeline, must be composed of ASCII alpha-numeric characters plus dash (-).
994 def create_value_split_experiment( 995 self, 996 name: str, 997 meta_key: str, 998 default_model: ModelConfig, 999 challenger_models: List[Tuple[Any, ModelConfig]], 1000 ) -> Pipeline: 1001 """Creates a new PipelineVariant of a "value-split experiment" type. 1002 :param str name: Name of the Pipeline 1003 :param meta_key str: Inference input key on which to redirect inputs to 1004 experiment models. 1005 :param default_model ModelConfig: Model to send inferences by default. 1006 :param challenger_models List[Tuple[Any, ModelConfig]]: A list of 1007 meta_key values -> Models to send inferences. If the inference data 1008 referred to by meta_key is equal to one of the keys in this tuple, 1009 that inference is redirected to the corresponding model instead of 1010 the default model. 1011 """ 1012 args = [meta_key, default_model.model().name()] 1013 for v, m in challenger_models: 1014 args.append(v) 1015 args.append(m.model().name()) 1016 step = { 1017 "id": "metavalue_split", 1018 "operation": "map", 1019 "args": args, 1020 } 1021 definition = {"id": name, "steps": [step]} 1022 # TODO: This seems like a one-to-one replace, find appropriate test. 1023 data = self._gql_client.execute( 1024 gql.gql( 1025 """ 1026 mutation CreatePipeline( 1027 $pipeline_id: String, 1028 $version: String, 1029 $definition: jsonb, 1030 $workspace_id: bigint 1031 ) { 1032 insert_pipeline( 1033 objects: { 1034 pipeline_versions: { 1035 data: { definition: $definition } 1036 } 1037 pipeline_id: $pipeline_id 1038 } 1039 ) { 1040 returning { 1041 id 1042 } 1043 } 1044 } 1045 """ 1046 ), 1047 variable_values={ 1048 "pipeline_id": name, 1049 "definition": definition, 1050 "workspace_id": self.get_current_workspace().id(), 1051 }, 1052 ) 1053 return Pipeline( 1054 client=self, 1055 data=data["insert_pipeline"]["returning"][0], 1056 )
Creates a new PipelineVariant of a "value-split experiment" type.
Parameters
- str name: Name of the Pipeline
- meta_key str: Inference input key on which to redirect inputs to experiment models.
- default_model ModelConfig: Model to send inferences by default.
- challenger_models List[Tuple[Any, ModelConfig]]: A list of meta_key values -> Models to send inferences. If the inference data referred to by meta_key is equal to one of the keys in this tuple, that inference is redirected to the corresponding model instead of the default model.
1058 def get_logs(self, topic: str, limit: int = 100) -> Tuple[LogEntries, str]: 1059 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1060 headers = {"User-Agent": _user_agent} 1061 partitions = requests.get(base, auth=self.auth, headers=headers).json()[ 1062 "partitions" 1063 ] 1064 1065 iterator = { 1066 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1067 for k, span in partitions.items() 1068 } 1069 1070 response = requests.post( 1071 base + "/records", 1072 params={"limit": limit}, 1073 json=iterator, 1074 auth=self.auth, 1075 headers=headers, 1076 ).json() 1077 1078 return ( 1079 LogEntries([LogEntry(json.loads(l)) for l in response["records"]]), 1080 response.get("status", "None"), 1081 )
1083 def security_logs(self, limit: int) -> List[dict]: 1084 """This function is not available in this release""" 1085 raise RuntimeError("security_log() is not available in this release.")
This function is not available in this release
1087 def get_raw_logs( 1088 self, 1089 topic: str, 1090 start: Optional[datetime] = None, 1091 end: Optional[datetime] = None, 1092 limit: int = 100_000, 1093 parse: bool = False, 1094 verbose: bool = False, 1095 ) -> List[Dict[str, Any]]: 1096 """Gets logs from Plateau for a particular time window without attempting 1097 to convert them to Inference LogEntries. Logs can be returned as strings 1098 or the json parsed into lists and dicts. 1099 :param topic str: The name of the topic to query 1100 :param start Optional[datetime]: The start of the time window 1101 :param end Optional[datetime]: The end of the time window 1102 :param limit int: The number of records to retrieve. Note retrieving many 1103 records may be a performance bottleneck. 1104 :param parse bool: Wether to attempt to parse the string as a json object. 1105 :param verbose bool: Prints out info to help diagnose issues. 1106 """ 1107 1108 assert limit <= 1_000_000 1109 1110 base = self.api_endpoint + f"/v1/logs/topic/" + topic 1111 headers = {"User-Agent": _user_agent} 1112 resp = requests.get(base, auth=self.auth, headers=headers) 1113 if resp.status_code != 200: 1114 raise EntityNotFoundError( 1115 f"Could not get partitions {resp.text}", {"url": base} 1116 ) 1117 data = resp.json() 1118 partitions = data["partitions"] 1119 1120 if verbose: 1121 print(f"Got partitions {partitions}") 1122 1123 params: Dict[str, Any] = {"limit": limit} 1124 if start is not None: 1125 start_str = start.astimezone(tz=timezone.utc).isoformat() 1126 params["time.start"] = start_str 1127 if end is not None: 1128 end_str = end.astimezone(tz=timezone.utc).isoformat() 1129 params["time.end"] = end_str 1130 1131 next: Union[Any, None] = { 1132 k: max(0, span["end"] - math.floor(limit / len(partitions))) 1133 for k, span in partitions.items() 1134 } 1135 1136 if verbose: 1137 print("Using params: ", params) 1138 1139 records = [] 1140 while next is not None: 1141 response = requests.post( 1142 base + "/records", 1143 params=params, 1144 json=next, 1145 auth=self.auth, 1146 headers=headers, 1147 ) 1148 if response.status_code != 200: 1149 raise EntityNotFoundError( 1150 f"Could not get records {resp.text}", 1151 {"url": base, "params": str(params), "iterator": str(next)}, 1152 ) 1153 1154 if verbose: 1155 print("response: ", response) 1156 1157 result = response.json() 1158 result_records = result["records"] 1159 if len(result_records) > 0: 1160 records.extend(result_records) 1161 next = result["next"] 1162 else: 1163 next = None 1164 if parse: 1165 return [json.loads(r) for r in records] 1166 return records
Gets logs from Plateau for a particular time window without attempting to convert them to Inference LogEntries. Logs can be returned as strings or the json parsed into lists and dicts.
Parameters
- topic str: The name of the topic to query
- start Optional[datetime]: The start of the time window
- end Optional[datetime]: The end of the time window
- limit int: The number of records to retrieve. Note retrieving many records may be a performance bottleneck.
- parse bool: Wether to attempt to parse the string as a json object.
- verbose bool: Prints out info to help diagnose issues.
1168 def get_raw_pipeline_inference_logs( 1169 self, 1170 topic: str, 1171 start: datetime, 1172 end: datetime, 1173 model_name: Optional[str] = None, 1174 limit: int = 100_000, 1175 verbose: bool = False, 1176 ) -> List[Dict[str, Any]]: 1177 """Gets logs from Plateau for a particular time window and filters them for 1178 the model specified. 1179 :param pipeline_name str: The name/pipeline_id of the pipeline to query 1180 :param topic str: The name of the topic to query 1181 :param start Optional[datetime]: The start of the time window 1182 :param end Optional[datetime]: The end of the time window 1183 :param model_id: The name of the specific model to filter if any 1184 :param limit int: The number of records to retrieve. Note retrieving many 1185 records may be a performance bottleneck. 1186 :param verbose bool: Prints out info to help diagnose issues. 1187 """ 1188 logs = self.get_raw_logs( 1189 topic, 1190 start=start, 1191 end=end, 1192 limit=limit, 1193 parse=True, 1194 verbose=verbose, 1195 ) 1196 1197 if verbose: 1198 print(f"Got {len(logs)} initial logs") 1199 1200 if model_name: 1201 logs = [l for l in logs if l["model_name"] == model_name] 1202 1203 # inference results are a unix timestamp in millis - filter by that 1204 start_ts = int(start.timestamp() * 1000) 1205 end_ts = int(end.timestamp() * 1000) 1206 logs = [l for l in logs if start_ts <= l["time"] < end_ts] 1207 1208 return logs
Gets logs from Plateau for a particular time window and filters them for the model specified.
Parameters
- pipeline_name str: The name/pipeline_id of the pipeline to query
- topic str: The name of the topic to query
- start Optional[datetime]: The start of the time window
- end Optional[datetime]: The end of the time window
- model_id: The name of the specific model to filter if any
- limit int: The number of records to retrieve. Note retrieving many records may be a performance bottleneck.
- verbose bool: Prints out info to help diagnose issues.
1210 def get_pipeline_inference_dataframe( 1211 self, 1212 topic: str, 1213 start: datetime, 1214 end: datetime, 1215 model_name: Optional[str] = None, 1216 limit: int = 100_000, 1217 verbose=False, 1218 ) -> pd.DataFrame: 1219 logs = self.get_raw_pipeline_inference_logs( 1220 topic, start, end, model_name, limit, verbose 1221 ) 1222 return inference_logs_to_dataframe(logs)
1224 def get_assay_results( 1225 self, 1226 assay_id: int, 1227 start: datetime, 1228 end: datetime, 1229 ) -> AssayAnalysisList: 1230 """Gets the assay results for a particular time window, parses them, and returns an 1231 AssayAnalysisList of AssayAnalysis. 1232 :param assay_id int: The id of the assay we are looking for. 1233 :param start datetime: The start of the time window 1234 :param end datetime: The end of the time window 1235 """ 1236 res = assays_get_assay_results.sync( 1237 client=self.mlops(), 1238 json_body=assays_get_assay_results_json_body.AssaysGetAssayResultsJsonBody( 1239 assay_id, start, end 1240 ), 1241 ) 1242 1243 if res is None: 1244 raise Exception("Failed to list models") 1245 1246 if not isinstance(res, List): 1247 raise Exception(res.msg) 1248 1249 if len(res) != 0 and not isinstance( 1250 res[0], AssaysGetAssayResultsResponse200Item 1251 ): 1252 raise Exception("invalid response") 1253 1254 return AssayAnalysisList([AssayAnalysis(v.to_dict()) for v in res])
Gets the assay results for a particular time window, parses them, and returns an AssayAnalysisList of AssayAnalysis.
Parameters
- assay_id int: The id of the assay we are looking for.
- start datetime: The start of the time window
- end datetime: The end of the time window
1256 def build_assay( 1257 self, 1258 assay_name: str, 1259 pipeline: Pipeline, 1260 model_name: str, 1261 baseline_start: datetime, 1262 baseline_end: datetime, 1263 ) -> AssayBuilder: 1264 """Creates an AssayBuilder that can be used to configure and create 1265 Assays. 1266 :param assay_name str: Human friendly name for the assay 1267 :param pipeline Pipeline: The pipeline this assay will work on 1268 :param model_name str: The model that this assay will monitor 1269 :param baseline_start datetime: The start time for the inferences to 1270 use as the baseline 1271 :param baseline_end datetime: The end time of the baseline window. 1272 the baseline. Windows start immediately after the baseline window and 1273 are run at regular intervals continously until the assay is deactivated 1274 or deleted. 1275 """ 1276 assay_builder = AssayBuilder( 1277 self, 1278 assay_name, 1279 pipeline.id(), 1280 pipeline.name(), 1281 model_name, 1282 baseline_start, 1283 baseline_end, 1284 ) 1285 1286 return assay_builder
Creates an AssayBuilder that can be used to configure and create Assays.
Parameters
- assay_name str: Human friendly name for the assay
- pipeline Pipeline: The pipeline this assay will work on
- model_name str: The model that this assay will monitor
- baseline_start datetime: The start time for the inferences to use as the baseline
- baseline_end datetime: The end time of the baseline window. the baseline. Windows start immediately after the baseline window and are run at regular intervals continously until the assay is deactivated or deleted.
1288 def upload_assay(self, config: AssayConfig) -> int: 1289 """Creates an assay in the database. 1290 :param config AssayConfig: The configuration for the assay to create. 1291 :return assay_id: The identifier for the assay that was created. 1292 :rtype int 1293 """ 1294 data = assays_create.sync( 1295 client=self.mlops(), 1296 json_body=AssaysCreateJsonBody.from_dict( 1297 { 1298 **json.loads(config.to_json()), 1299 "created_at": datetime.now(timezone.utc).isoformat(), 1300 } 1301 ), 1302 ) 1303 1304 if data is None: 1305 raise Exception("Failed to create assay") 1306 1307 if not isinstance(data, AssaysCreateResponse200): 1308 raise Exception(data.msg) 1309 1310 return data.assay_id
Creates an assay in the database.
Parameters
- config AssayConfig: The configuration for the assay to create.
Returns
The identifier for the assay that was created. :rtype int
1312 def list_assays(self) -> List[Assay]: 1313 """List all assays on the platform. 1314 1315 :return: A list of all assays on the platform. 1316 :rtype: List[Assay] 1317 """ 1318 res = assays_list.sync(client=self.mlops(), json_body=AssaysListJsonBody(UNSET)) 1319 1320 if res is None: 1321 raise Exception("Failed to get assays") 1322 1323 if not isinstance(res, List): 1324 raise Exception(res.msg) 1325 1326 return Assays([Assay(client=self, data=v.to_dict()) for v in res])
List all assays on the platform.
Returns
A list of all assays on the platform.
1328 def create_tag(self, tag_text: str) -> Tag: 1329 """Create a new tag with the given text.""" 1330 assert tag_text is not None 1331 return Tag._create_tag(client=self, tag_text=tag_text)
Create a new tag with the given text.
1333 def create_workspace(self, workspace_name: str) -> Workspace: 1334 """Create a new workspace with the current user as its first owner. 1335 1336 :param str workspace_name: Name of the workspace, must be composed of ASCII 1337 alpha-numeric characters plus dash (-)""" 1338 assert workspace_name is not None 1339 require_dns_compliance(workspace_name) 1340 return Workspace._create_workspace(client=self, name=workspace_name)
Create a new workspace with the current user as its first owner.
Parameters
- str workspace_name: Name of the workspace, must be composed of ASCII alpha-numeric characters plus dash (-)
1342 def list_workspaces(self) -> List[Workspace]: 1343 """List all workspaces on the platform which this user has permission see. 1344 1345 :return: A list of all workspaces on the platform. 1346 :rtype: List[Workspace] 1347 """ 1348 res = workspaces_list.sync( 1349 client=self.mlops(), 1350 json_body=workspaces_list_json_body.WorkspacesListJsonBody(UNSET), 1351 ) 1352 1353 if res is None: 1354 raise Exception("Failed to get workspaces") 1355 1356 if not isinstance(res, WorkspacesListResponse200): 1357 raise Exception(res.msg) 1358 1359 return Workspaces( 1360 [Workspace(client=self, data=d.to_dict()) for d in res.workspaces] 1361 )
List all workspaces on the platform which this user has permission see.
Returns
A list of all workspaces on the platform.
1363 def set_current_workspace(self, workspace: Workspace) -> Workspace: 1364 """Any calls involving pipelines or models will use the given workspace from then on.""" 1365 assert workspace is not None 1366 if not isinstance(workspace, Workspace): 1367 raise TypeError("Workspace type was expected") 1368 1369 self._current_workspace = workspace 1370 return cast("Workspace", self._current_workspace)
Any calls involving pipelines or models will use the given workspace from then on.
1372 def get_current_workspace(self) -> Workspace: 1373 """Return the current workspace. See also `set_current_workspace`.""" 1374 if self._current_workspace is None: 1375 # Is there a default? Use that or make one. 1376 default_ws = Workspace._get_user_default_workspace(self) 1377 if default_ws is not None: 1378 self._current_workspace = default_ws 1379 else: 1380 self._current_workspace = Workspace._create_user_default_workspace(self) 1381 1382 return cast("Workspace", self._current_workspace)
Return the current workspace. See also set_current_workspace
.
1401 def convert_model( 1402 self, 1403 path: Union[str, pathlib.Path], 1404 source_type: ModelConversionSource, 1405 conversion_arguments: ModelConversionArguments, 1406 ) -> Model: 1407 """ 1408 Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments. 1409 Convert the model to onnx, and add to available models for a pipeline. 1410 1411 :param Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model. 1412 :param ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost. 1413 :param ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type. 1414 :return: An instance of the Model being converted to Onnx. 1415 :raises ModelConversionGenericException: On a generic failure, please contact our support for further assistance. 1416 :raises ModelConversionFailure: Failure in converting the model type. 1417 :raises ModelConversionUnsupportedType: Raised when the source type passed is not supported. 1418 :raises ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist. 1419 """ 1420 if isinstance(path, str): 1421 path = pathlib.Path(path) 1422 if not os.path.exists(path): 1423 raise ModelConversionSourceFileNotPresent( 1424 f"The provided source file: {path} can not be found." 1425 ) 1426 file_handle = open(path, "rb") 1427 files = [("files", file_handle)] 1428 base_url = self.api_endpoint 1429 workspace_id = self.get_current_workspace().id() 1430 common_headers = { 1431 "user_id": self.auth.user_id(), 1432 "user_email": self.auth.user_email(), 1433 "User-Agent": _user_agent, 1434 } 1435 curry_post = partial( 1436 requests.post, 1437 auth=self.auth, 1438 headers=common_headers, 1439 files=files, 1440 ) 1441 model_id: Union[int, None] = None 1442 1443 def _handle_response(http_response) -> int: 1444 http_response.raise_for_status() 1445 response_record = http_response.json() 1446 if response_record is not None and "model_id" in response_record: 1447 return int(response_record["model_id"]) 1448 else: 1449 raise ModelConversionFailure("Failed to convert keras model") 1450 1451 try: 1452 if source_type == ModelConversionSource.KERAS: 1453 assert ( 1454 type(conversion_arguments).__name__ 1455 == ConvertKerasArguments.__name__ 1456 ) 1457 ## This is not optimal but api-lb(envoy) -> python-api 1458 ## Gave a 50(3|4|2|0) on several other options 1459 ## dimensions being of type List[Any] 1460 ## dimensions being a base64 encoded json array 1461 ## dimensions being a comma seperated string that was proccessed server side. 1462 ## This casts the ConvertKerasArguments dict structure to a json string. 1463 ## Base64 encoding via a utf-8 binary conversion, passed as the url 1464 ## parameter config. Which is handlded server side. 1465 data = { 1466 **conversion_arguments.to_dict(), 1467 "workspace_id": workspace_id, 1468 } 1469 params = { 1470 "config": base64.b64encode(json.dumps(data).encode("utf-8")).decode( 1471 "utf-8" 1472 ) 1473 } 1474 response = curry_post(url=f"{base_url}/v1/convert/keras", params=params) 1475 model_id = _handle_response(response) 1476 elif source_type == ModelConversionSource.SKLEARN: 1477 assert ( 1478 type(conversion_arguments).__name__ 1479 == ConvertSKLearnArguments.__name__ 1480 ) 1481 response = curry_post( 1482 url=f"{base_url}/v1/convert/sklearn", 1483 params={ 1484 **conversion_arguments.to_dict(), 1485 "workspace_id": workspace_id, 1486 }, 1487 ) 1488 model_id = _handle_response(response) 1489 1490 elif source_type == ModelConversionSource.XGBOOST: 1491 assert ( 1492 type(conversion_arguments).__name__ == ConvertXGBoostArgs.__name__ 1493 ) 1494 response = curry_post( 1495 url=f"{base_url}/v1/convert/xgboost", 1496 params={ 1497 **conversion_arguments.to_dict(), 1498 "workspace_id": workspace_id, 1499 }, 1500 ) 1501 model_id = _handle_response(response) 1502 else: 1503 raise ModelConversionUnsupportedType( 1504 f"Unsupported model source type of {source_type} passed." 1505 ) 1506 if model_id is None: 1507 raise ModelConversionFailure("Failed to retrieve final model id") 1508 return Model(self, {"id": model_id}) 1509 except Exception: 1510 raise ModelConversionGenericException( 1511 "This model type could not be deployed successfully. Please contact your Wallaroo support team at community@wallaroo.ai" 1512 ) 1513 finally: 1514 file_handle.close()
Given an inbound source model, a model type (xgboost, keras, sklearn), and conversion arguments. Convert the model to onnx, and add to available models for a pipeline.
Parameters
- Union[str, pathlib.Path] path: The path to the model to convert, i.e. the source model.
- ModelConversionSource source: The origin model type i.e. keras, sklearn or xgboost.
- ModelConversionArguments conversion_arguments: A structure representing the arguments for converting a specific model type.
Returns
An instance of the Model being converted to Onnx.
Raises
- ModelConversionGenericException: On a generic failure, please contact our support for further assistance.
- ModelConversionFailure: Failure in converting the model type.
- ModelConversionUnsupportedType: Raised when the source type passed is not supported.
- ModelConversionSourceFileNotPresent: Raised when the passed source file does not exist.