wallaroo.explainability

  1from dataclasses import asdict, dataclass
  2from datetime import datetime, timezone
  3from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  4from uuid import UUID
  5
  6import matplotlib.pyplot as plt
  7import numpy as np
  8import pandas as pd
  9import seaborn as sns
 10from scipy.stats import norm
 11
 12from wallaroo.unwrap import unwrap
 13
 14if TYPE_CHECKING:
 15    # Imports that happen below in methods to fix circular import dependency
 16    # issues need to also be specified here to satisfy mypy type checking.
 17    pass
 18
 19NDArray = np.ndarray
 20
 21REF_MEAN_ABS = "ref_mean_abs"
 22REF_MAX_ABS = "ref_max_abs"
 23WINDOW_MEAN_ABS = "window_mean_abs"
 24WINDOW_MAX_ABS = "window_max_abs"
 25
 26
 27# To keep colors consistent across plots explicitly list the ones we use.
 28colors = [
 29    "red",
 30    "green",
 31    "blue",
 32    "gold",
 33    "fuchsia",
 34    "tan",
 35    "yellowgreen",
 36    "darkkhaki",
 37    "brown",
 38    "azure",
 39    "darkorchid",
 40    "darkslategrey",
 41    "moccasin",
 42    "sandybrown",
 43    "aquamarine",
 44    "springgreen",
 45    "sienna",
 46    "cyan",
 47    "mediumpurple",
 48    "midnightblue",
 49    "plum",
 50    "deeppink",
 51    "crimson",
 52    "lightgoldenrodyellow",
 53    "lightgreen",
 54    "lawngreen",
 55    "bisque",
 56    "steelblue",
 57    "cornflowerblue",
 58    "lightblue",
 59    "darksalmon",
 60    "lightslategrey",
 61]
 62
 63
 64def truncate(v: Any, num_char: int = 256) -> str:
 65    s = str(v)
 66    if len(s) < num_char:
 67        return s
 68    else:
 69        postfix = " ..."
 70        return f"{s[:num_char - len(postfix)]}{postfix}"
 71
 72
 73@dataclass
 74class WindowRequestInput:
 75    start: Optional[str]  # These are strings in our request but
 76    end: Optional[str]  # should be interpreted as 3339 formated datetimes
 77    num_samples: Optional[int]
 78
 79
 80class ExplainabilityConfigList(list):
 81    def _repr_html_(self):
 82        rows = [
 83            f"""
 84        <tr">
 85            <td>{exp_config.id}</td>
 86            <td>{exp_config.status}</td>
 87        </tr>
 88        """
 89            for exp_config in self
 90        ]
 91        table = """<table>
 92            <tr>
 93                <th>Id</th>
 94                <th>Status</th>
 95            </tr>
 96            {0}
 97        </table>""".format(
 98            "\n".join(rows)
 99        )
