wallaroo.explainability
1from dataclasses import asdict, dataclass 2from datetime import datetime, timezone 3from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union 4from uuid import UUID 5 6import matplotlib.pyplot as plt 7import numpy as np 8import pandas as pd 9import seaborn as sns 10from scipy.stats import norm 11 12from wallaroo.unwrap import unwrap 13 14if TYPE_CHECKING: 15 # Imports that happen below in methods to fix circular import dependency 16 # issues need to also be specified here to satisfy mypy type checking. 17 pass 18 19NDArray = np.ndarray 20 21REF_MEAN_ABS = "ref_mean_abs" 22REF_MAX_ABS = "ref_max_abs" 23WINDOW_MEAN_ABS = "window_mean_abs" 24WINDOW_MAX_ABS = "window_max_abs" 25 26 27# To keep colors consistent across plots explicitly list the ones we use. 28colors = [ 29 "red", 30 "green", 31 "blue", 32 "gold", 33 "fuchsia", 34 "tan", 35 "yellowgreen", 36 "darkkhaki", 37 "brown", 38 "azure", 39 "darkorchid", 40 "darkslategrey", 41 "moccasin", 42 "sandybrown", 43 "aquamarine", 44 "springgreen", 45 "sienna", 46 "cyan", 47 "mediumpurple", 48 "midnightblue", 49 "plum", 50 "deeppink", 51 "crimson", 52 "lightgoldenrodyellow", 53 "lightgreen", 54 "lawngreen", 55 "bisque", 56 "steelblue", 57 "cornflowerblue", 58 "lightblue", 59 "darksalmon", 60 "lightslategrey", 61] 62 63 64def truncate(v: Any, num_char: int = 256) -> str: 65 s = str(v) 66 if len(s) < num_char: 67 return s 68 else: 69 postfix = " ..." 70 return f"{s[:num_char - len(postfix)]}{postfix}" 71 72 73@dataclass 74class WindowRequestInput: 75 start: Optional[str] # These are strings in our request but 76 end: Optional[str] # should be interpreted as 3339 formated datetimes 77 num_samples: Optional[int] 78 79 80class ExplainabilityConfigList(list): 81 def _repr_html_(self): 82 rows = [ 83 f""" 84 <tr"> 85 <td>{exp_config.id}</td> 86 <td>{exp_config.status}</td> 87 </tr> 88 """ 89 for exp_config in self 90 ] 91 table = """<table> 92 <tr> 93 <th>Id</th> 94 <th>Status</th> 95 </tr> 96 {0} 97 </table>""".format( 98 "\n".join(rows) 99 ) 100 101 return table 102 103 104class ExplainabilityRequestList(list): 105 """Adds a _repr_html_ to a list of explainability requests.""" 106 107 def _repr_html_(self): 108 """Assuming all items are explainabilty requests generates an HTML 109 table for jupyter.""" 110 rows = [ 111 f""" 112 <tr"> 113 <td>{exp_req.id}</td> 114 <td>{exp_req.status}</td> 115 <td>{exp_req.reference_config}</td> 116 <td>{exp_req.window_config}</td> 117 <td>{exp_req.use_adhoc_data}</td> 118 </tr> 119 """ 120 for exp_req in self 121 ] 122 table = """<table> 123 <tr> 124 <th>Id</th> 125 <th>Status</th> 126 <th>Reference Data</th> 127 <th>Window Data</th> 128 <th>Adhoc</th> 129 </tr> 130 {0} 131 </table>""".format( 132 "\n".join(rows) 133 ) 134 135 return table 136 137 138@dataclass 139class FeatureBounds: 140 min: float 141 max: float 142 xs: List[float] 143 144 145@dataclass 146class ExplainabilityConfig: 147 """This class specifies an explainability configuration that can be used 148 to later submit explainability requests which cause the server to do the 149 analysis and create explainability results. 150 151 ExplainabilityConfig are necessary to ensure the explainability pipeline 152 is created and is deployed and so that various requests are processed in 153 the same manner and can be compared. 154 155 id, status, feature_bounds and reference_pipeline_version are optional 156 and will be filled out when processed and saved to the database. 157 158 workspace id must match the users/pipelines workspace and 159 reference_pipeline_version must refer to a valid pipeline version that the 160 user has access too. 161 162 num_points specifies how many samples to take when varying the values of a 163 feature for the PDP/ICE analysis through the feature_bounds. 164 165 feature_names are convinince for the user. output_names is not currently used. 166 """ 167 168 id: Optional[UUID] 169 workspace_id: int 170 status: Optional[Dict[str, Any]] 171 reference_pipeline_version: str 172 explainability_pipeline_version: Optional[str] 173 num_points: int = 10 174 feature_names: Optional[List[str]] = None 175 feature_bounds: Optional[Dict[str, FeatureBounds]] = None 176 output_names: Optional[List[str]] = None 177 178 def _repr_html_(self): 179 fields = [ 180 f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>" 181 for k, v in asdict(self).items() 182 ] 183 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 184 185 def list_explainability_requests(self) -> List["ExplainabilityRequest"]: 186 """List the explainability requests we've created.""" 187 188 client = self.client # type: ignore 189 result = client._post_rest_api_json( 190 f"v1/api/explainability/list_requests", 191 {"explainability_config_id": self.id}, 192 ) 193 erl = [ExplainabilityRequest(**ec) for ec in result] 194 for er in erl: 195 er.client = client # type: ignore 196 return ExplainabilityRequestList(erl) 197 198 def get_explainability_request( 199 self, expr: Union[str, "ExplainabilityConfig"] 200 ) -> Optional["ExplainabilityRequest"]: 201 """Get the full explainability result whether completed or not.""" 202 203 if isinstance(expr, str): 204 explain_id = expr 205 else: 206 explain_id = str(expr.id) 207 208 client = self.client # type: ignore 209 result = client._post_rest_api_json( 210 f"v1/api/explainability/get_request", 211 {"explainability_request_id": explain_id}, 212 ) 213 214 exp_cfg = ExplainabilityRequest(**result) 215 exp_cfg.client = client # type: ignore 216 return exp_cfg 217 218 def submit_explainability_request( 219 self, 220 reference_start: Optional[datetime] = None, 221 reference_end: Optional[datetime] = None, 222 reference_num_samples: Optional[int] = None, 223 use_reference: Optional[bool] = True, 224 window_start: Optional[datetime] = None, 225 window_end: Optional[datetime] = None, 226 window_num_samples: Optional[int] = None, 227 adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None, 228 ): 229 """Submit an analysis on reference or adhoc data using a particular config""" 230 231 reference_config = None 232 if use_reference: 233 reference_config = WindowRequestInput( 234 start=maybe_format_date(reference_start), 235 end=maybe_format_date(reference_end), 236 num_samples=reference_num_samples, 237 ) 238 239 window_config = None 240 if window_start or window_end or window_num_samples: 241 window_config = WindowRequestInput( 242 start=maybe_format_date(window_start), 243 end=maybe_format_date(window_end), 244 num_samples=window_num_samples, 245 ) 246 247 use_adhoc_data = False 248 if adhoc_data is not None: 249 use_adhoc_data = True 250 251 adhoc_data_list = None 252 if use_adhoc_data: 253 if isinstance(adhoc_data, list): 254 adhoc_data_list = adhoc_data 255 elif isinstance(adhoc_data, np.ndarray): 256 adhoc_data_list = adhoc_data.tolist() 257 elif isinstance(adhoc_data, pd.DataFrame): 258 adhoc_data_list = adhoc_data.values.tolist() 259 else: 260 raise Exception(f"Unknown adhoc data type {type(adhoc_data)}") 261 262 if reference_config is None and window_config is None and not use_adhoc_data: 263 raise Exception( 264 "You must specify a reference config, a window config or adhoc data" 265 ) 266 267 client = self.client # type: ignore 268 exp_config_id = self.id 269 workspace_id = client.get_current_workspace().id() 270 271 ear = ExplainabilityRequest( 272 id=None, 273 explainability_config_id=exp_config_id, 274 workspace_id=workspace_id, 275 reference_config=reference_config, 276 window_config=window_config, 277 use_adhoc_data=use_adhoc_data, 278 adhoc_data=adhoc_data_list, 279 ) 280 281 result = client._post_rest_api_json( 282 f"v1/api/explainability/create_request", asdict(ear) 283 ) 284 expr_id = result["id"] 285 return self.get_explainability_request(expr_id) 286 287 288def maybe_format_date(d: Optional[datetime]) -> Optional[str]: 289 if d: 290 return d.astimezone(tz=timezone.utc).isoformat() 291 return None 292 293 294@dataclass 295class ExplainabilityRequest: 296 """This class outlines what should be submitted to start the explainability 297 analysis with a particular config. 298 299 The request can be to analyze reference data, historical data from the ref 300 pipeline, or new adhoc data submitted with the request or both. 301 302 id and status are optional and are filled in by the processing steps. 303 304 If the request has use_reference_data = True, num_sample inference logs are 305 sampled from between the start and end dates or the entire (last 100_000) 306 inferences. 307 """ 308 309 id: Optional[UUID] 310 workspace_id: int 311 explainability_config_id: Optional[UUID] = None 312 status: Optional[Dict[str, Any]] = None 313 reference_config: Optional[WindowRequestInput] = None 314 window_config: Optional[WindowRequestInput] = None 315 use_adhoc_data: bool = False 316 adhoc_data: Optional[List[List[float]]] = None 317 318 def _repr_html_(self): 319 fields = [f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in asdict(self).items()] 320 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 321 322 def get_explainability_result(self) -> Optional["ExplainabilityResult"]: 323 """Get the full explainability result whether completed or not.""" 324 325 client = self.client # type: ignore 326 327 result = client._post_rest_api_json( 328 f"v1/api/explainability/get_result", 329 {"explainability_result_id": self.id}, 330 ) 331 return build_explainability_result(result) 332 333 334@dataclass 335class PDPResult: 336 """This class holds the PDP/ICE part of the results. 337 PDP/ICE results are generated for each observation by holding 338 all but one feature constant, varying that feature and analyzing 339 that prediction. Thus the results are per inference per feature. 340 341 feature_name is the feature that this result is for. 342 xs is the list of x values that the feature was varied through. 343 344 pdp_vals is the list of resulting values. 345 model, shap and feature expected values are the mean/expected values 346 for that model, shap and feature. 347 """ 348 349 feature_name: str 350 ice_vals: np.ndarray 351 pdp_vals: List[float] 352 model_expected_value: List[float] 353 shap_expected_value: List[float] 354 feature_expected_value: List[float] 355 356 357@dataclass 358class WindowResult: 359 data: NDArray # the original data used for the analysis 360 shap_values: NDArray # the caculated shap values 361 base_values: NDArray # The expected value (mean) for each prediction. 362 pdp_results: List[PDPResult] # Partial dependence plot data for each feature 363 364 365def find_pdp_result(window_result: WindowResult, feature_name: str) -> PDPResult: 366 """Gets the pdp result object for the specified feature.""" 367 368 for p in window_result.pdp_results: 369 if p.feature_name == feature_name: 370 return p 371 raise Exception(f"Did not find feature {feature_name} in pdp_results.") 372 373 374@dataclass 375class ExplainabilityResult: 376 """This class holds the explainability result created by processing an 377 explainability request. 378 379 id and status are optional and will be filled in by processing. The id 380 will be the same as the request id since the results are stored in minio. 381 382 num_inferences and num_batches are nice to know status information and 383 could be brought into the status object in the future. 384 385 reference and adhoc data are the actual inferences used in the analysis. 386 387 reference and adhoc shap values are the shap values for each feature for 388 each prediction. 389 390 base_values are the expected value for each prediction. These values will 391 all be the same so may be changed to a single float in the future. 392 393 pdp results are a list of pdp/ice results for each feature. 394 395 """ 396 397 id: Optional[UUID] 398 workspace_id: int 399 explainability_config_id: UUID 400 num_inferences: int 401 num_batches: int 402 compute_time: float 403 status: Dict[str, Any] 404 feature_names: List[str] # The names of the columns 405 feature_bounds: Dict[str, FeatureBounds] 406 reference_result: Optional[WindowResult] 407 window_result: Optional[WindowResult] 408 adhoc_result: Optional[WindowResult] 409 410 reference_color = np.array([127.0, 196, 252]) / 255 411 window_color = np.array([252, 127.0, 196]) / 255 412 413 def _repr_html_(self): 414 fields = [ 415 f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>" 416 for k, v in asdict(self).items() 417 ] 418 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 419 420 def feature_effects(self) -> pd.DataFrame: 421 """Returns a dataframe summarizing the mean feature effects of the reference 422 data as well as the feature effects for each adhoc inference.""" 423 424 df = pd.DataFrame(index=self.feature_names) 425 426 if self.reference_result is not None: 427 vals = np.array(self.reference_result.shap_values) 428 if vals.shape[0] > 0: 429 df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0) 430 df["ref_std_dev"] = vals.std(axis=0) 431 df[REF_MAX_ABS] = np.abs(vals).max(axis=0) 432 433 if self.window_result is not None: 434 vals = np.array(self.window_result.shap_values) 435 if vals.shape[0] > 0: 436 df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0) 437 df["window_std_dev"] = vals.std(axis=0) 438 df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0) 439 440 if self.adhoc_result is not None: 441 vals = np.array(self.adhoc_result.shap_values) 442 if vals.shape[0] > 0: 443 for idx in range(vals.shape[0]): 444 df[f"input_{idx+1}"] = vals[idx, :] 445 446 if REF_MEAN_ABS in df.columns: 447 return df.sort_values(by=REF_MEAN_ABS, ascending=False) 448 return df 449 450 def effect_summary(self) -> pd.DataFrame: 451 """Returns a dataframe with the expected/mean values and the shap adjustments.""" 452 453 effects = self.feature_effects() 454 base_value = unwrap(self.reference_result).base_values[0] 455 data = {} 456 data["base_value"] = [0, base_value] 457 458 # we should rename inputs_X to 'house_X' or similar 459 input_cols = [c for c in effects.columns if "input_" in c] 460 effect_sums = effects[input_cols].sum(axis=0) 461 for c, v in zip(input_cols, effect_sums): 462 data[c] = [v, v + base_value] 463 return pd.DataFrame(data, index=["adjustment", "total"]) 464 465 def check_status(self) -> bool: 466 """Ensure we've completed before trying to plot anything.""" 467 468 if self.status["status"] != "COMPLETED": 469 raise Exception(f"Analysis has not (yet) completed: {self.status}") 470 return True 471 472 def plot_feature_effects( 473 self, 474 mode: str = "mean", 475 top_n=0, 476 plot_reference=True, 477 plot_window=True, 478 plot_adhoc=True, 479 ): 480 """Creates a bar plot of the mean or max abs feature effects.""" 481 482 if mode not in ["mean", "max", "individual"]: 483 raise Exception("Mode must be one of: 'mean', 'max', 'individual'") 484 485 self.check_status() 486 487 df = self.feature_effects() 488 if top_n > 0: 489 df = df.head(top_n) 490 491 cols = df.columns 492 if not plot_reference: 493 cols = [c for c in cols if not c.startswith("ref_")] 494 if not plot_window: 495 cols = [c for c in cols if not c.startswith("window_")] 496 if not plot_adhoc: 497 cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")] 498 499 df = df[cols] 500 501 if REF_MEAN_ABS in df.columns and mode != "individual": 502 if mode == "max": 503 title = "Max Absolute Feature Effect" 504 _ = plt.bar(df.index, df[REF_MAX_ABS]) 505 else: 506 title = "Mean Absolute Feature Effect" 507 _ = plt.bar(df.index, df[REF_MEAN_ABS]) 508 else: 509 title = "Feature Effects" 510 ax = plt.subplot(1, 1, 1) 511 _ = df.drop( 512 ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore" 513 ).plot(kind="bar", ax=ax) 514 plt.legend(bbox_to_anchor=(1, 1), loc="upper left") 515 516 plt.grid() 517 plt.xticks(rotation=90) 518 plt.title(title) 519 520 def plot_ice_values( 521 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 522 ): 523 """Creates a combination ICE plot for the adhoc data if any 524 in custom colors and the reference data if any in translucent 525 blue.""" 526 527 self.check_status() 528 529 ice_alpha = 0.2 530 ice_width = 5 531 532 xs = self.feature_bounds[feature_name].xs 533 534 # pdp_result = unwrap(self.pdp_result(feature_name)) 535 # xs = pdp_result.xs 536 # mean_vals = np.array(pdp_result.pdp_vals) 537 # ice_vals = np.array(pdp_result.ice_vals) 538 539 plt.title(f"ICE for {feature_name}") 540 plt.xlabel(feature_name) 541 plt.ylabel("Prediction") 542 543 if plot_reference and self.reference_result is not None: 544 pdp_result = find_pdp_result(self.reference_result, feature_name) 545 ice_vals = pdp_result.ice_vals 546 mean_vals = pdp_result.pdp_vals 547 _ = plt.plot( 548 xs, 549 ice_vals, 550 color=self.reference_color, 551 alpha=ice_alpha, 552 linewidth=ice_width, 553 ) 554 _ = plt.plot( 555 xs, mean_vals, color="black", zorder=10, label="Reference Mean" 556 ) 557 558 if plot_window and self.window_result is not None: 559 pdp_result = find_pdp_result(self.window_result, feature_name) 560 ice_vals = pdp_result.ice_vals 561 mean_vals = pdp_result.pdp_vals 562 _ = plt.plot( 563 xs, 564 ice_vals, 565 color=self.window_color, 566 alpha=ice_alpha, 567 linewidth=ice_width, 568 ) 569 _ = plt.plot( 570 xs, mean_vals, linestyle="dashed", color="black", label="Window Mean" 571 ) 572 573 if plot_adhoc and self.adhoc_result is not None: 574 pdp_result = find_pdp_result(self.adhoc_result, feature_name) 575 ice_vals = np.array(pdp_result.ice_vals) 576 mean_vals = pdp_result.pdp_vals 577 for idx in range(self.adhoc_result.data.shape[0]): 578 _ = plt.plot( 579 xs, 580 ice_vals[:, idx : idx + 1], 581 linewidth=3, 582 label=f"input_{idx+1}", 583 color=colors[idx % len(colors)], 584 ) 585 plt.legend() 586 587 plt.ylim(0) 588 _ = plt.grid() 589 plt.show() 590 591 def plot_all_features( 592 self, 593 title="Feature Effect per Inference", 594 plot_reference=True, 595 plot_window=True, 596 plot_adhoc=True, 597 top_n: int = 0, 598 ): 599 """Creates a 'bee swarm' plot of all/each feature effect.""" 600 self.check_status() 601 np.random.seed(42) 602 603 # Our custom palette will go from black to a weird blue green. 604 # Not sure of the best palette to use. We need lighter to be 605 # higher values. Started with a red to green ramp but we also 606 # need to be aware to color blindness issues. 607 rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)] 608 609 # Gather the data we'll need 610 shap_values_list = [] 611 feature_values_list = [] 612 if plot_reference and self.reference_result: 613 shap_values_list.append(self.reference_result.shap_values) 614 feature_values_list.append(self.reference_result.data) 615 if plot_window and self.window_result: 616 shap_values_list.append(self.window_result.shap_values) 617 feature_values_list.append(self.window_result.data) 618 if plot_adhoc and self.adhoc_result: 619 shap_values_list.append(self.adhoc_result.shap_values) 620 feature_values_list.append(self.adhoc_result.data) 621 622 if not shap_values_list: 623 raise Exception("Some data must be specified to create a plot.") 624 625 shap_values = np.vstack(shap_values_list) 626 feature_values = np.vstack(feature_values_list) 627 628 # create a df of the shap values / contributions 629 df = pd.DataFrame(shap_values) 630 df.columns = self.feature_names 631 # create a df of the original feature values 632 feature_df = pd.DataFrame(feature_values) 633 feature_df.columns = self.feature_names 634 635 # We plot in strips from the bottom so put the most important at the top. 636 # first taking the top_n if specified. 637 feature_names = list(self.feature_effects().index) 638 if top_n > 0: 639 feature_names = feature_names[:top_n] 640 feature_names = list(reversed(feature_names)) 641 642 num_features = len(feature_names) 643 num_obs = df.shape[0] 644 645 fig = plt.figure() 646 fig.patch.set_facecolor("white") 647 ax = fig.get_axes() 648 649 for i, col in enumerate(feature_names): 650 # create a temp df where the y is the level + jitter 651 dfc = pd.DataFrame(df[col]) 652 dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs) 653 # vals is the original feature values and we create a custom palette 654 dfc["vals"] = feature_df[col] 655 unique_vals = dfc["vals"].unique() 656 n_colors = len(unique_vals) 657 palette = sns.blend_palette(rgb, n_colors=n_colors) 658 659 # plot a scatter plot strip 660 ax = sns.scatterplot( 661 x=col, 662 y="y", 663 data=dfc, 664 alpha=0.75, 665 hue="vals", 666 palette=palette, 667 legend=None, 668 ) 669 670 # change the tick labels from strip number to feature name 671 ax.set_yticks(range(num_features + 1)) 672 ticks = [""] 673 ticks.extend(feature_names) 674 ax.set_yticklabels(ticks) 675 676 plt.xlabel("Shap Value") 677 plt.ylabel("") 678 plt.title(title) 679 plt.ylim(0, num_features + 1) 680 plt.grid() 681 plt.show() 682 683 def plot_feature_values_vs_shap_values( 684 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 685 ): 686 """Creates a scatter plot of the feature vs shap values. 687 adhoc data if any is in custom colors. reference data in translucent 688 blue.""" 689 690 self.check_status() 691 692 alpha = 0.5 693 694 i = self.feature_names.index(feature_name) 695 696 plt.title(f"Shap Values for {feature_name} for Each Input") 697 plt.xlabel(feature_name) 698 plt.ylabel("Shap Values") 699 plt.grid() 700 701 if plot_reference and self.reference_result: 702 _ = plt.scatter( 703 self.reference_result.data[:, i], 704 self.reference_result.shap_values[:, i], 705 s=100, 706 alpha=alpha, 707 color=self.reference_color, 708 label="Reference", 709 ) 710 711 if plot_window and self.window_result: 712 _ = plt.scatter( 713 self.window_result.data[:, i], 714 self.window_result.shap_values[:, i], 715 marker="p", 716 s=100, 717 alpha=alpha, 718 color=self.window_color, 719 label="Window", 720 ) 721 722 if plot_adhoc and self.adhoc_result: 723 for idx in range(self.adhoc_result.data.shape[0]): 724 _ = plt.scatter( 725 self.adhoc_result.data[idx, i], 726 self.adhoc_result.shap_values[idx, i], 727 marker="d", # type: ignore 728 s=200, 729 label=f"input_{idx+1}", 730 color=colors[idx % len(colors)], 731 ) 732 733 # expected_shap_values = ( 734 # reference_shap_values 735 # if self.reference_result.shap_values.shape[0] > 0 736 # else adhoc_shap_values 737 # ) 738 # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x 739 740 # mean_feature_value = expected_x[:, i].mean() 741 # min_y = expected_shap_values[:, i].min() 742 # max_y = expected_shap_values[:, i].max() 743 # plt.vlines( 744 # mean_feature_value, 745 # min_y, 746 # max_y, 747 # colors="gray", 748 # linestyle="dotted", 749 # label=f"Mean {feature_name}", 750 # ) 751 plt.legend() 752 plt.show() 753 754 755def build_window_result(data) -> Optional[WindowResult]: 756 if data is not None: 757 return WindowResult( 758 data=np.array(data["data"]), 759 shap_values=np.array(data["shap_values"]), 760 base_values=np.array(data["base_values"]), 761 pdp_results=[PDPResult(**d) for d in data["pdp_results"]], 762 ) 763 else: 764 return None 765 766 767def build_explainability_result(data): 768 """Convinience function to parse json into the full result object 769 we want.""" 770 771 return ExplainabilityResult( 772 id=data["id"], 773 workspace_id=data["workspace_id"], 774 explainability_config_id=data["explainability_config_id"], 775 num_inferences=data["num_inferences"], 776 num_batches=data["num_batches"], 777 compute_time=data["compute_time"], 778 status=data["status"], 779 feature_names=data["feature_names"], 780 feature_bounds={ 781 k: FeatureBounds(**v) for k, v in data["feature_bounds"].items() 782 }, 783 reference_result=build_window_result(data["reference_result"]), 784 window_result=build_window_result(data["window_result"]), 785 adhoc_result=build_window_result(data["adhoc_result"]), 786 )
74@dataclass 75class WindowRequestInput: 76 start: Optional[str] # These are strings in our request but 77 end: Optional[str] # should be interpreted as 3339 formated datetimes 78 num_samples: Optional[int]
WindowRequestInput(start: Optional[str], end: Optional[str], num_samples: Optional[int])
81class ExplainabilityConfigList(list): 82 def _repr_html_(self): 83 rows = [ 84 f""" 85 <tr"> 86 <td>{exp_config.id}</td> 87 <td>{exp_config.status}</td> 88 </tr> 89 """ 90 for exp_config in self 91 ] 92 table = """<table> 93 <tr> 94 <th>Id</th> 95 <th>Status</th> 96 </tr> 97 {0} 98 </table>""".format( 99 "\n".join(rows) 100 ) 101 102 return table
Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
Inherited Members
- builtins.list
- list
- clear
- copy
- append
- insert
- extend
- pop
- remove
- index
- count
- reverse
- sort
105class ExplainabilityRequestList(list): 106 """Adds a _repr_html_ to a list of explainability requests.""" 107 108 def _repr_html_(self): 109 """Assuming all items are explainabilty requests generates an HTML 110 table for jupyter.""" 111 rows = [ 112 f""" 113 <tr"> 114 <td>{exp_req.id}</td> 115 <td>{exp_req.status}</td> 116 <td>{exp_req.reference_config}</td> 117 <td>{exp_req.window_config}</td> 118 <td>{exp_req.use_adhoc_data}</td> 119 </tr> 120 """ 121 for exp_req in self 122 ] 123 table = """<table> 124 <tr> 125 <th>Id</th> 126 <th>Status</th> 127 <th>Reference Data</th> 128 <th>Window Data</th> 129 <th>Adhoc</th> 130 </tr> 131 {0} 132 </table>""".format( 133 "\n".join(rows) 134 ) 135 136 return table
Adds a _repr_html_ to a list of explainability requests.
Inherited Members
- builtins.list
- list
- clear
- copy
- append
- insert
- extend
- pop
- remove
- index
- count
- reverse
- sort
FeatureBounds(min: float, max: float, xs: List[float])
146@dataclass 147class ExplainabilityConfig: 148 """This class specifies an explainability configuration that can be used 149 to later submit explainability requests which cause the server to do the 150 analysis and create explainability results. 151 152 ExplainabilityConfig are necessary to ensure the explainability pipeline 153 is created and is deployed and so that various requests are processed in 154 the same manner and can be compared. 155 156 id, status, feature_bounds and reference_pipeline_version are optional 157 and will be filled out when processed and saved to the database. 158 159 workspace id must match the users/pipelines workspace and 160 reference_pipeline_version must refer to a valid pipeline version that the 161 user has access too. 162 163 num_points specifies how many samples to take when varying the values of a 164 feature for the PDP/ICE analysis through the feature_bounds. 165 166 feature_names are convinince for the user. output_names is not currently used. 167 """ 168 169 id: Optional[UUID] 170 workspace_id: int 171 status: Optional[Dict[str, Any]] 172 reference_pipeline_version: str 173 explainability_pipeline_version: Optional[str] 174 num_points: int = 10 175 feature_names: Optional[List[str]] = None 176 feature_bounds: Optional[Dict[str, FeatureBounds]] = None 177 output_names: Optional[List[str]] = None 178 179 def _repr_html_(self): 180 fields = [ 181 f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>" 182 for k, v in asdict(self).items() 183 ] 184 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 185 186 def list_explainability_requests(self) -> List["ExplainabilityRequest"]: 187 """List the explainability requests we've created.""" 188 189 client = self.client # type: ignore 190 result = client._post_rest_api_json( 191 f"v1/api/explainability/list_requests", 192 {"explainability_config_id": self.id}, 193 ) 194 erl = [ExplainabilityRequest(**ec) for ec in result] 195 for er in erl: 196 er.client = client # type: ignore 197 return ExplainabilityRequestList(erl) 198 199 def get_explainability_request( 200 self, expr: Union[str, "ExplainabilityConfig"] 201 ) -> Optional["ExplainabilityRequest"]: 202 """Get the full explainability result whether completed or not.""" 203 204 if isinstance(expr, str): 205 explain_id = expr 206 else: 207 explain_id = str(expr.id) 208 209 client = self.client # type: ignore 210 result = client._post_rest_api_json( 211 f"v1/api/explainability/get_request", 212 {"explainability_request_id": explain_id}, 213 ) 214 215 exp_cfg = ExplainabilityRequest(**result) 216 exp_cfg.client = client # type: ignore 217 return exp_cfg 218 219 def submit_explainability_request( 220 self, 221 reference_start: Optional[datetime] = None, 222 reference_end: Optional[datetime] = None, 223 reference_num_samples: Optional[int] = None, 224 use_reference: Optional[bool] = True, 225 window_start: Optional[datetime] = None, 226 window_end: Optional[datetime] = None, 227 window_num_samples: Optional[int] = None, 228 adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None, 229 ): 230 """Submit an analysis on reference or adhoc data using a particular config""" 231 232 reference_config = None 233 if use_reference: 234 reference_config = WindowRequestInput( 235 start=maybe_format_date(reference_start), 236 end=maybe_format_date(reference_end), 237 num_samples=reference_num_samples, 238 ) 239 240 window_config = None 241 if window_start or window_end or window_num_samples: 242 window_config = WindowRequestInput( 243 start=maybe_format_date(window_start), 244 end=maybe_format_date(window_end), 245 num_samples=window_num_samples, 246 ) 247 248 use_adhoc_data = False 249 if adhoc_data is not None: 250 use_adhoc_data = True 251 252 adhoc_data_list = None 253 if use_adhoc_data: 254 if isinstance(adhoc_data, list): 255 adhoc_data_list = adhoc_data 256 elif isinstance(adhoc_data, np.ndarray): 257 adhoc_data_list = adhoc_data.tolist() 258 elif isinstance(adhoc_data, pd.DataFrame): 259 adhoc_data_list = adhoc_data.values.tolist() 260 else: 261 raise Exception(f"Unknown adhoc data type {type(adhoc_data)}") 262 263 if reference_config is None and window_config is None and not use_adhoc_data: 264 raise Exception( 265 "You must specify a reference config, a window config or adhoc data" 266 ) 267 268 client = self.client # type: ignore 269 exp_config_id = self.id 270 workspace_id = client.get_current_workspace().id() 271 272 ear = ExplainabilityRequest( 273 id=None, 274 explainability_config_id=exp_config_id, 275 workspace_id=workspace_id, 276 reference_config=reference_config, 277 window_config=window_config, 278 use_adhoc_data=use_adhoc_data, 279 adhoc_data=adhoc_data_list, 280 ) 281 282 result = client._post_rest_api_json( 283 f"v1/api/explainability/create_request", asdict(ear) 284 ) 285 expr_id = result["id"] 286 return self.get_explainability_request(expr_id)
This class specifies an explainability configuration that can be used to later submit explainability requests which cause the server to do the analysis and create explainability results.
ExplainabilityConfig are necessary to ensure the explainability pipeline is created and is deployed and so that various requests are processed in the same manner and can be compared.
id, status, feature_bounds and reference_pipeline_version are optional and will be filled out when processed and saved to the database.
workspace id must match the users/pipelines workspace and reference_pipeline_version must refer to a valid pipeline version that the user has access too.
num_points specifies how many samples to take when varying the values of a feature for the PDP/ICE analysis through the feature_bounds.
feature_names are convinince for the user. output_names is not currently used.
186 def list_explainability_requests(self) -> List["ExplainabilityRequest"]: 187 """List the explainability requests we've created.""" 188 189 client = self.client # type: ignore 190 result = client._post_rest_api_json( 191 f"v1/api/explainability/list_requests", 192 {"explainability_config_id": self.id}, 193 ) 194 erl = [ExplainabilityRequest(**ec) for ec in result] 195 for er in erl: 196 er.client = client # type: ignore 197 return ExplainabilityRequestList(erl)
List the explainability requests we've created.
199 def get_explainability_request( 200 self, expr: Union[str, "ExplainabilityConfig"] 201 ) -> Optional["ExplainabilityRequest"]: 202 """Get the full explainability result whether completed or not.""" 203 204 if isinstance(expr, str): 205 explain_id = expr 206 else: 207 explain_id = str(expr.id) 208 209 client = self.client # type: ignore 210 result = client._post_rest_api_json( 211 f"v1/api/explainability/get_request", 212 {"explainability_request_id": explain_id}, 213 ) 214 215 exp_cfg = ExplainabilityRequest(**result) 216 exp_cfg.client = client # type: ignore 217 return exp_cfg
Get the full explainability result whether completed or not.
219 def submit_explainability_request( 220 self, 221 reference_start: Optional[datetime] = None, 222 reference_end: Optional[datetime] = None, 223 reference_num_samples: Optional[int] = None, 224 use_reference: Optional[bool] = True, 225 window_start: Optional[datetime] = None, 226 window_end: Optional[datetime] = None, 227 window_num_samples: Optional[int] = None, 228 adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None, 229 ): 230 """Submit an analysis on reference or adhoc data using a particular config""" 231 232 reference_config = None 233 if use_reference: 234 reference_config = WindowRequestInput( 235 start=maybe_format_date(reference_start), 236 end=maybe_format_date(reference_end), 237 num_samples=reference_num_samples, 238 ) 239 240 window_config = None 241 if window_start or window_end or window_num_samples: 242 window_config = WindowRequestInput( 243 start=maybe_format_date(window_start), 244 end=maybe_format_date(window_end), 245 num_samples=window_num_samples, 246 ) 247 248 use_adhoc_data = False 249 if adhoc_data is not None: 250 use_adhoc_data = True 251 252 adhoc_data_list = None 253 if use_adhoc_data: 254 if isinstance(adhoc_data, list): 255 adhoc_data_list = adhoc_data 256 elif isinstance(adhoc_data, np.ndarray): 257 adhoc_data_list = adhoc_data.tolist() 258 elif isinstance(adhoc_data, pd.DataFrame): 259 adhoc_data_list = adhoc_data.values.tolist() 260 else: 261 raise Exception(f"Unknown adhoc data type {type(adhoc_data)}") 262 263 if reference_config is None and window_config is None and not use_adhoc_data: 264 raise Exception( 265 "You must specify a reference config, a window config or adhoc data" 266 ) 267 268 client = self.client # type: ignore 269 exp_config_id = self.id 270 workspace_id = client.get_current_workspace().id() 271 272 ear = ExplainabilityRequest( 273 id=None, 274 explainability_config_id=exp_config_id, 275 workspace_id=workspace_id, 276 reference_config=reference_config, 277 window_config=window_config, 278 use_adhoc_data=use_adhoc_data, 279 adhoc_data=adhoc_data_list, 280 ) 281 282 result = client._post_rest_api_json( 283 f"v1/api/explainability/create_request", asdict(ear) 284 ) 285 expr_id = result["id"] 286 return self.get_explainability_request(expr_id)
Submit an analysis on reference or adhoc data using a particular config
295@dataclass 296class ExplainabilityRequest: 297 """This class outlines what should be submitted to start the explainability 298 analysis with a particular config. 299 300 The request can be to analyze reference data, historical data from the ref 301 pipeline, or new adhoc data submitted with the request or both. 302 303 id and status are optional and are filled in by the processing steps. 304 305 If the request has use_reference_data = True, num_sample inference logs are 306 sampled from between the start and end dates or the entire (last 100_000) 307 inferences. 308 """ 309 310 id: Optional[UUID] 311 workspace_id: int 312 explainability_config_id: Optional[UUID] = None 313 status: Optional[Dict[str, Any]] = None 314 reference_config: Optional[WindowRequestInput] = None 315 window_config: Optional[WindowRequestInput] = None 316 use_adhoc_data: bool = False 317 adhoc_data: Optional[List[List[float]]] = None 318 319 def _repr_html_(self): 320 fields = [f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in asdict(self).items()] 321 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 322 323 def get_explainability_result(self) -> Optional["ExplainabilityResult"]: 324 """Get the full explainability result whether completed or not.""" 325 326 client = self.client # type: ignore 327 328 result = client._post_rest_api_json( 329 f"v1/api/explainability/get_result", 330 {"explainability_result_id": self.id}, 331 ) 332 return build_explainability_result(result)
This class outlines what should be submitted to start the explainability analysis with a particular config.
The request can be to analyze reference data, historical data from the ref pipeline, or new adhoc data submitted with the request or both.
id and status are optional and are filled in by the processing steps.
If the request has use_reference_data = True, num_sample inference logs are sampled from between the start and end dates or the entire (last 100_000) inferences.
323 def get_explainability_result(self) -> Optional["ExplainabilityResult"]: 324 """Get the full explainability result whether completed or not.""" 325 326 client = self.client # type: ignore 327 328 result = client._post_rest_api_json( 329 f"v1/api/explainability/get_result", 330 {"explainability_result_id": self.id}, 331 ) 332 return build_explainability_result(result)
Get the full explainability result whether completed or not.
335@dataclass 336class PDPResult: 337 """This class holds the PDP/ICE part of the results. 338 PDP/ICE results are generated for each observation by holding 339 all but one feature constant, varying that feature and analyzing 340 that prediction. Thus the results are per inference per feature. 341 342 feature_name is the feature that this result is for. 343 xs is the list of x values that the feature was varied through. 344 345 pdp_vals is the list of resulting values. 346 model, shap and feature expected values are the mean/expected values 347 for that model, shap and feature. 348 """ 349 350 feature_name: str 351 ice_vals: np.ndarray 352 pdp_vals: List[float] 353 model_expected_value: List[float] 354 shap_expected_value: List[float] 355 feature_expected_value: List[float]
This class holds the PDP/ICE part of the results. PDP/ICE results are generated for each observation by holding all but one feature constant, varying that feature and analyzing that prediction. Thus the results are per inference per feature.
feature_name is the feature that this result is for. xs is the list of x values that the feature was varied through.
pdp_vals is the list of resulting values. model, shap and feature expected values are the mean/expected values for that model, shap and feature.
358@dataclass 359class WindowResult: 360 data: NDArray # the original data used for the analysis 361 shap_values: NDArray # the caculated shap values 362 base_values: NDArray # The expected value (mean) for each prediction. 363 pdp_results: List[PDPResult] # Partial dependence plot data for each feature
WindowResult(data: numpy.ndarray, shap_values: numpy.ndarray, base_values: numpy.ndarray, pdp_results: List[wallaroo.explainability.PDPResult])
366def find_pdp_result(window_result: WindowResult, feature_name: str) -> PDPResult: 367 """Gets the pdp result object for the specified feature.""" 368 369 for p in window_result.pdp_results: 370 if p.feature_name == feature_name: 371 return p 372 raise Exception(f"Did not find feature {feature_name} in pdp_results.")
Gets the pdp result object for the specified feature.
375@dataclass 376class ExplainabilityResult: 377 """This class holds the explainability result created by processing an 378 explainability request. 379 380 id and status are optional and will be filled in by processing. The id 381 will be the same as the request id since the results are stored in minio. 382 383 num_inferences and num_batches are nice to know status information and 384 could be brought into the status object in the future. 385 386 reference and adhoc data are the actual inferences used in the analysis. 387 388 reference and adhoc shap values are the shap values for each feature for 389 each prediction. 390 391 base_values are the expected value for each prediction. These values will 392 all be the same so may be changed to a single float in the future. 393 394 pdp results are a list of pdp/ice results for each feature. 395 396 """ 397 398 id: Optional[UUID] 399 workspace_id: int 400 explainability_config_id: UUID 401 num_inferences: int 402 num_batches: int 403 compute_time: float 404 status: Dict[str, Any] 405 feature_names: List[str] # The names of the columns 406 feature_bounds: Dict[str, FeatureBounds] 407 reference_result: Optional[WindowResult] 408 window_result: Optional[WindowResult] 409 adhoc_result: Optional[WindowResult] 410 411 reference_color = np.array([127.0, 196, 252]) / 255 412 window_color = np.array([252, 127.0, 196]) / 255 413 414 def _repr_html_(self): 415 fields = [ 416 f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>" 417 for k, v in asdict(self).items() 418 ] 419 return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>" 420 421 def feature_effects(self) -> pd.DataFrame: 422 """Returns a dataframe summarizing the mean feature effects of the reference 423 data as well as the feature effects for each adhoc inference.""" 424 425 df = pd.DataFrame(index=self.feature_names) 426 427 if self.reference_result is not None: 428 vals = np.array(self.reference_result.shap_values) 429 if vals.shape[0] > 0: 430 df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0) 431 df["ref_std_dev"] = vals.std(axis=0) 432 df[REF_MAX_ABS] = np.abs(vals).max(axis=0) 433 434 if self.window_result is not None: 435 vals = np.array(self.window_result.shap_values) 436 if vals.shape[0] > 0: 437 df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0) 438 df["window_std_dev"] = vals.std(axis=0) 439 df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0) 440 441 if self.adhoc_result is not None: 442 vals = np.array(self.adhoc_result.shap_values) 443 if vals.shape[0] > 0: 444 for idx in range(vals.shape[0]): 445 df[f"input_{idx+1}"] = vals[idx, :] 446 447 if REF_MEAN_ABS in df.columns: 448 return df.sort_values(by=REF_MEAN_ABS, ascending=False) 449 return df 450 451 def effect_summary(self) -> pd.DataFrame: 452 """Returns a dataframe with the expected/mean values and the shap adjustments.""" 453 454 effects = self.feature_effects() 455 base_value = unwrap(self.reference_result).base_values[0] 456 data = {} 457 data["base_value"] = [0, base_value] 458 459 # we should rename inputs_X to 'house_X' or similar 460 input_cols = [c for c in effects.columns if "input_" in c] 461 effect_sums = effects[input_cols].sum(axis=0) 462 for c, v in zip(input_cols, effect_sums): 463 data[c] = [v, v + base_value] 464 return pd.DataFrame(data, index=["adjustment", "total"]) 465 466 def check_status(self) -> bool: 467 """Ensure we've completed before trying to plot anything.""" 468 469 if self.status["status"] != "COMPLETED": 470 raise Exception(f"Analysis has not (yet) completed: {self.status}") 471 return True 472 473 def plot_feature_effects( 474 self, 475 mode: str = "mean", 476 top_n=0, 477 plot_reference=True, 478 plot_window=True, 479 plot_adhoc=True, 480 ): 481 """Creates a bar plot of the mean or max abs feature effects.""" 482 483 if mode not in ["mean", "max", "individual"]: 484 raise Exception("Mode must be one of: 'mean', 'max', 'individual'") 485 486 self.check_status() 487 488 df = self.feature_effects() 489 if top_n > 0: 490 df = df.head(top_n) 491 492 cols = df.columns 493 if not plot_reference: 494 cols = [c for c in cols if not c.startswith("ref_")] 495 if not plot_window: 496 cols = [c for c in cols if not c.startswith("window_")] 497 if not plot_adhoc: 498 cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")] 499 500 df = df[cols] 501 502 if REF_MEAN_ABS in df.columns and mode != "individual": 503 if mode == "max": 504 title = "Max Absolute Feature Effect" 505 _ = plt.bar(df.index, df[REF_MAX_ABS]) 506 else: 507 title = "Mean Absolute Feature Effect" 508 _ = plt.bar(df.index, df[REF_MEAN_ABS]) 509 else: 510 title = "Feature Effects" 511 ax = plt.subplot(1, 1, 1) 512 _ = df.drop( 513 ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore" 514 ).plot(kind="bar", ax=ax) 515 plt.legend(bbox_to_anchor=(1, 1), loc="upper left") 516 517 plt.grid() 518 plt.xticks(rotation=90) 519 plt.title(title) 520 521 def plot_ice_values( 522 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 523 ): 524 """Creates a combination ICE plot for the adhoc data if any 525 in custom colors and the reference data if any in translucent 526 blue.""" 527 528 self.check_status() 529 530 ice_alpha = 0.2 531 ice_width = 5 532 533 xs = self.feature_bounds[feature_name].xs 534 535 # pdp_result = unwrap(self.pdp_result(feature_name)) 536 # xs = pdp_result.xs 537 # mean_vals = np.array(pdp_result.pdp_vals) 538 # ice_vals = np.array(pdp_result.ice_vals) 539 540 plt.title(f"ICE for {feature_name}") 541 plt.xlabel(feature_name) 542 plt.ylabel("Prediction") 543 544 if plot_reference and self.reference_result is not None: 545 pdp_result = find_pdp_result(self.reference_result, feature_name) 546 ice_vals = pdp_result.ice_vals 547 mean_vals = pdp_result.pdp_vals 548 _ = plt.plot( 549 xs, 550 ice_vals, 551 color=self.reference_color, 552 alpha=ice_alpha, 553 linewidth=ice_width, 554 ) 555 _ = plt.plot( 556 xs, mean_vals, color="black", zorder=10, label="Reference Mean" 557 ) 558 559 if plot_window and self.window_result is not None: 560 pdp_result = find_pdp_result(self.window_result, feature_name) 561 ice_vals = pdp_result.ice_vals 562 mean_vals = pdp_result.pdp_vals 563 _ = plt.plot( 564 xs, 565 ice_vals, 566 color=self.window_color, 567 alpha=ice_alpha, 568 linewidth=ice_width, 569 ) 570 _ = plt.plot( 571 xs, mean_vals, linestyle="dashed", color="black", label="Window Mean" 572 ) 573 574 if plot_adhoc and self.adhoc_result is not None: 575 pdp_result = find_pdp_result(self.adhoc_result, feature_name) 576 ice_vals = np.array(pdp_result.ice_vals) 577 mean_vals = pdp_result.pdp_vals 578 for idx in range(self.adhoc_result.data.shape[0]): 579 _ = plt.plot( 580 xs, 581 ice_vals[:, idx : idx + 1], 582 linewidth=3, 583 label=f"input_{idx+1}", 584 color=colors[idx % len(colors)], 585 ) 586 plt.legend() 587 588 plt.ylim(0) 589 _ = plt.grid() 590 plt.show() 591 592 def plot_all_features( 593 self, 594 title="Feature Effect per Inference", 595 plot_reference=True, 596 plot_window=True, 597 plot_adhoc=True, 598 top_n: int = 0, 599 ): 600 """Creates a 'bee swarm' plot of all/each feature effect.""" 601 self.check_status() 602 np.random.seed(42) 603 604 # Our custom palette will go from black to a weird blue green. 605 # Not sure of the best palette to use. We need lighter to be 606 # higher values. Started with a red to green ramp but we also 607 # need to be aware to color blindness issues. 608 rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)] 609 610 # Gather the data we'll need 611 shap_values_list = [] 612 feature_values_list = [] 613 if plot_reference and self.reference_result: 614 shap_values_list.append(self.reference_result.shap_values) 615 feature_values_list.append(self.reference_result.data) 616 if plot_window and self.window_result: 617 shap_values_list.append(self.window_result.shap_values) 618 feature_values_list.append(self.window_result.data) 619 if plot_adhoc and self.adhoc_result: 620 shap_values_list.append(self.adhoc_result.shap_values) 621 feature_values_list.append(self.adhoc_result.data) 622 623 if not shap_values_list: 624 raise Exception("Some data must be specified to create a plot.") 625 626 shap_values = np.vstack(shap_values_list) 627 feature_values = np.vstack(feature_values_list) 628 629 # create a df of the shap values / contributions 630 df = pd.DataFrame(shap_values) 631 df.columns = self.feature_names 632 # create a df of the original feature values 633 feature_df = pd.DataFrame(feature_values) 634 feature_df.columns = self.feature_names 635 636 # We plot in strips from the bottom so put the most important at the top. 637 # first taking the top_n if specified. 638 feature_names = list(self.feature_effects().index) 639 if top_n > 0: 640 feature_names = feature_names[:top_n] 641 feature_names = list(reversed(feature_names)) 642 643 num_features = len(feature_names) 644 num_obs = df.shape[0] 645 646 fig = plt.figure() 647 fig.patch.set_facecolor("white") 648 ax = fig.get_axes() 649 650 for i, col in enumerate(feature_names): 651 # create a temp df where the y is the level + jitter 652 dfc = pd.DataFrame(df[col]) 653 dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs) 654 # vals is the original feature values and we create a custom palette 655 dfc["vals"] = feature_df[col] 656 unique_vals = dfc["vals"].unique() 657 n_colors = len(unique_vals) 658 palette = sns.blend_palette(rgb, n_colors=n_colors) 659 660 # plot a scatter plot strip 661 ax = sns.scatterplot( 662 x=col, 663 y="y", 664 data=dfc, 665 alpha=0.75, 666 hue="vals", 667 palette=palette, 668 legend=None, 669 ) 670 671 # change the tick labels from strip number to feature name 672 ax.set_yticks(range(num_features + 1)) 673 ticks = [""] 674 ticks.extend(feature_names) 675 ax.set_yticklabels(ticks) 676 677 plt.xlabel("Shap Value") 678 plt.ylabel("") 679 plt.title(title) 680 plt.ylim(0, num_features + 1) 681 plt.grid() 682 plt.show() 683 684 def plot_feature_values_vs_shap_values( 685 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 686 ): 687 """Creates a scatter plot of the feature vs shap values. 688 adhoc data if any is in custom colors. reference data in translucent 689 blue.""" 690 691 self.check_status() 692 693 alpha = 0.5 694 695 i = self.feature_names.index(feature_name) 696 697 plt.title(f"Shap Values for {feature_name} for Each Input") 698 plt.xlabel(feature_name) 699 plt.ylabel("Shap Values") 700 plt.grid() 701 702 if plot_reference and self.reference_result: 703 _ = plt.scatter( 704 self.reference_result.data[:, i], 705 self.reference_result.shap_values[:, i], 706 s=100, 707 alpha=alpha, 708 color=self.reference_color, 709 label="Reference", 710 ) 711 712 if plot_window and self.window_result: 713 _ = plt.scatter( 714 self.window_result.data[:, i], 715 self.window_result.shap_values[:, i], 716 marker="p", 717 s=100, 718 alpha=alpha, 719 color=self.window_color, 720 label="Window", 721 ) 722 723 if plot_adhoc and self.adhoc_result: 724 for idx in range(self.adhoc_result.data.shape[0]): 725 _ = plt.scatter( 726 self.adhoc_result.data[idx, i], 727 self.adhoc_result.shap_values[idx, i], 728 marker="d", # type: ignore 729 s=200, 730 label=f"input_{idx+1}", 731 color=colors[idx % len(colors)], 732 ) 733 734 # expected_shap_values = ( 735 # reference_shap_values 736 # if self.reference_result.shap_values.shape[0] > 0 737 # else adhoc_shap_values 738 # ) 739 # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x 740 741 # mean_feature_value = expected_x[:, i].mean() 742 # min_y = expected_shap_values[:, i].min() 743 # max_y = expected_shap_values[:, i].max() 744 # plt.vlines( 745 # mean_feature_value, 746 # min_y, 747 # max_y, 748 # colors="gray", 749 # linestyle="dotted", 750 # label=f"Mean {feature_name}", 751 # ) 752 plt.legend() 753 plt.show()
This class holds the explainability result created by processing an explainability request.
id and status are optional and will be filled in by processing. The id will be the same as the request id since the results are stored in minio.
num_inferences and num_batches are nice to know status information and could be brought into the status object in the future.
reference and adhoc data are the actual inferences used in the analysis.
reference and adhoc shap values are the shap values for each feature for each prediction.
base_values are the expected value for each prediction. These values will all be the same so may be changed to a single float in the future.
pdp results are a list of pdp/ice results for each feature.
421 def feature_effects(self) -> pd.DataFrame: 422 """Returns a dataframe summarizing the mean feature effects of the reference 423 data as well as the feature effects for each adhoc inference.""" 424 425 df = pd.DataFrame(index=self.feature_names) 426 427 if self.reference_result is not None: 428 vals = np.array(self.reference_result.shap_values) 429 if vals.shape[0] > 0: 430 df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0) 431 df["ref_std_dev"] = vals.std(axis=0) 432 df[REF_MAX_ABS] = np.abs(vals).max(axis=0) 433 434 if self.window_result is not None: 435 vals = np.array(self.window_result.shap_values) 436 if vals.shape[0] > 0: 437 df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0) 438 df["window_std_dev"] = vals.std(axis=0) 439 df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0) 440 441 if self.adhoc_result is not None: 442 vals = np.array(self.adhoc_result.shap_values) 443 if vals.shape[0] > 0: 444 for idx in range(vals.shape[0]): 445 df[f"input_{idx+1}"] = vals[idx, :] 446 447 if REF_MEAN_ABS in df.columns: 448 return df.sort_values(by=REF_MEAN_ABS, ascending=False) 449 return df
Returns a dataframe summarizing the mean feature effects of the reference data as well as the feature effects for each adhoc inference.
451 def effect_summary(self) -> pd.DataFrame: 452 """Returns a dataframe with the expected/mean values and the shap adjustments.""" 453 454 effects = self.feature_effects() 455 base_value = unwrap(self.reference_result).base_values[0] 456 data = {} 457 data["base_value"] = [0, base_value] 458 459 # we should rename inputs_X to 'house_X' or similar 460 input_cols = [c for c in effects.columns if "input_" in c] 461 effect_sums = effects[input_cols].sum(axis=0) 462 for c, v in zip(input_cols, effect_sums): 463 data[c] = [v, v + base_value] 464 return pd.DataFrame(data, index=["adjustment", "total"])
Returns a dataframe with the expected/mean values and the shap adjustments.
466 def check_status(self) -> bool: 467 """Ensure we've completed before trying to plot anything.""" 468 469 if self.status["status"] != "COMPLETED": 470 raise Exception(f"Analysis has not (yet) completed: {self.status}") 471 return True
Ensure we've completed before trying to plot anything.
473 def plot_feature_effects( 474 self, 475 mode: str = "mean", 476 top_n=0, 477 plot_reference=True, 478 plot_window=True, 479 plot_adhoc=True, 480 ): 481 """Creates a bar plot of the mean or max abs feature effects.""" 482 483 if mode not in ["mean", "max", "individual"]: 484 raise Exception("Mode must be one of: 'mean', 'max', 'individual'") 485 486 self.check_status() 487 488 df = self.feature_effects() 489 if top_n > 0: 490 df = df.head(top_n) 491 492 cols = df.columns 493 if not plot_reference: 494 cols = [c for c in cols if not c.startswith("ref_")] 495 if not plot_window: 496 cols = [c for c in cols if not c.startswith("window_")] 497 if not plot_adhoc: 498 cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")] 499 500 df = df[cols] 501 502 if REF_MEAN_ABS in df.columns and mode != "individual": 503 if mode == "max": 504 title = "Max Absolute Feature Effect" 505 _ = plt.bar(df.index, df[REF_MAX_ABS]) 506 else: 507 title = "Mean Absolute Feature Effect" 508 _ = plt.bar(df.index, df[REF_MEAN_ABS]) 509 else: 510 title = "Feature Effects" 511 ax = plt.subplot(1, 1, 1) 512 _ = df.drop( 513 ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore" 514 ).plot(kind="bar", ax=ax) 515 plt.legend(bbox_to_anchor=(1, 1), loc="upper left") 516 517 plt.grid() 518 plt.xticks(rotation=90) 519 plt.title(title)
Creates a bar plot of the mean or max abs feature effects.
521 def plot_ice_values( 522 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 523 ): 524 """Creates a combination ICE plot for the adhoc data if any 525 in custom colors and the reference data if any in translucent 526 blue.""" 527 528 self.check_status() 529 530 ice_alpha = 0.2 531 ice_width = 5 532 533 xs = self.feature_bounds[feature_name].xs 534 535 # pdp_result = unwrap(self.pdp_result(feature_name)) 536 # xs = pdp_result.xs 537 # mean_vals = np.array(pdp_result.pdp_vals) 538 # ice_vals = np.array(pdp_result.ice_vals) 539 540 plt.title(f"ICE for {feature_name}") 541 plt.xlabel(feature_name) 542 plt.ylabel("Prediction") 543 544 if plot_reference and self.reference_result is not None: 545 pdp_result = find_pdp_result(self.reference_result, feature_name) 546 ice_vals = pdp_result.ice_vals 547 mean_vals = pdp_result.pdp_vals 548 _ = plt.plot( 549 xs, 550 ice_vals, 551 color=self.reference_color, 552 alpha=ice_alpha, 553 linewidth=ice_width, 554 ) 555 _ = plt.plot( 556 xs, mean_vals, color="black", zorder=10, label="Reference Mean" 557 ) 558 559 if plot_window and self.window_result is not None: 560 pdp_result = find_pdp_result(self.window_result, feature_name) 561 ice_vals = pdp_result.ice_vals 562 mean_vals = pdp_result.pdp_vals 563 _ = plt.plot( 564 xs, 565 ice_vals, 566 color=self.window_color, 567 alpha=ice_alpha, 568 linewidth=ice_width, 569 ) 570 _ = plt.plot( 571 xs, mean_vals, linestyle="dashed", color="black", label="Window Mean" 572 ) 573 574 if plot_adhoc and self.adhoc_result is not None: 575 pdp_result = find_pdp_result(self.adhoc_result, feature_name) 576 ice_vals = np.array(pdp_result.ice_vals) 577 mean_vals = pdp_result.pdp_vals 578 for idx in range(self.adhoc_result.data.shape[0]): 579 _ = plt.plot( 580 xs, 581 ice_vals[:, idx : idx + 1], 582 linewidth=3, 583 label=f"input_{idx+1}", 584 color=colors[idx % len(colors)], 585 ) 586 plt.legend() 587 588 plt.ylim(0) 589 _ = plt.grid() 590 plt.show()
Creates a combination ICE plot for the adhoc data if any in custom colors and the reference data if any in translucent blue.
592 def plot_all_features( 593 self, 594 title="Feature Effect per Inference", 595 plot_reference=True, 596 plot_window=True, 597 plot_adhoc=True, 598 top_n: int = 0, 599 ): 600 """Creates a 'bee swarm' plot of all/each feature effect.""" 601 self.check_status() 602 np.random.seed(42) 603 604 # Our custom palette will go from black to a weird blue green. 605 # Not sure of the best palette to use. We need lighter to be 606 # higher values. Started with a red to green ramp but we also 607 # need to be aware to color blindness issues. 608 rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)] 609 610 # Gather the data we'll need 611 shap_values_list = [] 612 feature_values_list = [] 613 if plot_reference and self.reference_result: 614 shap_values_list.append(self.reference_result.shap_values) 615 feature_values_list.append(self.reference_result.data) 616 if plot_window and self.window_result: 617 shap_values_list.append(self.window_result.shap_values) 618 feature_values_list.append(self.window_result.data) 619 if plot_adhoc and self.adhoc_result: 620 shap_values_list.append(self.adhoc_result.shap_values) 621 feature_values_list.append(self.adhoc_result.data) 622 623 if not shap_values_list: 624 raise Exception("Some data must be specified to create a plot.") 625 626 shap_values = np.vstack(shap_values_list) 627 feature_values = np.vstack(feature_values_list) 628 629 # create a df of the shap values / contributions 630 df = pd.DataFrame(shap_values) 631 df.columns = self.feature_names 632 # create a df of the original feature values 633 feature_df = pd.DataFrame(feature_values) 634 feature_df.columns = self.feature_names 635 636 # We plot in strips from the bottom so put the most important at the top. 637 # first taking the top_n if specified. 638 feature_names = list(self.feature_effects().index) 639 if top_n > 0: 640 feature_names = feature_names[:top_n] 641 feature_names = list(reversed(feature_names)) 642 643 num_features = len(feature_names) 644 num_obs = df.shape[0] 645 646 fig = plt.figure() 647 fig.patch.set_facecolor("white") 648 ax = fig.get_axes() 649 650 for i, col in enumerate(feature_names): 651 # create a temp df where the y is the level + jitter 652 dfc = pd.DataFrame(df[col]) 653 dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs) 654 # vals is the original feature values and we create a custom palette 655 dfc["vals"] = feature_df[col] 656 unique_vals = dfc["vals"].unique() 657 n_colors = len(unique_vals) 658 palette = sns.blend_palette(rgb, n_colors=n_colors) 659 660 # plot a scatter plot strip 661 ax = sns.scatterplot( 662 x=col, 663 y="y", 664 data=dfc, 665 alpha=0.75, 666 hue="vals", 667 palette=palette, 668 legend=None, 669 ) 670 671 # change the tick labels from strip number to feature name 672 ax.set_yticks(range(num_features + 1)) 673 ticks = [""] 674 ticks.extend(feature_names) 675 ax.set_yticklabels(ticks) 676 677 plt.xlabel("Shap Value") 678 plt.ylabel("") 679 plt.title(title) 680 plt.ylim(0, num_features + 1) 681 plt.grid() 682 plt.show()
Creates a 'bee swarm' plot of all/each feature effect.
684 def plot_feature_values_vs_shap_values( 685 self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True 686 ): 687 """Creates a scatter plot of the feature vs shap values. 688 adhoc data if any is in custom colors. reference data in translucent 689 blue.""" 690 691 self.check_status() 692 693 alpha = 0.5 694 695 i = self.feature_names.index(feature_name) 696 697 plt.title(f"Shap Values for {feature_name} for Each Input") 698 plt.xlabel(feature_name) 699 plt.ylabel("Shap Values") 700 plt.grid() 701 702 if plot_reference and self.reference_result: 703 _ = plt.scatter( 704 self.reference_result.data[:, i], 705 self.reference_result.shap_values[:, i], 706 s=100, 707 alpha=alpha, 708 color=self.reference_color, 709 label="Reference", 710 ) 711 712 if plot_window and self.window_result: 713 _ = plt.scatter( 714 self.window_result.data[:, i], 715 self.window_result.shap_values[:, i], 716 marker="p", 717 s=100, 718 alpha=alpha, 719 color=self.window_color, 720 label="Window", 721 ) 722 723 if plot_adhoc and self.adhoc_result: 724 for idx in range(self.adhoc_result.data.shape[0]): 725 _ = plt.scatter( 726 self.adhoc_result.data[idx, i], 727 self.adhoc_result.shap_values[idx, i], 728 marker="d", # type: ignore 729 s=200, 730 label=f"input_{idx+1}", 731 color=colors[idx % len(colors)], 732 ) 733 734 # expected_shap_values = ( 735 # reference_shap_values 736 # if self.reference_result.shap_values.shape[0] > 0 737 # else adhoc_shap_values 738 # ) 739 # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x 740 741 # mean_feature_value = expected_x[:, i].mean() 742 # min_y = expected_shap_values[:, i].min() 743 # max_y = expected_shap_values[:, i].max() 744 # plt.vlines( 745 # mean_feature_value, 746 # min_y, 747 # max_y, 748 # colors="gray", 749 # linestyle="dotted", 750 # label=f"Mean {feature_name}", 751 # ) 752 plt.legend() 753 plt.show()
Creates a scatter plot of the feature vs shap values. adhoc data if any is in custom colors. reference data in translucent blue.
756def build_window_result(data) -> Optional[WindowResult]: 757 if data is not None: 758 return WindowResult( 759 data=np.array(data["data"]), 760 shap_values=np.array(data["shap_values"]), 761 base_values=np.array(data["base_values"]), 762 pdp_results=[PDPResult(**d) for d in data["pdp_results"]], 763 ) 764 else: 765 return None
768def build_explainability_result(data): 769 """Convinience function to parse json into the full result object 770 we want.""" 771 772 return ExplainabilityResult( 773 id=data["id"], 774 workspace_id=data["workspace_id"], 775 explainability_config_id=data["explainability_config_id"], 776 num_inferences=data["num_inferences"], 777 num_batches=data["num_batches"], 778 compute_time=data["compute_time"], 779 status=data["status"], 780 feature_names=data["feature_names"], 781 feature_bounds={ 782 k: FeatureBounds(**v) for k, v in data["feature_bounds"].items() 783 }, 784 reference_result=build_window_result(data["reference_result"]), 785 window_result=build_window_result(data["window_result"]), 786 adhoc_result=build_window_result(data["adhoc_result"]), 787 )
Convinience function to parse json into the full result object we want.