wallaroo.tag
1from typing import TYPE_CHECKING, Any, Dict, List, cast 2 3from . import queries 4from .model import Model 5from .object import * 6from .pipeline import Pipeline 7 8if TYPE_CHECKING: 9 # Imports that happen below in methods to fix circular import dependency 10 # issues need to also be specified here to satisfy mypy type checking. 11 from .client import Client 12 13 14class Tag(Object): 15 """Tags that may be attached to models and pipelines.""" 16 17 def __init__( 18 self, client: Optional["Client"], data: Dict[str, Any], standalone=False 19 ) -> None: 20 self.client = client 21 super().__init__( 22 gql_client=client._gql_client if client is not None else None, 23 data=data, 24 standalone=standalone, 25 ) 26 27 def __repr__(self) -> str: 28 return str( 29 { 30 "id": self.id(), 31 "tag": self.tag(), 32 "models": self.model_tags(), 33 "pipelines": self.pipeline_tags(), 34 } 35 ) 36 37 @staticmethod 38 def _create_tag(client, tag_text: str): 39 res = client._gql_client.execute( 40 gql.gql(queries.named("CreateTag")), 41 variable_values={ 42 "tag_text": tag_text, 43 }, 44 ) 45 return Tag(client, res["insert_tag"]["returning"][0]) 46 47 def _fill(self, data: Dict[str, Any]) -> None: 48 """Fills an object given a response dictionary from the GraphQL API. 49 50 Only the primary key member must be present; other members will be 51 filled in via rehydration if their corresponding member function is 52 called. 53 """ 54 for required_attribute in ["id"]: 55 if required_attribute not in data: 56 raise RequiredAttributeMissing( 57 self.__class__.__name__, required_attribute 58 ) 59 # Required 60 self._id = data["id"] 61 62 # Optional 63 self._tag = value_if_present(data, "tag") 64 65 # Optional 66 self._model_tags = ( 67 [Model(self.client, model_tag["model"]) for model_tag in data["model_tags"]] 68 if "model_tags" in data 69 else DehydratedValue() 70 ) 71 72 # Optional 73 self._pipeline_tags = ( 74 [ 75 Pipeline(self.client, pipeline_tag["pipeline"]) 76 for pipeline_tag in data["pipeline_tags"] 77 ] 78 if "pipeline_tags" in data 79 else DehydratedValue() 80 ) 81 82 def _fetch_attributes(self) -> Dict[str, Any]: 83 """Fetches all member data from the GraphQL API.""" 84 return self._gql_client.execute( 85 gql.gql( 86 f""" 87 query TagById {{ 88 tag_by_pk(id: {self._id}) {{ 89 id 90 tag 91 model_tags {{ 92 model {{ 93 id 94 model_id 95 models_pk_id 96 model_version 97 98 }} 99 }} 100 pipeline_tags {{ 101 pipeline {{ 102 id 103 pipeline_id 104 pipeline_versions {{ 105 id 106 version 107 }} 108 }} 109 }} 110 }} 111 }} 112 """ 113 ) 114 )["tag_by_pk"] 115 116 def id(self) -> int: 117 return self._id 118 119 @rehydrate("_tag") 120 def tag(self) -> str: 121 return cast(str, self._tag) 122 123 @rehydrate("_model_tags") 124 def model_tags(self) -> List[Model]: 125 return cast(List[Model], self._model_tags) 126 127 @rehydrate("_pipeline_tags") 128 def pipeline_tags(self) -> List[Pipeline]: 129 return cast(List[Pipeline], self._pipeline_tags) 130 131 def list_models(self) -> List[Model]: 132 """Lists the models this tag is on.""" 133 res = self._gql_client.execute( 134 gql.gql( 135 """ 136 query ModelsByTagId($tag_id: bigint!){ 137 tag_by_pk(id:$tag_id){ 138 model_tags { 139 model { 140 id 141 model_id 142 model_version 143 sha 144 file_name 145 updated_at 146 visibility 147 } 148 } 149 } 150 } 151 """ 152 ), 153 variable_values={ 154 "tag_id": self._id, 155 }, 156 ) 157 list_of_models = [] 158 if res["tag_by_pk"]: 159 for v in res["tag_by_pk"]["model_tags"]: 160 list_of_models.append(Model(client=self.client, data=v["model"])) 161 return list_of_models 162 163 def add_to_model(self, model_id: int): 164 data = self._gql_client.execute( 165 gql.gql( 166 """ 167 mutation AddTagToModel($model_id: bigint!, $tag_id: bigint!) { 168 insert_model_tag(objects: { 169 model_id : $model_id, 170 tag_id: $tag_id 171 }) { 172 returning { 173 model_id 174 tag_id 175 } 176 } 177 } 178 """ 179 ), 180 variable_values={ 181 "model_id": model_id, 182 "tag_id": self._id, 183 }, 184 ) 185 return data["insert_model_tag"]["returning"][0] 186 187 def remove_from_model(self, model_id: int): 188 data = self._gql_client.execute( 189 gql.gql( 190 """ 191 mutation RemoveTagFromModel($model_id: bigint!, $tag_id: bigint!) { 192 delete_model_tag( 193 where: { 194 _and: [ 195 { model_id: {_eq: $model_id} } 196 { tag_id: {_eq: $tag_id} } 197 ] 198 } 199 ) 200 { 201 returning { 202 model_id 203 tag_id 204 } 205 206 } 207 208 } 209 """ 210 ), 211 variable_values={ 212 "model_id": model_id, 213 "tag_id": self._id, 214 }, 215 ) 216 return data["delete_model_tag"]["returning"][0] 217 218 def list_pipelines(self) -> List[Pipeline]: 219 """Lists the pipelines this tag is on.""" 220 res = self._gql_client.execute( 221 gql.gql( 222 """ 223 query PipelinesByTagId($tag_id: bigint!){ 224 tag_by_pk(id:$tag_id){ 225 pipeline_tags { 226 pipeline { 227 id 228 pipeline_id 229 created_at 230 updated_at 231 owner_id 232 } 233 } 234 } 235 } 236 """ 237 ), 238 variable_values={ 239 "tag_id": self._id, 240 }, 241 ) 242 list_of_pipelines = [] 243 if res["tag_by_pk"]: 244 for v in res["tag_by_pk"]["pipeline_tags"]: 245 list_of_pipelines.append( 246 Pipeline(client=self.client, data=v["pipeline"]) 247 ) 248 return list_of_pipelines 249 250 def add_to_pipeline(self, pipeline_id: int): 251 data = self._gql_client.execute( 252 gql.gql( 253 """ 254 mutation AddTagToPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 255 insert_pipeline_tag(objects: { 256 pipeline_pk_id : $pipeline_id, 257 tag_pk_id: $tag_id 258 }) { 259 returning { 260 pipeline_pk_id 261 tag_pk_id 262 } 263 } 264 } 265 """ 266 ), 267 variable_values={ 268 "pipeline_id": pipeline_id, 269 "tag_id": self._id, 270 }, 271 ) 272 return data["insert_pipeline_tag"]["returning"][0] 273 274 def remove_from_pipeline(self, pipeline_id: int): 275 data = self._gql_client.execute( 276 gql.gql( 277 """ 278 mutation RemoveTagFromPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 279 delete_pipeline_tag( 280 where: { 281 _and: [ 282 { pipeline_pk_id: {_eq: $pipeline_id} } 283 { tag_pk_id: {_eq: $tag_id} } 284 ] 285 } 286 ) 287 { 288 returning { 289 pipeline_pk_id 290 tag_pk_id 291 } 292 293 } 294 295 } 296 """ 297 ), 298 variable_values={ 299 "pipeline_id": pipeline_id, 300 "tag_id": self._id, 301 }, 302 ) 303 return data["delete_pipeline_tag"]["returning"][0] 304 305 306class Tags(List[Tag]): 307 """Wraps a list of tags for display in a display-aware environment like Jupyter.""" 308 309 def _repr_html_(self) -> str: 310 def row(tag): 311 models = [model for model in tag.model_tags()] 312 pipelines = [pipeline for pipeline in tag.pipeline_tags()] 313 model_ids = list(set(m.models_pk_id() for m in models)) 314 models_dict = {} 315 for m in models: 316 if m.models_pk_id() in model_ids: 317 if m.name() not in models_dict: 318 models_dict[m.name()] = [] 319 models_dict[m.name()].append(m.version()) 320 321 return ( 322 "<tr>" 323 + f"<td>{tag.id()}</td>" 324 + f"<td>{tag.tag()}</td>" 325 + f"<td>{[(key, value) for key, value in models_dict.items()]}</td>" 326 + f"<td>{[(p.name(), [pv.name() for pv in p.variants()]) for p in pipelines]}</td>" 327 + "</tr>" 328 ) 329 330 fields = ["id", "tag", "models", "pipelines"] 331 332 if not self: 333 return "(no tags)" 334 else: 335 return ( 336 "<table>" 337 + "<tr><th>" 338 + "</th><th>".join(fields) 339 + "</th></tr>" 340 + ("".join([row(p) for p in self])) 341 + "</table>" 342 )
15class Tag(Object): 16 """Tags that may be attached to models and pipelines.""" 17 18 def __init__( 19 self, client: Optional["Client"], data: Dict[str, Any], standalone=False 20 ) -> None: 21 self.client = client 22 super().__init__( 23 gql_client=client._gql_client if client is not None else None, 24 data=data, 25 standalone=standalone, 26 ) 27 28 def __repr__(self) -> str: 29 return str( 30 { 31 "id": self.id(), 32 "tag": self.tag(), 33 "models": self.model_tags(), 34 "pipelines": self.pipeline_tags(), 35 } 36 ) 37 38 @staticmethod 39 def _create_tag(client, tag_text: str): 40 res = client._gql_client.execute( 41 gql.gql(queries.named("CreateTag")), 42 variable_values={ 43 "tag_text": tag_text, 44 }, 45 ) 46 return Tag(client, res["insert_tag"]["returning"][0]) 47 48 def _fill(self, data: Dict[str, Any]) -> None: 49 """Fills an object given a response dictionary from the GraphQL API. 50 51 Only the primary key member must be present; other members will be 52 filled in via rehydration if their corresponding member function is 53 called. 54 """ 55 for required_attribute in ["id"]: 56 if required_attribute not in data: 57 raise RequiredAttributeMissing( 58 self.__class__.__name__, required_attribute 59 ) 60 # Required 61 self._id = data["id"] 62 63 # Optional 64 self._tag = value_if_present(data, "tag") 65 66 # Optional 67 self._model_tags = ( 68 [Model(self.client, model_tag["model"]) for model_tag in data["model_tags"]] 69 if "model_tags" in data 70 else DehydratedValue() 71 ) 72 73 # Optional 74 self._pipeline_tags = ( 75 [ 76 Pipeline(self.client, pipeline_tag["pipeline"]) 77 for pipeline_tag in data["pipeline_tags"] 78 ] 79 if "pipeline_tags" in data 80 else DehydratedValue() 81 ) 82 83 def _fetch_attributes(self) -> Dict[str, Any]: 84 """Fetches all member data from the GraphQL API.""" 85 return self._gql_client.execute( 86 gql.gql( 87 f""" 88 query TagById {{ 89 tag_by_pk(id: {self._id}) {{ 90 id 91 tag 92 model_tags {{ 93 model {{ 94 id 95 model_id 96 models_pk_id 97 model_version 98 99 }} 100 }} 101 pipeline_tags {{ 102 pipeline {{ 103 id 104 pipeline_id 105 pipeline_versions {{ 106 id 107 version 108 }} 109 }} 110 }} 111 }} 112 }} 113 """ 114 ) 115 )["tag_by_pk"] 116 117 def id(self) -> int: 118 return self._id 119 120 @rehydrate("_tag") 121 def tag(self) -> str: 122 return cast(str, self._tag) 123 124 @rehydrate("_model_tags") 125 def model_tags(self) -> List[Model]: 126 return cast(List[Model], self._model_tags) 127 128 @rehydrate("_pipeline_tags") 129 def pipeline_tags(self) -> List[Pipeline]: 130 return cast(List[Pipeline], self._pipeline_tags) 131 132 def list_models(self) -> List[Model]: 133 """Lists the models this tag is on.""" 134 res = self._gql_client.execute( 135 gql.gql( 136 """ 137 query ModelsByTagId($tag_id: bigint!){ 138 tag_by_pk(id:$tag_id){ 139 model_tags { 140 model { 141 id 142 model_id 143 model_version 144 sha 145 file_name 146 updated_at 147 visibility 148 } 149 } 150 } 151 } 152 """ 153 ), 154 variable_values={ 155 "tag_id": self._id, 156 }, 157 ) 158 list_of_models = [] 159 if res["tag_by_pk"]: 160 for v in res["tag_by_pk"]["model_tags"]: 161 list_of_models.append(Model(client=self.client, data=v["model"])) 162 return list_of_models 163 164 def add_to_model(self, model_id: int): 165 data = self._gql_client.execute( 166 gql.gql( 167 """ 168 mutation AddTagToModel($model_id: bigint!, $tag_id: bigint!) { 169 insert_model_tag(objects: { 170 model_id : $model_id, 171 tag_id: $tag_id 172 }) { 173 returning { 174 model_id 175 tag_id 176 } 177 } 178 } 179 """ 180 ), 181 variable_values={ 182 "model_id": model_id, 183 "tag_id": self._id, 184 }, 185 ) 186 return data["insert_model_tag"]["returning"][0] 187 188 def remove_from_model(self, model_id: int): 189 data = self._gql_client.execute( 190 gql.gql( 191 """ 192 mutation RemoveTagFromModel($model_id: bigint!, $tag_id: bigint!) { 193 delete_model_tag( 194 where: { 195 _and: [ 196 { model_id: {_eq: $model_id} } 197 { tag_id: {_eq: $tag_id} } 198 ] 199 } 200 ) 201 { 202 returning { 203 model_id 204 tag_id 205 } 206 207 } 208 209 } 210 """ 211 ), 212 variable_values={ 213 "model_id": model_id, 214 "tag_id": self._id, 215 }, 216 ) 217 return data["delete_model_tag"]["returning"][0] 218 219 def list_pipelines(self) -> List[Pipeline]: 220 """Lists the pipelines this tag is on.""" 221 res = self._gql_client.execute( 222 gql.gql( 223 """ 224 query PipelinesByTagId($tag_id: bigint!){ 225 tag_by_pk(id:$tag_id){ 226 pipeline_tags { 227 pipeline { 228 id 229 pipeline_id 230 created_at 231 updated_at 232 owner_id 233 } 234 } 235 } 236 } 237 """ 238 ), 239 variable_values={ 240 "tag_id": self._id, 241 }, 242 ) 243 list_of_pipelines = [] 244 if res["tag_by_pk"]: 245 for v in res["tag_by_pk"]["pipeline_tags"]: 246 list_of_pipelines.append( 247 Pipeline(client=self.client, data=v["pipeline"]) 248 ) 249 return list_of_pipelines 250 251 def add_to_pipeline(self, pipeline_id: int): 252 data = self._gql_client.execute( 253 gql.gql( 254 """ 255 mutation AddTagToPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 256 insert_pipeline_tag(objects: { 257 pipeline_pk_id : $pipeline_id, 258 tag_pk_id: $tag_id 259 }) { 260 returning { 261 pipeline_pk_id 262 tag_pk_id 263 } 264 } 265 } 266 """ 267 ), 268 variable_values={ 269 "pipeline_id": pipeline_id, 270 "tag_id": self._id, 271 }, 272 ) 273 return data["insert_pipeline_tag"]["returning"][0] 274 275 def remove_from_pipeline(self, pipeline_id: int): 276 data = self._gql_client.execute( 277 gql.gql( 278 """ 279 mutation RemoveTagFromPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 280 delete_pipeline_tag( 281 where: { 282 _and: [ 283 { pipeline_pk_id: {_eq: $pipeline_id} } 284 { tag_pk_id: {_eq: $tag_id} } 285 ] 286 } 287 ) 288 { 289 returning { 290 pipeline_pk_id 291 tag_pk_id 292 } 293 294 } 295 296 } 297 """ 298 ), 299 variable_values={ 300 "pipeline_id": pipeline_id, 301 "tag_id": self._id, 302 }, 303 ) 304 return data["delete_pipeline_tag"]["returning"][0]
Tags that may be attached to models and pipelines.
Tag( client: Optional[wallaroo.client.Client], data: Dict[str, Any], standalone=False)
18 def __init__( 19 self, client: Optional["Client"], data: Dict[str, Any], standalone=False 20 ) -> None: 21 self.client = client 22 super().__init__( 23 gql_client=client._gql_client if client is not None else None, 24 data=data, 25 standalone=standalone, 26 )
Base constructor.
Each object requires:
- a GraphQL client - in order to fill its missing members dynamically
- an initial data blob - typically from unserialized JSON, contains at
- least the data for required members (typically the object's primary key) and optionally other data members.
def
tag(*args, **kwargs):
41 def wrapper(*args, **kwargs): 42 obj = args[0] 43 if not getattr(obj, "_standalone", None): 44 present = getattr(obj, attr) != DehydratedValue() 45 # Uncomment to debug while testing 46 # print( 47 # "rehydrate: {} -> {}".format( 48 # attr, "present" if present else "not present" 49 # ) 50 # ) 51 if not present: 52 obj._rehydrate() 53 result = fn(*args, **kwargs) 54 return result
132 def list_models(self) -> List[Model]: 133 """Lists the models this tag is on.""" 134 res = self._gql_client.execute( 135 gql.gql( 136 """ 137 query ModelsByTagId($tag_id: bigint!){ 138 tag_by_pk(id:$tag_id){ 139 model_tags { 140 model { 141 id 142 model_id 143 model_version 144 sha 145 file_name 146 updated_at 147 visibility 148 } 149 } 150 } 151 } 152 """ 153 ), 154 variable_values={ 155 "tag_id": self._id, 156 }, 157 ) 158 list_of_models = [] 159 if res["tag_by_pk"]: 160 for v in res["tag_by_pk"]["model_tags"]: 161 list_of_models.append(Model(client=self.client, data=v["model"])) 162 return list_of_models
Lists the models this tag is on.
def
add_to_model(self, model_id: int):
164 def add_to_model(self, model_id: int): 165 data = self._gql_client.execute( 166 gql.gql( 167 """ 168 mutation AddTagToModel($model_id: bigint!, $tag_id: bigint!) { 169 insert_model_tag(objects: { 170 model_id : $model_id, 171 tag_id: $tag_id 172 }) { 173 returning { 174 model_id 175 tag_id 176 } 177 } 178 } 179 """ 180 ), 181 variable_values={ 182 "model_id": model_id, 183 "tag_id": self._id, 184 }, 185 ) 186 return data["insert_model_tag"]["returning"][0]
def
remove_from_model(self, model_id: int):
188 def remove_from_model(self, model_id: int): 189 data = self._gql_client.execute( 190 gql.gql( 191 """ 192 mutation RemoveTagFromModel($model_id: bigint!, $tag_id: bigint!) { 193 delete_model_tag( 194 where: { 195 _and: [ 196 { model_id: {_eq: $model_id} } 197 { tag_id: {_eq: $tag_id} } 198 ] 199 } 200 ) 201 { 202 returning { 203 model_id 204 tag_id 205 } 206 207 } 208 209 } 210 """ 211 ), 212 variable_values={ 213 "model_id": model_id, 214 "tag_id": self._id, 215 }, 216 ) 217 return data["delete_model_tag"]["returning"][0]
219 def list_pipelines(self) -> List[Pipeline]: 220 """Lists the pipelines this tag is on.""" 221 res = self._gql_client.execute( 222 gql.gql( 223 """ 224 query PipelinesByTagId($tag_id: bigint!){ 225 tag_by_pk(id:$tag_id){ 226 pipeline_tags { 227 pipeline { 228 id 229 pipeline_id 230 created_at 231 updated_at 232 owner_id 233 } 234 } 235 } 236 } 237 """ 238 ), 239 variable_values={ 240 "tag_id": self._id, 241 }, 242 ) 243 list_of_pipelines = [] 244 if res["tag_by_pk"]: 245 for v in res["tag_by_pk"]["pipeline_tags"]: 246 list_of_pipelines.append( 247 Pipeline(client=self.client, data=v["pipeline"]) 248 ) 249 return list_of_pipelines
Lists the pipelines this tag is on.
def
add_to_pipeline(self, pipeline_id: int):
251 def add_to_pipeline(self, pipeline_id: int): 252 data = self._gql_client.execute( 253 gql.gql( 254 """ 255 mutation AddTagToPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 256 insert_pipeline_tag(objects: { 257 pipeline_pk_id : $pipeline_id, 258 tag_pk_id: $tag_id 259 }) { 260 returning { 261 pipeline_pk_id 262 tag_pk_id 263 } 264 } 265 } 266 """ 267 ), 268 variable_values={ 269 "pipeline_id": pipeline_id, 270 "tag_id": self._id, 271 }, 272 ) 273 return data["insert_pipeline_tag"]["returning"][0]
def
remove_from_pipeline(self, pipeline_id: int):
275 def remove_from_pipeline(self, pipeline_id: int): 276 data = self._gql_client.execute( 277 gql.gql( 278 """ 279 mutation RemoveTagFromPipeline($pipeline_id: bigint!, $tag_id: bigint!) { 280 delete_pipeline_tag( 281 where: { 282 _and: [ 283 { pipeline_pk_id: {_eq: $pipeline_id} } 284 { tag_pk_id: {_eq: $tag_id} } 285 ] 286 } 287 ) 288 { 289 returning { 290 pipeline_pk_id 291 tag_pk_id 292 } 293 294 } 295 296 } 297 """ 298 ), 299 variable_values={ 300 "pipeline_id": pipeline_id, 301 "tag_id": self._id, 302 }, 303 ) 304 return data["delete_pipeline_tag"]["returning"][0]
class
Tags(typing.List[wallaroo.tag.Tag]):
307class Tags(List[Tag]): 308 """Wraps a list of tags for display in a display-aware environment like Jupyter.""" 309 310 def _repr_html_(self) -> str: 311 def row(tag): 312 models = [model for model in tag.model_tags()] 313 pipelines = [pipeline for pipeline in tag.pipeline_tags()] 314 model_ids = list(set(m.models_pk_id() for m in models)) 315 models_dict = {} 316 for m in models: 317 if m.models_pk_id() in model_ids: 318 if m.name() not in models_dict: 319 models_dict[m.name()] = [] 320 models_dict[m.name()].append(m.version()) 321 322 return ( 323 "<tr>" 324 + f"<td>{tag.id()}</td>" 325 + f"<td>{tag.tag()}</td>" 326 + f"<td>{[(key, value) for key, value in models_dict.items()]}</td>" 327 + f"<td>{[(p.name(), [pv.name() for pv in p.variants()]) for p in pipelines]}</td>" 328 + "</tr>" 329 ) 330 331 fields = ["id", "tag", "models", "pipelines"] 332 333 if not self: 334 return "(no tags)" 335 else: 336 return ( 337 "<table>" 338 + "<tr><th>" 339 + "</th><th>".join(fields) 340 + "</th></tr>" 341 + ("".join([row(p) for p in self])) 342 + "</table>" 343 )
Wraps a list of tags for display in a display-aware environment like Jupyter.
Inherited Members
- builtins.list
- list
- clear
- copy
- append
- insert
- extend
- pop
- remove
- index
- count
- reverse
- sort