100
101        return table
102
103
104class ExplainabilityRequestList(list):
105    """Adds a _repr_html_ to a list of explainability requests."""
106
107    def _repr_html_(self):
108        """Assuming all items are explainabilty requests generates an HTML
109        table for jupyter."""
110        rows = [
111            f"""
112        <tr">
113            <td>{exp_req.id}</td>
114            <td>{exp_req.status}</td>
115            <td>{exp_req.reference_config}</td>
116            <td>{exp_req.window_config}</td>
117            <td>{exp_req.use_adhoc_data}</td>
118        </tr>
119        """
120            for exp_req in self
121        ]
122        table = """<table>
123            <tr>
124                <th>Id</th>
125                <th>Status</th>
126                <th>Reference Data</th>
127                <th>Window Data</th>
128                <th>Adhoc</th>
129            </tr>
130            {0}
131        </table>""".format(
132            "\n".join(rows)
133        )
134
135        return table
136
137
138@dataclass
139class FeatureBounds:
140    min: float
141    max: float
142    xs: List[float]
143
144
145@dataclass
146class ExplainabilityConfig:
147    """This class specifies an explainability configuration that can be used
148    to later submit explainability requests which cause the server to do the
149    analysis and create explainability results.
150
151    ExplainabilityConfig are necessary to ensure the explainability pipeline
152    is created and is deployed and so that various requests are processed in
153    the same manner and can be compared.
154
155    id, status, feature_bounds and reference_pipeline_version are optional
156    and will be filled out when processed and saved to the database.
157
158    workspace id must match the users/pipelines workspace and
159    reference_pipeline_version must refer to a valid pipeline version that the
160    user has access too.
161
162    num_points specifies how many samples to take when varying the values of a
163    feature for the PDP/ICE analysis through the feature_bounds.
164
165    feature_names are convinince for the user. output_names is not currently used.
166    """
167
168    id: Optional[UUID]
169    workspace_id: int
170    status: Optional[Dict[str, Any]]
171    reference_pipeline_version: str
172    explainability_pipeline_version: Optional[str]
173    num_points: int = 10
174    feature_names: Optional[List[str]] = None
175    feature_bounds: Optional[Dict[str, FeatureBounds]] = None
176    output_names: Optional[List[str]] = None
177
178    def _repr_html_(self):
179        fields = [
180            f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>"
181            for k, v in asdict(self).items()
182        ]
183        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
184
185    def list_explainability_requests(self) -> List["ExplainabilityRequest"]:
186        """List the explainability requests we've created."""
187
188        client = self.client  # type: ignore
189        result = client._post_rest_api_json(
190            f"v1/api/explainability/list_requests",
191            {"explainability_config_id": self.id},
192        )
193        erl = [ExplainabilityRequest(**ec) for ec in result]
194        for er in erl:
195            er.client = client  # type: ignore
196        return ExplainabilityRequestList(erl)
197
198    def get_explainability_request(
199        self, expr: Union[str, "ExplainabilityConfig"]
200    ) -> Optional["ExplainabilityRequest"]:
201        """Get the full explainability result whether completed or not."""
202
203        if isinstance(expr, str):
204            explain_id = expr
205        else:
206            explain_id = str(expr.id)
207
208        client = self.client  # type: ignore
209        result = client._post_rest_api_json(
210            f"v1/api/explainability/get_request",
211            {"explainability_request_id": explain_id},
212        )
213
214        exp_cfg = ExplainabilityRequest(**result)
215        exp_cfg.client = client  # type: ignore
216        return exp_cfg
217
218    def submit_explainability_request(
219        self,
220        reference_start: Optional[datetime] = None,
221        reference_end: Optional[datetime] = None,
222        reference_num_samples: Optional[int] = None,
223        use_reference: Optional[bool] = True,
224        window_start: Optional[datetime] = None,
225        window_end: Optional[datetime] = None,
226        window_num_samples: Optional[int] = None,
227        adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None,
228    ):
229        """Submit an analysis on reference or adhoc data using a particular config"""
230
231        reference_config = None
232        if use_reference:
233            reference_config = WindowRequestInput(
234                start=maybe_format_date(reference_start),
235                end=maybe_format_date(reference_end),
236                num_samples=reference_num_samples,
237            )
238
239        window_config = None
240        if window_start or window_end or window_num_samples:
241            window_config = WindowRequestInput(
242                start=maybe_format_date(window_start),
243                end=maybe_format_date(window_end),
244                num_samples=window_num_samples,
245            )
246
247        use_adhoc_data = False
248        if adhoc_data is not None:
249            use_adhoc_data = True
250
251        adhoc_data_list = None
252        if use_adhoc_data:
253            if isinstance(adhoc_data, list):
254                adhoc_data_list = adhoc_data
255            elif isinstance(adhoc_data, np.ndarray):
256                adhoc_data_list = adhoc_data.tolist()
257            elif isinstance(adhoc_data, pd.DataFrame):
258                adhoc_data_list = adhoc_data.values.tolist()
259            else:
260                raise Exception(f"Unknown adhoc data type {type(adhoc_data)}")
261
262        if reference_config is None and window_config is None and not use_adhoc_data:
263            raise Exception(
264                "You must specify a reference config, a window config or adhoc data"
265            )
266
267        client = self.client  # type: ignore
268        exp_config_id = self.id
269        workspace_id = client.get_current_workspace().id()
270
271        ear = ExplainabilityRequest(
272            id=None,
273            explainability_config_id=exp_config_id,
274            workspace_id=workspace_id,
275            reference_config=reference_config,
276            window_config=window_config,
277            use_adhoc_data=use_adhoc_data,
278            adhoc_data=adhoc_data_list,
279        )
280
281        result = client._post_rest_api_json(
282            f"v1/api/explainability/create_request", asdict(ear)
283        )
284        expr_id = result["id"]
285        return self.get_explainability_request(expr_id)
286
287
288def maybe_format_date(d: Optional[datetime]) -> Optional[str]:
289    if d:
290        return d.astimezone(tz=timezone.utc).isoformat()
291    return None
292
293
294@dataclass
295class ExplainabilityRequest:
296    """This class outlines what should be submitted to start the explainability
297    analysis with a particular config.
298
299    The request can be to analyze reference data, historical data from the ref
300    pipeline, or new adhoc data submitted with the request or both.
301
302    id and status are optional and are filled in by the processing steps.
303
304    If the request has use_reference_data = True, num_sample inference logs are
305    sampled from between the start and end dates or the entire (last 100_000)
306    inferences.
307    """
308
309    id: Optional[UUID]
310    workspace_id: int
311    explainability_config_id: Optional[UUID] = None
312    status: Optional[Dict[str, Any]] = None
313    reference_config: Optional[WindowRequestInput] = None
314    window_config: Optional[WindowRequestInput] = None
315    use_adhoc_data: bool = False
316    adhoc_data: Optional[List[List[float]]] = None
317
318    def _repr_html_(self):
319        fields = [f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in asdict(self).items()]
320        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
321
322    def get_explainability_result(self) -> Optional["ExplainabilityResult"]:
323        """Get the full explainability result whether completed or not."""
324
325        client = self.client  # type: ignore
326
327        result = client._post_rest_api_json(
328            f"v1/api/explainability/get_result",
329            {"explainability_result_id": self.id},
330        )
331        return build_explainability_result(result)
332
333
334@dataclass
335class PDPResult:
336    """This class holds the PDP/ICE part of the results.
337    PDP/ICE results are generated for each observation by holding
338    all but one feature constant, varying that feature and analyzing
339    that prediction. Thus the results are per inference per feature.
340
341    feature_name is the feature that this result is for.
342    xs is the list of x values that the feature was varied through.
343
344    pdp_vals is the list of resulting values.
345    model, shap and feature expected values are the mean/expected values
346    for that model, shap and feature.
347    """
348
349    feature_name: str
350    ice_vals: np.ndarray
351    pdp_vals: List[float]
352    model_expected_value: List[float]
353    shap_expected_value: List[float]
354    feature_expected_value: List[float]
355
356
357@dataclass
358class WindowResult:
359    data: NDArray  # the original data used for the analysis
360    shap_values: NDArray  # the caculated shap values
361    base_values: NDArray  # The expected value (mean) for each prediction.
362    pdp_results: List[PDPResult]  # Partial dependence plot data for each feature
363
364
365def find_pdp_result(window_result: WindowResult, feature_name: str) -> PDPResult:
366    """Gets the pdp result object for the specified feature."""
367
368    for p in window_result.pdp_results:
369        if p.feature_name == feature_name:
370            return p
371    raise Exception(f"Did not find feature {feature_name} in pdp_results.")
372
373
374@dataclass
375class ExplainabilityResult:
376    """This class holds the explainability result created by processing an
377    explainability request.
378
379    id and status are optional and will be filled in by processing. The id
380    will be the same as the request id since the results are stored in minio.
381
382    num_inferences and num_batches are nice to know status information and
383    could be brought into the status object in the future.
384
385    reference and adhoc data are the actual inferences used in the analysis.
386
387    reference and adhoc shap values are the shap values for each feature for
388    each prediction.
389
390    base_values are the expected value for each prediction. These values will
391    all be the same so may be changed to a single float in the future.
392
393    pdp results are a list of pdp/ice results for each feature.
394
395    """
396
397    id: Optional[UUID]
398    workspace_id: int
399    explainability_config_id: UUID
400    num_inferences: int
401    num_batches: int
402    compute_time: float
403    status: Dict[str, Any]
404    feature_names: List[str]  # The names of the columns
405    feature_bounds: Dict[str, FeatureBounds]
406    reference_result: Optional[WindowResult]
407    window_result: Optional[WindowResult]
408    adhoc_result: Optional[WindowResult]
409
410    reference_color = np.array([127.0, 196, 252]) / 255
411    window_color = np.array([252, 127.0, 196]) / 255
412
413    def _repr_html_(self):
414        fields = [
415            f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>"
416            for k, v in asdict(self).items()
417        ]
418        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
419
420    def feature_effects(self) -> pd.DataFrame:
421        """Returns a dataframe summarizing the mean feature effects of the reference
422        data as well as the feature effects for each adhoc inference."""
423
424        df = pd.DataFrame(index=self.feature_names)
425
426        if self.reference_result is not None:
427            vals = np.array(self.reference_result.shap_values)
428            if vals.shape[0] > 0:
429                df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0)
430                df["ref_std_dev"] = vals.std(axis=0)
431                df[REF_MAX_ABS] = np.abs(vals).max(axis=0)
432
433        if self.window_result is not None:
434            vals = np.array(self.window_result.shap_values)
435            if vals.shape[0] > 0:
436                df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0)
437                df["window_std_dev"] = vals.std(axis=0)
438                df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0)
439
440        if self.adhoc_result is not None:
441            vals = np.array(self.adhoc_result.shap_values)
442            if vals.shape[0] > 0:
443                for idx in range(vals.shape[0]):
444                    df[f"input_{idx+1}"] = vals[idx, :]
445
446        if REF_MEAN_ABS in df.columns:
447            return df.sort_values(by=REF_MEAN_ABS, ascending=False)
448        return df
449
450    def effect_summary(self) -> pd.DataFrame:
451        """Returns a dataframe with the expected/mean values and the shap adjustments."""
452
453        effects = self.feature_effects()
454        base_value = unwrap(self.reference_result).base_values[0]
455        data = {}
456        data["base_value"] = [0, base_value]
457
458        # we should rename inputs_X to 'house_X' or similar
459        input_cols = [c for c in effects.columns if "input_" in c]
460        effect_sums = effects[input_cols].sum(axis=0)
461        for c, v in zip(input_cols, effect_sums):
462            data[c] = [v, v + base_value]
463        return pd.DataFrame(data, index=["adjustment", "total"])
464
465    def check_status(self) -> bool:
466        """Ensure we've completed before trying to plot anything."""
467
468        if self.status["status"] != "COMPLETED":
469            raise Exception(f"Analysis has not (yet) completed: {self.status}")
470        return True
471
472    def plot_feature_effects(
473        self,
474        mode: str = "mean",
475        top_n=0,
476        plot_reference=True,
477        plot_window=True,
478        plot_adhoc=True,
479    ):
480        """Creates a bar plot of the mean or max abs feature effects."""
481
482        if mode not in ["mean", "max", "individual"]:
483            raise Exception("Mode must be one of: 'mean', 'max', 'individual'")
484
485        self.check_status()
486
487        df = self.feature_effects()
488        if top_n > 0:
489            df = df.head(top_n)
490
491        cols = df.columns
492        if not plot_reference:
493            cols = [c for c in cols if not c.startswith("ref_")]
494        if not plot_window:
495            cols = [c for c in cols if not c.startswith("window_")]
496        if not plot_adhoc:
497            cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")]
498
499        df = df[cols]
500
501        if REF_MEAN_ABS in df.columns and mode != "individual":
502            if mode == "max":
503                title = "Max Absolute Feature Effect"
504                _ = plt.bar(df.index, df[REF_MAX_ABS])
505            else:
506                title = "Mean Absolute Feature Effect"
507                _ = plt.bar(df.index, df[REF_MEAN_ABS])
508        else:
509            title = "Feature Effects"
510            ax = plt.subplot(1, 1, 1)
511            _ = df.drop(
512                ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore"
513            ).plot(kind="bar", ax=ax)
514            plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
515
516        plt.grid()
517        plt.xticks(rotation=90)
518        plt.title(title)
519
520    def plot_ice_values(
521        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
522    ):
523        """Creates a combination ICE plot for the adhoc data if any
524        in custom colors and the reference data if any in translucent
525        blue."""
526
527        self.check_status()
528
529        ice_alpha = 0.2
530        ice_width = 5
531
532        xs = self.feature_bounds[feature_name].xs
533
534        # pdp_result = unwrap(self.pdp_result(feature_name))
535        # xs = pdp_result.xs
536        # mean_vals = np.array(pdp_result.pdp_vals)
537        # ice_vals = np.array(pdp_result.ice_vals)
538
539        plt.title(f"ICE for {feature_name}")
540        plt.xlabel(feature_name)
541        plt.ylabel("Prediction")
542
543        if plot_reference and self.reference_result is not None:
544            pdp_result = find_pdp_result(self.reference_result, feature_name)
545            ice_vals = pdp_result.ice_vals
546            mean_vals = pdp_result.pdp_vals
547            _ = plt.plot(
548                xs,
549                ice_vals,
550                color=self.reference_color,
551                alpha=ice_alpha,
552                linewidth=ice_width,
553            )
554            _ = plt.plot(
555                xs, mean_vals, color="black", zorder=10, label="Reference Mean"
556            )
557
558        if plot_window and self.window_result is not None:
559            pdp_result = find_pdp_result(self.window_result, feature_name)
560            ice_vals = pdp_result.ice_vals
561            mean_vals = pdp_result.pdp_vals
562            _ = plt.plot(
563                xs,
564                ice_vals,
565                color=self.window_color,
566                alpha=ice_alpha,
567                linewidth=ice_width,
568            )
569            _ = plt.plot(
570                xs, mean_vals, linestyle="dashed", color="black", label="Window Mean"
571            )
572
573        if plot_adhoc and self.adhoc_result is not None:
574            pdp_result = find_pdp_result(self.adhoc_result, feature_name)
575            ice_vals = np.array(pdp_result.ice_vals)
576            mean_vals = pdp_result.pdp_vals
577            for idx in range(self.adhoc_result.data.shape[0]):
578                _ = plt.plot(
579                    xs,
580                    ice_vals[:, idx : idx + 1],
581                    linewidth=3,
582                    label=f"input_{idx+1}",
583                    color=colors[idx % len(colors)],
584                )
585            plt.legend()
586
587        plt.ylim(0)
588        _ = plt.grid()
589        plt.show()
590
591    def plot_all_features(
592        self,
593        title="Feature Effect per Inference",
594        plot_reference=True,
595        plot_window=True,
596        plot_adhoc=True,
597        top_n: int = 0,
598    ):
599        """Creates a 'bee swarm' plot of all/each feature effect."""
600        self.check_status()
601        np.random.seed(42)
602
603        # Our custom palette will go from black to a weird blue green.
604        # Not sure of the best palette to use. We need lighter to be
605        # higher values. Started with a red to green ramp but we also
606        # need to be aware to color blindness issues.
607        rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)]
608
609        # Gather the data we'll need
610        shap_values_list = []
611        feature_values_list = []
612        if plot_reference and self.reference_result:
613            shap_values_list.append(self.reference_result.shap_values)
614            feature_values_list.append(self.reference_result.data)
615        if plot_window and self.window_result:
616            shap_values_list.append(self.window_result.shap_values)
617            feature_values_list.append(self.window_result.data)
618        if plot_adhoc and self.adhoc_result:
619            shap_values_list.append(self.adhoc_result.shap_values)
620            feature_values_list.append(self.adhoc_result.data)
621
622        if not shap_values_list:
623            raise Exception("Some data must be specified to create a plot.")
624
625        shap_values = np.vstack(shap_values_list)
626        feature_values = np.vstack(feature_values_list)
627
628        # create a df of the shap values / contributions
629        df = pd.DataFrame(shap_values)
630        df.columns = self.feature_names
631        # create a df of the original feature values
632        feature_df = pd.DataFrame(feature_values)
633        feature_df.columns = self.feature_names
634
635        # We plot in strips from the bottom so put the most important at the top.
636        # first taking the top_n if specified.
637        feature_names = list(self.feature_effects().index)
638        if top_n > 0:
639            feature_names = feature_names[:top_n]
640        feature_names = list(reversed(feature_names))
641
642        num_features = len(feature_names)
643        num_obs = df.shape[0]
644
645        fig = plt.figure()
646        fig.patch.set_facecolor("white")
647        ax = fig.get_axes()
648
649        for i, col in enumerate(feature_names):
650            # create a temp df where the y is the level + jitter
651            dfc = pd.DataFrame(df[col])
652            dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs)
653            # vals is the original feature values and we create a custom palette
654            dfc["vals"] = feature_df[col]
655            unique_vals = dfc["vals"].unique()
656            n_colors = len(unique_vals)
657            palette = sns.blend_palette(rgb, n_colors=n_colors)
658
659            # plot a scatter plot strip
660            ax = sns.scatterplot(
661                x=col,
662                y="y",
663                data=dfc,
664                alpha=0.75,
665                hue="vals",
666                palette=palette,
667                legend=None,
668            )
669
670        # change the tick labels from strip number to feature name
671        ax.set_yticks(range(num_features + 1))
672        ticks = [""]
673        ticks.extend(feature_names)
674        ax.set_yticklabels(ticks)
675
676        plt.xlabel("Shap Value")
677        plt.ylabel("")
678        plt.title(title)
679        plt.ylim(0, num_features + 1)
680        plt.grid()
681        plt.show()
682
683    def plot_feature_values_vs_shap_values(
684        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
685    ):
686        """Creates a scatter plot of the feature vs shap values.
687        adhoc data if any is in custom colors. reference data in translucent
688        blue."""
689
690        self.check_status()
691
692        alpha = 0.5
693
694        i = self.feature_names.index(feature_name)
695
696        plt.title(f"Shap Values for {feature_name} for Each Input")
697        plt.xlabel(feature_name)
698        plt.ylabel("Shap Values")
699        plt.grid()
700
701        if plot_reference and self.reference_result:
702            _ = plt.scatter(
703                self.reference_result.data[:, i],
704                self.reference_result.shap_values[:, i],
705                s=100,
706                alpha=alpha,
707                color=self.reference_color,
708                label="Reference",
709            )
710
711        if plot_window and self.window_result:
712            _ = plt.scatter(
713                self.window_result.data[:, i],
714                self.window_result.shap_values[:, i],
715                marker="p",
716                s=100,
717                alpha=alpha,
718                color=self.window_color,
719                label="Window",
720            )
721
722        if plot_adhoc and self.adhoc_result:
723            for idx in range(self.adhoc_result.data.shape[0]):
724                _ = plt.scatter(
725                    self.adhoc_result.data[idx, i],
726                    self.adhoc_result.shap_values[idx, i],
727                    marker="d",  # type: ignore
728                    s=200,
729                    label=f"input_{idx+1}",
730                    color=colors[idx % len(colors)],
731                )
732
733        # expected_shap_values = (
734        #     reference_shap_values
735        #     if self.reference_result.shap_values.shape[0] > 0
736        #     else adhoc_shap_values
737        # )
738        # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x
739
740        # mean_feature_value = expected_x[:, i].mean()
741        # min_y = expected_shap_values[:, i].min()
742        # max_y = expected_shap_values[:, i].max()
743        # plt.vlines(
744        #     mean_feature_value,
745        #     min_y,
746        #     max_y,
747        #     colors="gray",
748        #     linestyle="dotted",
749        #     label=f"Mean {feature_name}",
750        # )
751        plt.legend()
752        plt.show()
753
754
755def build_window_result(data) -> Optional[WindowResult]:
756    if data is not None:
757        return WindowResult(
758            data=np.array(data["data"]),
759            shap_values=np.array(data["shap_values"]),
760            base_values=np.array(data["base_values"]),
761            pdp_results=[PDPResult(**d) for d in data["pdp_results"]],
762        )
763    else:
764        return None
765
766
767def build_explainability_result(data):
768    """Convinience function to parse json into the full result object
769    we want."""
770
771    return ExplainabilityResult(
772        id=data["id"],
773        workspace_id=data["workspace_id"],
774        explainability_config_id=data["explainability_config_id"],
775        num_inferences=data["num_inferences"],
776        num_batches=data["num_batches"],
777        compute_time=data["compute_time"],
778        status=data["status"],
779        feature_names=data["feature_names"],
780        feature_bounds={
781            k: FeatureBounds(**v) for k, v in data["feature_bounds"].items()
782        },
783        reference_result=build_window_result(data["reference_result"]),
784        window_result=build_window_result(data["window_result"]),
785        adhoc_result=build_window_result(data["adhoc_result"]),
786    )
def truncate(v: Any, num_char: int = 256) -> str:
65def truncate(v: Any, num_char: int = 256) -> str:
66    s = str(v)
67    if len(s) < num_char:
68        return s
69    else:
70        postfix = " ..."
71        return f"{s[:num_char - len(postfix)]}{postfix}"
@dataclass
class WindowRequestInput:
74@dataclass
75class WindowRequestInput:
76    start: Optional[str]  # These are strings in our request but
77    end: Optional[str]  # should be interpreted as 3339 formated datetimes
78    num_samples: Optional[int]

WindowRequestInput(start: Optional[str], end: Optional[str], num_samples: Optional[int])

WindowRequestInput(start: Optional[str], end: Optional[str], num_samples: Optional[int])
class ExplainabilityConfigList(builtins.list):
 81class ExplainabilityConfigList(list):
 82    def _repr_html_(self):
 83        rows = [
 84            f"""
 85        <tr">
 86            <td>{exp_config.id}</td>
 87            <td>{exp_config.status}</td>
 88        </tr>
 89        """
 90            for exp_config in self
 91        ]
 92        table = """<table>
 93            <tr>
 94                <th>Id</th>
 95                <th>Status</th>
 96            </tr>
 97            {0}
 98        </table>""".format(
 99            "\n".join(rows)
100        )
101
102        return table

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

Inherited Members
builtins.list
list
clear
copy
append
insert
extend
pop
remove
index
count
reverse
sort
class ExplainabilityRequestList(builtins.list):
105class ExplainabilityRequestList(list):
106    """Adds a _repr_html_ to a list of explainability requests."""
107
108    def _repr_html_(self):
109        """Assuming all items are explainabilty requests generates an HTML
110        table for jupyter."""
111        rows = [
112            f"""
113        <tr">
114            <td>{exp_req.id}</td>
115            <td>{exp_req.status}</td>
116            <td>{exp_req.reference_config}</td>
117            <td>{exp_req.window_config}</td>
118            <td>{exp_req.use_adhoc_data}</td>
119        </tr>
120        """
121            for exp_req in self
122        ]
123        table = """<table>
124            <tr>
125                <th>Id</th>
126                <th>Status</th>
127                <th>Reference Data</th>
128                <th>Window Data</th>
129                <th>Adhoc</th>
130            </tr>
131            {0}
132        </table>""".format(
133            "\n".join(rows)
134        )
135
136        return table

Adds a _repr_html_ to a list of explainability requests.

Inherited Members
builtins.list
list
clear
copy
append
insert
extend
pop
remove
index
count
reverse
sort
@dataclass
class FeatureBounds:
139@dataclass
140class FeatureBounds:
141    min: float
142    max: float
143    xs: List[float]

FeatureBounds(min: float, max: float, xs: List[float])

FeatureBounds(min: float, max: float, xs: List[float])
@dataclass
class ExplainabilityConfig:
146@dataclass
147class ExplainabilityConfig:
148    """This class specifies an explainability configuration that can be used
149    to later submit explainability requests which cause the server to do the
150    analysis and create explainability results.
151
152    ExplainabilityConfig are necessary to ensure the explainability pipeline
153    is created and is deployed and so that various requests are processed in
154    the same manner and can be compared.
155
156    id, status, feature_bounds and reference_pipeline_version are optional
157    and will be filled out when processed and saved to the database.
158
159    workspace id must match the users/pipelines workspace and
160    reference_pipeline_version must refer to a valid pipeline version that the
161    user has access too.
162
163    num_points specifies how many samples to take when varying the values of a
164    feature for the PDP/ICE analysis through the feature_bounds.
165
166    feature_names are convinince for the user. output_names is not currently used.
167    """
168
169    id: Optional[UUID]
170    workspace_id: int
171    status: Optional[Dict[str, Any]]
172    reference_pipeline_version: str
173    explainability_pipeline_version: Optional[str]
174    num_points: int = 10
175    feature_names: Optional[List[str]] = None
176    feature_bounds: Optional[Dict[str, FeatureBounds]] = None
177    output_names: Optional[List[str]] = None
178
179    def _repr_html_(self):
180        fields = [
181            f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>"
182            for k, v in asdict(self).items()
183        ]
184        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
185
186    def list_explainability_requests(self) -> List["ExplainabilityRequest"]:
187        """List the explainability requests we've created."""
188
189        client = self.client  # type: ignore
190        result = client._post_rest_api_json(
191            f"v1/api/explainability/list_requests",
192            {"explainability_config_id": self.id},
193        )
194        erl = [ExplainabilityRequest(**ec) for ec in result]
195        for er in erl:
196            er.client = client  # type: ignore
197        return ExplainabilityRequestList(erl)
198
199    def get_explainability_request(
200        self, expr: Union[str, "ExplainabilityConfig"]
201    ) -> Optional["ExplainabilityRequest"]:
202        """Get the full explainability result whether completed or not."""
203
204        if isinstance(expr, str):
205            explain_id = expr
206        else:
207            explain_id = str(expr.id)
208
209        client = self.client  # type: ignore
210        result = client._post_rest_api_json(
211            f"v1/api/explainability/get_request",
212            {"explainability_request_id": explain_id},
213        )
214
215        exp_cfg = ExplainabilityRequest(**result)
216        exp_cfg.client = client  # type: ignore
217        return exp_cfg
218
219    def submit_explainability_request(
220        self,
221        reference_start: Optional[datetime] = None,
222        reference_end: Optional[datetime] = None,
223        reference_num_samples: Optional[int] = None,
224        use_reference: Optional[bool] = True,
225        window_start: Optional[datetime] = None,
226        window_end: Optional[datetime] = None,
227        window_num_samples: Optional[int] = None,
228        adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None,
229    ):
230        """Submit an analysis on reference or adhoc data using a particular config"""
231
232        reference_config = None
233        if use_reference:
234            reference_config = WindowRequestInput(
235                start=maybe_format_date(reference_start),
236                end=maybe_format_date(reference_end),
237                num_samples=reference_num_samples,
238            )
239
240        window_config = None
241        if window_start or window_end or window_num_samples:
242            window_config = WindowRequestInput(
243                start=maybe_format_date(window_start),
244                end=maybe_format_date(window_end),
245                num_samples=window_num_samples,
246            )
247
248        use_adhoc_data = False
249        if adhoc_data is not None:
250            use_adhoc_data = True
251
252        adhoc_data_list = None
253        if use_adhoc_data:
254            if isinstance(adhoc_data, list):
255                adhoc_data_list = adhoc_data
256            elif isinstance(adhoc_data, np.ndarray):
257                adhoc_data_list = adhoc_data.tolist()
258            elif isinstance(adhoc_data, pd.DataFrame):
259                adhoc_data_list = adhoc_data.values.tolist()
260            else:
261                raise Exception(f"Unknown adhoc data type {type(adhoc_data)}")
262
263        if reference_config is None and window_config is None and not use_adhoc_data:
264            raise Exception(
265                "You must specify a reference config, a window config or adhoc data"
266            )
267
268        client = self.client  # type: ignore
269        exp_config_id = self.id
270        workspace_id = client.get_current_workspace().id()
271
272        ear = ExplainabilityRequest(
273            id=None,
274            explainability_config_id=exp_config_id,
275            workspace_id=workspace_id,
276            reference_config=reference_config,
277            window_config=window_config,
278            use_adhoc_data=use_adhoc_data,
279            adhoc_data=adhoc_data_list,
280        )
281
282        result = client._post_rest_api_json(
283            f"v1/api/explainability/create_request", asdict(ear)
284        )
285        expr_id = result["id"]
286        return self.get_explainability_request(expr_id)

This class specifies an explainability configuration that can be used to later submit explainability requests which cause the server to do the analysis and create explainability results.

ExplainabilityConfig are necessary to ensure the explainability pipeline is created and is deployed and so that various requests are processed in the same manner and can be compared.

id, status, feature_bounds and reference_pipeline_version are optional and will be filled out when processed and saved to the database.

workspace id must match the users/pipelines workspace and reference_pipeline_version must refer to a valid pipeline version that the user has access too.

num_points specifies how many samples to take when varying the values of a feature for the PDP/ICE analysis through the feature_bounds.

feature_names are convinince for the user. output_names is not currently used.

ExplainabilityConfig( id: Optional[uuid.UUID], workspace_id: int, status: Optional[Dict[str, Any]], reference_pipeline_version: str, explainability_pipeline_version: Optional[str], num_points: int = 10, feature_names: Optional[List[str]] = None, feature_bounds: Optional[Dict[str, wallaroo.explainability.FeatureBounds]] = None, output_names: Optional[List[str]] = None)
num_points: int = 10
feature_names: Optional[List[str]] = None
feature_bounds: Optional[Dict[str, wallaroo.explainability.FeatureBounds]] = None
output_names: Optional[List[str]] = None
def list_explainability_requests(self) -> List[wallaroo.explainability.ExplainabilityRequest]:
186    def list_explainability_requests(self) -> List["ExplainabilityRequest"]:
187        """List the explainability requests we've created."""
188
189        client = self.client  # type: ignore
190        result = client._post_rest_api_json(
191            f"v1/api/explainability/list_requests",
192            {"explainability_config_id": self.id},
193        )
194        erl = [ExplainabilityRequest(**ec) for ec in result]
195        for er in erl:
196            er.client = client  # type: ignore
197        return ExplainabilityRequestList(erl)

List the explainability requests we've created.

def get_explainability_request( self, expr: Union[str, wallaroo.explainability.ExplainabilityConfig]) -> Optional[wallaroo.explainability.ExplainabilityRequest]:
199    def get_explainability_request(
200        self, expr: Union[str, "ExplainabilityConfig"]
201    ) -> Optional["ExplainabilityRequest"]:
202        """Get the full explainability result whether completed or not."""
203
204        if isinstance(expr, str):
205            explain_id = expr
206        else:
207            explain_id = str(expr.id)
208
209        client = self.client  # type: ignore
210        result = client._post_rest_api_json(
211            f"v1/api/explainability/get_request",
212            {"explainability_request_id": explain_id},
213        )
214
215        exp_cfg = ExplainabilityRequest(**result)
216        exp_cfg.client = client  # type: ignore
217        return exp_cfg

Get the full explainability result whether completed or not.

def submit_explainability_request( self, reference_start: Optional[datetime.datetime] = None, reference_end: Optional[datetime.datetime] = None, reference_num_samples: Optional[int] = None, use_reference: Optional[bool] = True, window_start: Optional[datetime.datetime] = None, window_end: Optional[datetime.datetime] = None, window_num_samples: Optional[int] = None, adhoc_data: Union[List[List[float]], numpy.ndarray, pandas.core.frame.DataFrame, NoneType] = None):
219    def submit_explainability_request(
220        self,
221        reference_start: Optional[datetime] = None,
222        reference_end: Optional[datetime] = None,
223        reference_num_samples: Optional[int] = None,
224        use_reference: Optional[bool] = True,
225        window_start: Optional[datetime] = None,
226        window_end: Optional[datetime] = None,
227        window_num_samples: Optional[int] = None,
228        adhoc_data: Optional[Union[List[List[float]], np.ndarray, pd.DataFrame]] = None,
229    ):
230        """Submit an analysis on reference or adhoc data using a particular config"""
231
232        reference_config = None
233        if use_reference:
234            reference_config = WindowRequestInput(
235                start=maybe_format_date(reference_start),
236                end=maybe_format_date(reference_end),
237                num_samples=reference_num_samples,
238            )
239
240        window_config = None
241        if window_start or window_end or window_num_samples:
242            window_config = WindowRequestInput(
243                start=maybe_format_date(window_start),
244                end=maybe_format_date(window_end),
245                num_samples=window_num_samples,
246            )
247
248        use_adhoc_data = False
249        if adhoc_data is not None:
250            use_adhoc_data = True
251
252        adhoc_data_list = None
253        if use_adhoc_data:
254            if isinstance(adhoc_data, list):
255                adhoc_data_list = adhoc_data
256            elif isinstance(adhoc_data, np.ndarray):
257                adhoc_data_list = adhoc_data.tolist()
258            elif isinstance(adhoc_data, pd.DataFrame):
259                adhoc_data_list = adhoc_data.values.tolist()
260            else:
261                raise Exception(f"Unknown adhoc data type {type(adhoc_data)}")
262
263        if reference_config is None and window_config is None and not use_adhoc_data:
264            raise Exception(
265                "You must specify a reference config, a window config or adhoc data"
266            )
267
268        client = self.client  # type: ignore
269        exp_config_id = self.id
270        workspace_id = client.get_current_workspace().id()
271
272        ear = ExplainabilityRequest(
273            id=None,
274            explainability_config_id=exp_config_id,
275            workspace_id=workspace_id,
276            reference_config=reference_config,
277            window_config=window_config,
278            use_adhoc_data=use_adhoc_data,
279            adhoc_data=adhoc_data_list,
280        )
281
282        result = client._post_rest_api_json(
283            f"v1/api/explainability/create_request", asdict(ear)
284        )
285        expr_id = result["id"]
286        return self.get_explainability_request(expr_id)

Submit an analysis on reference or adhoc data using a particular config

def maybe_format_date(d: Optional[datetime.datetime]) -> Optional[str]:
289def maybe_format_date(d: Optional[datetime]) -> Optional[str]:
290    if d:
291        return d.astimezone(tz=timezone.utc).isoformat()
292    return None
@dataclass
class ExplainabilityRequest:
295@dataclass
296class ExplainabilityRequest:
297    """This class outlines what should be submitted to start the explainability
298    analysis with a particular config.
299
300    The request can be to analyze reference data, historical data from the ref
301    pipeline, or new adhoc data submitted with the request or both.
302
303    id and status are optional and are filled in by the processing steps.
304
305    If the request has use_reference_data = True, num_sample inference logs are
306    sampled from between the start and end dates or the entire (last 100_000)
307    inferences.
308    """
309
310    id: Optional[UUID]
311    workspace_id: int
312    explainability_config_id: Optional[UUID] = None
313    status: Optional[Dict[str, Any]] = None
314    reference_config: Optional[WindowRequestInput] = None
315    window_config: Optional[WindowRequestInput] = None
316    use_adhoc_data: bool = False
317    adhoc_data: Optional[List[List[float]]] = None
318
319    def _repr_html_(self):
320        fields = [f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in asdict(self).items()]
321        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
322
323    def get_explainability_result(self) -> Optional["ExplainabilityResult"]:
324        """Get the full explainability result whether completed or not."""
325
326        client = self.client  # type: ignore
327
328        result = client._post_rest_api_json(
329            f"v1/api/explainability/get_result",
330            {"explainability_result_id": self.id},
331        )
332        return build_explainability_result(result)

This class outlines what should be submitted to start the explainability analysis with a particular config.

The request can be to analyze reference data, historical data from the ref pipeline, or new adhoc data submitted with the request or both.

id and status are optional and are filled in by the processing steps.

If the request has use_reference_data = True, num_sample inference logs are sampled from between the start and end dates or the entire (last 100_000) inferences.

ExplainabilityRequest( id: Optional[uuid.UUID], workspace_id: int, explainability_config_id: Optional[uuid.UUID] = None, status: Optional[Dict[str, Any]] = None, reference_config: Optional[wallaroo.explainability.WindowRequestInput] = None, window_config: Optional[wallaroo.explainability.WindowRequestInput] = None, use_adhoc_data: bool = False, adhoc_data: Optional[List[List[float]]] = None)
explainability_config_id: Optional[uuid.UUID] = None
status: Optional[Dict[str, Any]] = None
reference_config: Optional[wallaroo.explainability.WindowRequestInput] = None
window_config: Optional[wallaroo.explainability.WindowRequestInput] = None
use_adhoc_data: bool = False
adhoc_data: Optional[List[List[float]]] = None
def get_explainability_result(self) -> Optional[wallaroo.explainability.ExplainabilityResult]:
323    def get_explainability_result(self) -> Optional["ExplainabilityResult"]:
324        """Get the full explainability result whether completed or not."""
325
326        client = self.client  # type: ignore
327
328        result = client._post_rest_api_json(
329            f"v1/api/explainability/get_result",
330            {"explainability_result_id": self.id},
331        )
332        return build_explainability_result(result)

Get the full explainability result whether completed or not.

@dataclass
class PDPResult:
335@dataclass
336class PDPResult:
337    """This class holds the PDP/ICE part of the results.
338    PDP/ICE results are generated for each observation by holding
339    all but one feature constant, varying that feature and analyzing
340    that prediction. Thus the results are per inference per feature.
341
342    feature_name is the feature that this result is for.
343    xs is the list of x values that the feature was varied through.
344
345    pdp_vals is the list of resulting values.
346    model, shap and feature expected values are the mean/expected values
347    for that model, shap and feature.
348    """
349
350    feature_name: str
351    ice_vals: np.ndarray
352    pdp_vals: List[float]
353    model_expected_value: List[float]
354    shap_expected_value: List[float]
355    feature_expected_value: List[float]

This class holds the PDP/ICE part of the results. PDP/ICE results are generated for each observation by holding all but one feature constant, varying that feature and analyzing that prediction. Thus the results are per inference per feature.

feature_name is the feature that this result is for. xs is the list of x values that the feature was varied through.

pdp_vals is the list of resulting values. model, shap and feature expected values are the mean/expected values for that model, shap and feature.

PDPResult( feature_name: str, ice_vals: numpy.ndarray, pdp_vals: List[float], model_expected_value: List[float], shap_expected_value: List[float], feature_expected_value: List[float])
@dataclass
class WindowResult:
358@dataclass
359class WindowResult:
360    data: NDArray  # the original data used for the analysis
361    shap_values: NDArray  # the caculated shap values
362    base_values: NDArray  # The expected value (mean) for each prediction.
363    pdp_results: List[PDPResult]  # Partial dependence plot data for each feature

WindowResult(data: numpy.ndarray, shap_values: numpy.ndarray, base_values: numpy.ndarray, pdp_results: List[wallaroo.explainability.PDPResult])

WindowResult( data: numpy.ndarray, shap_values: numpy.ndarray, base_values: numpy.ndarray, pdp_results: List[wallaroo.explainability.PDPResult])
def find_pdp_result( window_result: wallaroo.explainability.WindowResult, feature_name: str) -> wallaroo.explainability.PDPResult:
366def find_pdp_result(window_result: WindowResult, feature_name: str) -> PDPResult:
367    """Gets the pdp result object for the specified feature."""
368
369    for p in window_result.pdp_results:
370        if p.feature_name == feature_name:
371            return p
372    raise Exception(f"Did not find feature {feature_name} in pdp_results.")

Gets the pdp result object for the specified feature.

@dataclass
class ExplainabilityResult:
375@dataclass
376class ExplainabilityResult:
377    """This class holds the explainability result created by processing an
378    explainability request.
379
380    id and status are optional and will be filled in by processing. The id
381    will be the same as the request id since the results are stored in minio.
382
383    num_inferences and num_batches are nice to know status information and
384    could be brought into the status object in the future.
385
386    reference and adhoc data are the actual inferences used in the analysis.
387
388    reference and adhoc shap values are the shap values for each feature for
389    each prediction.
390
391    base_values are the expected value for each prediction. These values will
392    all be the same so may be changed to a single float in the future.
393
394    pdp results are a list of pdp/ice results for each feature.
395
396    """
397
398    id: Optional[UUID]
399    workspace_id: int
400    explainability_config_id: UUID
401    num_inferences: int
402    num_batches: int
403    compute_time: float
404    status: Dict[str, Any]
405    feature_names: List[str]  # The names of the columns
406    feature_bounds: Dict[str, FeatureBounds]
407    reference_result: Optional[WindowResult]
408    window_result: Optional[WindowResult]
409    adhoc_result: Optional[WindowResult]
410
411    reference_color = np.array([127.0, 196, 252]) / 255
412    window_color = np.array([252, 127.0, 196]) / 255
413
414    def _repr_html_(self):
415        fields = [
416            f"<tr><td>{k}</td><td>{truncate(v)}</td></tr>"
417            for k, v in asdict(self).items()
418        ]
419        return f"<table><tr><th>Field</th><th>Value</th></tr>{''.join(fields)}</table>"
420
421    def feature_effects(self) -> pd.DataFrame:
422        """Returns a dataframe summarizing the mean feature effects of the reference
423        data as well as the feature effects for each adhoc inference."""
424
425        df = pd.DataFrame(index=self.feature_names)
426
427        if self.reference_result is not None:
428            vals = np.array(self.reference_result.shap_values)
429            if vals.shape[0] > 0:
430                df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0)
431                df["ref_std_dev"] = vals.std(axis=0)
432                df[REF_MAX_ABS] = np.abs(vals).max(axis=0)
433
434        if self.window_result is not None:
435            vals = np.array(self.window_result.shap_values)
436            if vals.shape[0] > 0:
437                df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0)
438                df["window_std_dev"] = vals.std(axis=0)
439                df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0)
440
441        if self.adhoc_result is not None:
442            vals = np.array(self.adhoc_result.shap_values)
443            if vals.shape[0] > 0:
444                for idx in range(vals.shape[0]):
445                    df[f"input_{idx+1}"] = vals[idx, :]
446
447        if REF_MEAN_ABS in df.columns:
448            return df.sort_values(by=REF_MEAN_ABS, ascending=False)
449        return df
450
451    def effect_summary(self) -> pd.DataFrame:
452        """Returns a dataframe with the expected/mean values and the shap adjustments."""
453
454        effects = self.feature_effects()
455        base_value = unwrap(self.reference_result).base_values[0]
456        data = {}
457        data["base_value"] = [0, base_value]
458
459        # we should rename inputs_X to 'house_X' or similar
460        input_cols = [c for c in effects.columns if "input_" in c]
461        effect_sums = effects[input_cols].sum(axis=0)
462        for c, v in zip(input_cols, effect_sums):
463            data[c] = [v, v + base_value]
464        return pd.DataFrame(data, index=["adjustment", "total"])
465
466    def check_status(self) -> bool:
467        """Ensure we've completed before trying to plot anything."""
468
469        if self.status["status"] != "COMPLETED":
470            raise Exception(f"Analysis has not (yet) completed: {self.status}")
471        return True
472
473    def plot_feature_effects(
474        self,
475        mode: str = "mean",
476        top_n=0,
477        plot_reference=True,
478        plot_window=True,
479        plot_adhoc=True,
480    ):
481        """Creates a bar plot of the mean or max abs feature effects."""
482
483        if mode not in ["mean", "max", "individual"]:
484            raise Exception("Mode must be one of: 'mean', 'max', 'individual'")
485
486        self.check_status()
487
488        df = self.feature_effects()
489        if top_n > 0:
490            df = df.head(top_n)
491
492        cols = df.columns
493        if not plot_reference:
494            cols = [c for c in cols if not c.startswith("ref_")]
495        if not plot_window:
496            cols = [c for c in cols if not c.startswith("window_")]
497        if not plot_adhoc:
498            cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")]
499
500        df = df[cols]
501
502        if REF_MEAN_ABS in df.columns and mode != "individual":
503            if mode == "max":
504                title = "Max Absolute Feature Effect"
505                _ = plt.bar(df.index, df[REF_MAX_ABS])
506            else:
507                title = "Mean Absolute Feature Effect"
508                _ = plt.bar(df.index, df[REF_MEAN_ABS])
509        else:
510            title = "Feature Effects"
511            ax = plt.subplot(1, 1, 1)
512            _ = df.drop(
513                ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore"
514            ).plot(kind="bar", ax=ax)
515            plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
516
517        plt.grid()
518        plt.xticks(rotation=90)
519        plt.title(title)
520
521    def plot_ice_values(
522        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
523    ):
524        """Creates a combination ICE plot for the adhoc data if any
525        in custom colors and the reference data if any in translucent
526        blue."""
527
528        self.check_status()
529
530        ice_alpha = 0.2
531        ice_width = 5
532
533        xs = self.feature_bounds[feature_name].xs
534
535        # pdp_result = unwrap(self.pdp_result(feature_name))
536        # xs = pdp_result.xs
537        # mean_vals = np.array(pdp_result.pdp_vals)
538        # ice_vals = np.array(pdp_result.ice_vals)
539
540        plt.title(f"ICE for {feature_name}")
541        plt.xlabel(feature_name)
542        plt.ylabel("Prediction")
543
544        if plot_reference and self.reference_result is not None:
545            pdp_result = find_pdp_result(self.reference_result, feature_name)
546            ice_vals = pdp_result.ice_vals
547            mean_vals = pdp_result.pdp_vals
548            _ = plt.plot(
549                xs,
550                ice_vals,
551                color=self.reference_color,
552                alpha=ice_alpha,
553                linewidth=ice_width,
554            )
555            _ = plt.plot(
556                xs, mean_vals, color="black", zorder=10, label="Reference Mean"
557            )
558
559        if plot_window and self.window_result is not None:
560            pdp_result = find_pdp_result(self.window_result, feature_name)
561            ice_vals = pdp_result.ice_vals
562            mean_vals = pdp_result.pdp_vals
563            _ = plt.plot(
564                xs,
565                ice_vals,
566                color=self.window_color,
567                alpha=ice_alpha,
568                linewidth=ice_width,
569            )
570            _ = plt.plot(
571                xs, mean_vals, linestyle="dashed", color="black", label="Window Mean"
572            )
573
574        if plot_adhoc and self.adhoc_result is not None:
575            pdp_result = find_pdp_result(self.adhoc_result, feature_name)
576            ice_vals = np.array(pdp_result.ice_vals)
577            mean_vals = pdp_result.pdp_vals
578            for idx in range(self.adhoc_result.data.shape[0]):
579                _ = plt.plot(
580                    xs,
581                    ice_vals[:, idx : idx + 1],
582                    linewidth=3,
583                    label=f"input_{idx+1}",
584                    color=colors[idx % len(colors)],
585                )
586            plt.legend()
587
588        plt.ylim(0)
589        _ = plt.grid()
590        plt.show()
591
592    def plot_all_features(
593        self,
594        title="Feature Effect per Inference",
595        plot_reference=True,
596        plot_window=True,
597        plot_adhoc=True,
598        top_n: int = 0,
599    ):
600        """Creates a 'bee swarm' plot of all/each feature effect."""
601        self.check_status()
602        np.random.seed(42)
603
604        # Our custom palette will go from black to a weird blue green.
605        # Not sure of the best palette to use. We need lighter to be
606        # higher values. Started with a red to green ramp but we also
607        # need to be aware to color blindness issues.
608        rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)]
609
610        # Gather the data we'll need
611        shap_values_list = []
612        feature_values_list = []
613        if plot_reference and self.reference_result:
614            shap_values_list.append(self.reference_result.shap_values)
615            feature_values_list.append(self.reference_result.data)
616        if plot_window and self.window_result:
617            shap_values_list.append(self.window_result.shap_values)
618            feature_values_list.append(self.window_result.data)
619        if plot_adhoc and self.adhoc_result:
620            shap_values_list.append(self.adhoc_result.shap_values)
621            feature_values_list.append(self.adhoc_result.data)
622
623        if not shap_values_list:
624            raise Exception("Some data must be specified to create a plot.")
625
626        shap_values = np.vstack(shap_values_list)
627        feature_values = np.vstack(feature_values_list)
628
629        # create a df of the shap values / contributions
630        df = pd.DataFrame(shap_values)
631        df.columns = self.feature_names
632        # create a df of the original feature values
633        feature_df = pd.DataFrame(feature_values)
634        feature_df.columns = self.feature_names
635
636        # We plot in strips from the bottom so put the most important at the top.
637        # first taking the top_n if specified.
638        feature_names = list(self.feature_effects().index)
639        if top_n > 0:
640            feature_names = feature_names[:top_n]
641        feature_names = list(reversed(feature_names))
642
643        num_features = len(feature_names)
644        num_obs = df.shape[0]
645
646        fig = plt.figure()
647        fig.patch.set_facecolor("white")
648        ax = fig.get_axes()
649
650        for i, col in enumerate(feature_names):
651            # create a temp df where the y is the level + jitter
652            dfc = pd.DataFrame(df[col])
653            dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs)
654            # vals is the original feature values and we create a custom palette
655            dfc["vals"] = feature_df[col]
656            unique_vals = dfc["vals"].unique()
657            n_colors = len(unique_vals)
658            palette = sns.blend_palette(rgb, n_colors=n_colors)
659
660            # plot a scatter plot strip
661            ax = sns.scatterplot(
662                x=col,
663                y="y",
664                data=dfc,
665                alpha=0.75,
666                hue="vals",
667                palette=palette,
668                legend=None,
669            )
670
671        # change the tick labels from strip number to feature name
672        ax.set_yticks(range(num_features + 1))
673        ticks = [""]
674        ticks.extend(feature_names)
675        ax.set_yticklabels(ticks)
676
677        plt.xlabel("Shap Value")
678        plt.ylabel("")
679        plt.title(title)
680        plt.ylim(0, num_features + 1)
681        plt.grid()
682        plt.show()
683
684    def plot_feature_values_vs_shap_values(
685        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
686    ):
687        """Creates a scatter plot of the feature vs shap values.
688        adhoc data if any is in custom colors. reference data in translucent
689        blue."""
690
691        self.check_status()
692
693        alpha = 0.5
694
695        i = self.feature_names.index(feature_name)
696
697        plt.title(f"Shap Values for {feature_name} for Each Input")
698        plt.xlabel(feature_name)
699        plt.ylabel("Shap Values")
700        plt.grid()
701
702        if plot_reference and self.reference_result:
703            _ = plt.scatter(
704                self.reference_result.data[:, i],
705                self.reference_result.shap_values[:, i],
706                s=100,
707                alpha=alpha,
708                color=self.reference_color,
709                label="Reference",
710            )
711
712        if plot_window and self.window_result:
713            _ = plt.scatter(
714                self.window_result.data[:, i],
715                self.window_result.shap_values[:, i],
716                marker="p",
717                s=100,
718                alpha=alpha,
719                color=self.window_color,
720                label="Window",
721            )
722
723        if plot_adhoc and self.adhoc_result:
724            for idx in range(self.adhoc_result.data.shape[0]):
725                _ = plt.scatter(
726                    self.adhoc_result.data[idx, i],
727                    self.adhoc_result.shap_values[idx, i],
728                    marker="d",  # type: ignore
729                    s=200,
730                    label=f"input_{idx+1}",
731                    color=colors[idx % len(colors)],
732                )
733
734        # expected_shap_values = (
735        #     reference_shap_values
736        #     if self.reference_result.shap_values.shape[0] > 0
737        #     else adhoc_shap_values
738        # )
739        # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x
740
741        # mean_feature_value = expected_x[:, i].mean()
742        # min_y = expected_shap_values[:, i].min()
743        # max_y = expected_shap_values[:, i].max()
744        # plt.vlines(
745        #     mean_feature_value,
746        #     min_y,
747        #     max_y,
748        #     colors="gray",
749        #     linestyle="dotted",
750        #     label=f"Mean {feature_name}",
751        # )
752        plt.legend()
753        plt.show()

This class holds the explainability result created by processing an explainability request.

id and status are optional and will be filled in by processing. The id will be the same as the request id since the results are stored in minio.

num_inferences and num_batches are nice to know status information and could be brought into the status object in the future.

reference and adhoc data are the actual inferences used in the analysis.

reference and adhoc shap values are the shap values for each feature for each prediction.

base_values are the expected value for each prediction. These values will all be the same so may be changed to a single float in the future.

pdp results are a list of pdp/ice results for each feature.

ExplainabilityResult( id: Optional[uuid.UUID], workspace_id: int, explainability_config_id: uuid.UUID, num_inferences: int, num_batches: int, compute_time: float, status: Dict[str, Any], feature_names: List[str], feature_bounds: Dict[str, wallaroo.explainability.FeatureBounds], reference_result: Optional[wallaroo.explainability.WindowResult], window_result: Optional[wallaroo.explainability.WindowResult], adhoc_result: Optional[wallaroo.explainability.WindowResult])
reference_color = array([0.49803922, 0.76862745, 0.98823529])
window_color = array([0.98823529, 0.49803922, 0.76862745])
def feature_effects(self) -> pandas.core.frame.DataFrame:
421    def feature_effects(self) -> pd.DataFrame:
422        """Returns a dataframe summarizing the mean feature effects of the reference
423        data as well as the feature effects for each adhoc inference."""
424
425        df = pd.DataFrame(index=self.feature_names)
426
427        if self.reference_result is not None:
428            vals = np.array(self.reference_result.shap_values)
429            if vals.shape[0] > 0:
430                df[REF_MEAN_ABS] = np.abs(vals).mean(axis=0)
431                df["ref_std_dev"] = vals.std(axis=0)
432                df[REF_MAX_ABS] = np.abs(vals).max(axis=0)
433
434        if self.window_result is not None:
435            vals = np.array(self.window_result.shap_values)
436            if vals.shape[0] > 0:
437                df[WINDOW_MEAN_ABS] = np.abs(vals).mean(axis=0)
438                df["window_std_dev"] = vals.std(axis=0)
439                df[WINDOW_MAX_ABS] = np.abs(vals).max(axis=0)
440
441        if self.adhoc_result is not None:
442            vals = np.array(self.adhoc_result.shap_values)
443            if vals.shape[0] > 0:
444                for idx in range(vals.shape[0]):
445                    df[f"input_{idx+1}"] = vals[idx, :]
446
447        if REF_MEAN_ABS in df.columns:
448            return df.sort_values(by=REF_MEAN_ABS, ascending=False)
449        return df

Returns a dataframe summarizing the mean feature effects of the reference data as well as the feature effects for each adhoc inference.

def effect_summary(self) -> pandas.core.frame.DataFrame:
451    def effect_summary(self) -> pd.DataFrame:
452        """Returns a dataframe with the expected/mean values and the shap adjustments."""
453
454        effects = self.feature_effects()
455        base_value = unwrap(self.reference_result).base_values[0]
456        data = {}
457        data["base_value"] = [0, base_value]
458
459        # we should rename inputs_X to 'house_X' or similar
460        input_cols = [c for c in effects.columns if "input_" in c]
461        effect_sums = effects[input_cols].sum(axis=0)
462        for c, v in zip(input_cols, effect_sums):
463            data[c] = [v, v + base_value]
464        return pd.DataFrame(data, index=["adjustment", "total"])

Returns a dataframe with the expected/mean values and the shap adjustments.

def check_status(self) -> bool:
466    def check_status(self) -> bool:
467        """Ensure we've completed before trying to plot anything."""
468
469        if self.status["status"] != "COMPLETED":
470            raise Exception(f"Analysis has not (yet) completed: {self.status}")
471        return True

Ensure we've completed before trying to plot anything.

def plot_feature_effects( self, mode: str = 'mean', top_n=0, plot_reference=True, plot_window=True, plot_adhoc=True):
473    def plot_feature_effects(
474        self,
475        mode: str = "mean",
476        top_n=0,
477        plot_reference=True,
478        plot_window=True,
479        plot_adhoc=True,
480    ):
481        """Creates a bar plot of the mean or max abs feature effects."""
482
483        if mode not in ["mean", "max", "individual"]:
484            raise Exception("Mode must be one of: 'mean', 'max', 'individual'")
485
486        self.check_status()
487
488        df = self.feature_effects()
489        if top_n > 0:
490            df = df.head(top_n)
491
492        cols = df.columns
493        if not plot_reference:
494            cols = [c for c in cols if not c.startswith("ref_")]
495        if not plot_window:
496            cols = [c for c in cols if not c.startswith("window_")]
497        if not plot_adhoc:
498            cols = [c for c in cols if c.startswith("ref_") or c.startswith("window_")]
499
500        df = df[cols]
501
502        if REF_MEAN_ABS in df.columns and mode != "individual":
503            if mode == "max":
504                title = "Max Absolute Feature Effect"
505                _ = plt.bar(df.index, df[REF_MAX_ABS])
506            else:
507                title = "Mean Absolute Feature Effect"
508                _ = plt.bar(df.index, df[REF_MEAN_ABS])
509        else:
510            title = "Feature Effects"
511            ax = plt.subplot(1, 1, 1)
512            _ = df.drop(
513                ["ref_std_dev", "window_std_dev"], axis=1, errors="ignore"
514            ).plot(kind="bar", ax=ax)
515            plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
516
517        plt.grid()
518        plt.xticks(rotation=90)
519        plt.title(title)

Creates a bar plot of the mean or max abs feature effects.

def plot_ice_values( self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True):
521    def plot_ice_values(
522        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
523    ):
524        """Creates a combination ICE plot for the adhoc data if any
525        in custom colors and the reference data if any in translucent
526        blue."""
527
528        self.check_status()
529
530        ice_alpha = 0.2
531        ice_width = 5
532
533        xs = self.feature_bounds[feature_name].xs
534
535        # pdp_result = unwrap(self.pdp_result(feature_name))
536        # xs = pdp_result.xs
537        # mean_vals = np.array(pdp_result.pdp_vals)
538        # ice_vals = np.array(pdp_result.ice_vals)
539
540        plt.title(f"ICE for {feature_name}")
541        plt.xlabel(feature_name)
542        plt.ylabel("Prediction")
543
544        if plot_reference and self.reference_result is not None:
545            pdp_result = find_pdp_result(self.reference_result, feature_name)
546            ice_vals = pdp_result.ice_vals
547            mean_vals = pdp_result.pdp_vals
548            _ = plt.plot(
549                xs,
550                ice_vals,
551                color=self.reference_color,
552                alpha=ice_alpha,
553                linewidth=ice_width,
554            )
555            _ = plt.plot(
556                xs, mean_vals, color="black", zorder=10, label="Reference Mean"
557            )
558
559        if plot_window and self.window_result is not None:
560            pdp_result = find_pdp_result(self.window_result, feature_name)
561            ice_vals = pdp_result.ice_vals
562            mean_vals = pdp_result.pdp_vals
563            _ = plt.plot(
564                xs,
565                ice_vals,
566                color=self.window_color,
567                alpha=ice_alpha,
568                linewidth=ice_width,
569            )
570            _ = plt.plot(
571                xs, mean_vals, linestyle="dashed", color="black", label="Window Mean"
572            )
573
574        if plot_adhoc and self.adhoc_result is not None:
575            pdp_result = find_pdp_result(self.adhoc_result, feature_name)
576            ice_vals = np.array(pdp_result.ice_vals)
577            mean_vals = pdp_result.pdp_vals
578            for idx in range(self.adhoc_result.data.shape[0]):
579                _ = plt.plot(
580                    xs,
581                    ice_vals[:, idx : idx + 1],
582                    linewidth=3,
583                    label=f"input_{idx+1}",
584                    color=colors[idx % len(colors)],
585                )
586            plt.legend()
587
588        plt.ylim(0)
589        _ = plt.grid()
590        plt.show()

Creates a combination ICE plot for the adhoc data if any in custom colors and the reference data if any in translucent blue.

def plot_all_features( self, title='Feature Effect per Inference', plot_reference=True, plot_window=True, plot_adhoc=True, top_n: int = 0):
592    def plot_all_features(
593        self,
594        title="Feature Effect per Inference",
595        plot_reference=True,
596        plot_window=True,
597        plot_adhoc=True,
598        top_n: int = 0,
599    ):
600        """Creates a 'bee swarm' plot of all/each feature effect."""
601        self.check_status()
602        np.random.seed(42)
603
604        # Our custom palette will go from black to a weird blue green.
605        # Not sure of the best palette to use. We need lighter to be
606        # higher values. Started with a red to green ramp but we also
607        # need to be aware to color blindness issues.
608        rgb = [(0.10, 0.10, 0.10), (0.0, 0.90, 0.90)]
609
610        # Gather the data we'll need
611        shap_values_list = []
612        feature_values_list = []
613        if plot_reference and self.reference_result:
614            shap_values_list.append(self.reference_result.shap_values)
615            feature_values_list.append(self.reference_result.data)
616        if plot_window and self.window_result:
617            shap_values_list.append(self.window_result.shap_values)
618            feature_values_list.append(self.window_result.data)
619        if plot_adhoc and self.adhoc_result:
620            shap_values_list.append(self.adhoc_result.shap_values)
621            feature_values_list.append(self.adhoc_result.data)
622
623        if not shap_values_list:
624            raise Exception("Some data must be specified to create a plot.")
625
626        shap_values = np.vstack(shap_values_list)
627        feature_values = np.vstack(feature_values_list)
628
629        # create a df of the shap values / contributions
630        df = pd.DataFrame(shap_values)
631        df.columns = self.feature_names
632        # create a df of the original feature values
633        feature_df = pd.DataFrame(feature_values)
634        feature_df.columns = self.feature_names
635
636        # We plot in strips from the bottom so put the most important at the top.
637        # first taking the top_n if specified.
638        feature_names = list(self.feature_effects().index)
639        if top_n > 0:
640            feature_names = feature_names[:top_n]
641        feature_names = list(reversed(feature_names))
642
643        num_features = len(feature_names)
644        num_obs = df.shape[0]
645
646        fig = plt.figure()
647        fig.patch.set_facecolor("white")
648        ax = fig.get_axes()
649
650        for i, col in enumerate(feature_names):
651            # create a temp df where the y is the level + jitter
652            dfc = pd.DataFrame(df[col])
653            dfc["y"] = i + 1 + norm.rvs(loc=0, scale=0.1, size=num_obs)
654            # vals is the original feature values and we create a custom palette
655            dfc["vals"] = feature_df[col]
656            unique_vals = dfc["vals"].unique()
657            n_colors = len(unique_vals)
658            palette = sns.blend_palette(rgb, n_colors=n_colors)
659
660            # plot a scatter plot strip
661            ax = sns.scatterplot(
662                x=col,
663                y="y",
664                data=dfc,
665                alpha=0.75,
666                hue="vals",
667                palette=palette,
668                legend=None,
669            )
670
671        # change the tick labels from strip number to feature name
672        ax.set_yticks(range(num_features + 1))
673        ticks = [""]
674        ticks.extend(feature_names)
675        ax.set_yticklabels(ticks)
676
677        plt.xlabel("Shap Value")
678        plt.ylabel("")
679        plt.title(title)
680        plt.ylim(0, num_features + 1)
681        plt.grid()
682        plt.show()

Creates a 'bee swarm' plot of all/each feature effect.

def plot_feature_values_vs_shap_values( self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True):
684    def plot_feature_values_vs_shap_values(
685        self, feature_name: str, plot_reference=True, plot_window=True, plot_adhoc=True
686    ):
687        """Creates a scatter plot of the feature vs shap values.
688        adhoc data if any is in custom colors. reference data in translucent
689        blue."""
690
691        self.check_status()
692
693        alpha = 0.5
694
695        i = self.feature_names.index(feature_name)
696
697        plt.title(f"Shap Values for {feature_name} for Each Input")
698        plt.xlabel(feature_name)
699        plt.ylabel("Shap Values")
700        plt.grid()
701
702        if plot_reference and self.reference_result:
703            _ = plt.scatter(
704                self.reference_result.data[:, i],
705                self.reference_result.shap_values[:, i],
706                s=100,
707                alpha=alpha,
708                color=self.reference_color,
709                label="Reference",
710            )
711
712        if plot_window and self.window_result:
713            _ = plt.scatter(
714                self.window_result.data[:, i],
715                self.window_result.shap_values[:, i],
716                marker="p",
717                s=100,
718                alpha=alpha,
719                color=self.window_color,
720                label="Window",
721            )
722
723        if plot_adhoc and self.adhoc_result:
724            for idx in range(self.adhoc_result.data.shape[0]):
725                _ = plt.scatter(
726                    self.adhoc_result.data[idx, i],
727                    self.adhoc_result.shap_values[idx, i],
728                    marker="d",  # type: ignore
729                    s=200,
730                    label=f"input_{idx+1}",
731                    color=colors[idx % len(colors)],
732                )
733
734        # expected_shap_values = (
735        #     reference_shap_values
736        #     if self.reference_result.shap_values.shape[0] > 0
737        #     else adhoc_shap_values
738        # )
739        # expected_x = reference_x if reference_x.shape[0] > 0 else adhoc_x
740
741        # mean_feature_value = expected_x[:, i].mean()
742        # min_y = expected_shap_values[:, i].min()
743        # max_y = expected_shap_values[:, i].max()
744        # plt.vlines(
745        #     mean_feature_value,
746        #     min_y,
747        #     max_y,
748        #     colors="gray",
749        #     linestyle="dotted",
750        #     label=f"Mean {feature_name}",
751        # )
752        plt.legend()
753        plt.show()

Creates a scatter plot of the feature vs shap values. adhoc data if any is in custom colors. reference data in translucent blue.

def build_window_result(data) -> Optional[wallaroo.explainability.WindowResult]:
756def build_window_result(data) -> Optional[WindowResult]:
757    if data is not None:
758        return WindowResult(
759            data=np.array(data["data"]),
760            shap_values=np.array(data["shap_values"]),
761            base_values=np.array(data["base_values"]),
762            pdp_results=[PDPResult(**d) for d in data["pdp_results"]],
763        )
764    else:
765        return None
def build_explainability_result(data):
768def build_explainability_result(data):
769    """Convinience function to parse json into the full result object
770    we want."""
771
772    return ExplainabilityResult(
773        id=data["id"],
774        workspace_id=data["workspace_id"],
775        explainability_config_id=data["explainability_config_id"],
776        num_inferences=data["num_inferences"],
777        num_batches=data["num_batches"],
778        compute_time=data["compute_time"],
779        status=data["status"],
780        feature_names=data["feature_names"],
781        feature_bounds={
782            k: FeatureBounds(**v) for k, v in data["feature_bounds"].items()
783        },
784        reference_result=build_window_result(data["reference_result"]),
785        window_result=build_window_result(data["window_result"]),
786        adhoc_result=build_window_result(data["adhoc_result"]),
787    )

Convinience function to parse json into the full result object we want.