Generating counterfactuals for multi-class classification and regression models¶
This notebook will demonstrate how the DiCE library can be used for multiclass classification and regression for scikit-learn models. You can use any method (“random”, “kdtree”, “genetic”), just specific it in the method argument in the initialization step. The rest of the code is completely identical. For demonstration, we will be using the genetic algorithm for CFs.
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import dice_ml
from dice_ml import Dice
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import pandas as pd
We will use sklearn’s internal datasets to demonstrate DiCE’s features in this notebook
[3]:
outcome_name = 'target'
# Function to process sklearn's internal datasets
def sklearn_to_df(sklearn_dataset):
df = pd.DataFrame(sklearn_dataset.data, columns=sklearn_dataset.feature_names)
df[outcome_name] = pd.Series(sklearn_dataset.target)
return df
Multiclass Classification¶
For multiclass classification, we will use sklearn’s Iris dataset. This data set consists of 3 different types of irises’ (Setosa, Versicolour, and Virginica) petal and sepal length. More information at https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-plants-dataset
[4]:
from sklearn.datasets import load_iris
df_iris = sklearn_to_df(load_iris())
df_iris.head()
[4]:
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
[5]:
df_iris.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 sepal length (cm) 150 non-null float64
1 sepal width (cm) 150 non-null float64
2 petal length (cm) 150 non-null float64
3 petal width (cm) 150 non-null float64
4 target 150 non-null int64
dtypes: float64(4), int64(1)
memory usage: 6.0 KB
[6]:
continuous_features_iris = df_iris.drop(outcome_name, axis=1).columns.tolist()
target = df_iris[outcome_name]
[7]:
# Split data into train and test
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
datasetX = df_iris.drop(outcome_name, axis=1)
x_train, x_test, y_train, y_test = train_test_split(datasetX,
target,
test_size = 0.2,
random_state=0,
stratify=target)
categorical_features = x_train.columns.difference(continuous_features_iris)
# We create the preprocessing pipelines for both numeric and categorical data.
numeric_transformer = Pipeline(steps=[
('scaler', StandardScaler())])
categorical_transformer = Pipeline(steps=[
('onehot', OneHotEncoder(handle_unknown='ignore'))])
transformations = ColumnTransformer(
transformers=[
('num', numeric_transformer, continuous_features_iris),
('cat', categorical_transformer, categorical_features)])
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf_iris = Pipeline(steps=[('preprocessor', transformations),
('classifier', RandomForestClassifier())])
model_iris = clf_iris.fit(x_train, y_train)
[8]:
d_iris = dice_ml.Data(dataframe=df_iris,
continuous_features=continuous_features_iris,
outcome_name=outcome_name)
# We provide the type of model as a parameter (model_type)
m_iris = dice_ml.Model(model=model_iris, backend="sklearn", model_type='classifier')
[9]:
exp_genetic_iris = Dice(d_iris, m_iris, method="genetic")
As we can see below, all the target values will lie in the desired class
[10]:
# Single input
query_instances_iris = x_train[2:3]
genetic_iris = exp_genetic_iris.generate_counterfactuals(query_instances_iris, total_CFs=7, desired_class = 2)
genetic_iris.visualize_as_dataframe()
Query instance (original outcome : 0)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 4.4 | 3.0 | 1.3 | 0.2 | 0 |
Diverse Counterfactual set (new outcome: 2)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 4.3 | 3.3 | 1.0 | 0.1 | 0 |
1 | 4.3 | 3.0 | 1.0 | 0.2 | 0 |
2 | 4.3 | 2.7 | 1.0 | 0.1 | 0 |
3 | 4.3 | 3.2 | 1.0 | 0.1 | 0 |
4 | 4.3 | 2.0 | 1.0 | 0.1 | 0 |
5 | 4.3 | 2.0 | 1.0 | 0.1 | 0 |
6 | 4.3 | 3.0 | 1.0 | 0.1 | 0 |
[11]:
# Multiple queries can be given as input at once
query_instances_iris = x_train[17:19]
genetic_iris = exp_genetic_iris.generate_counterfactuals(query_instances_iris, total_CFs=7, desired_class = 2)
genetic_iris.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 1)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.7 | 2.9 | 4.2 | 1.3 | 1 |
Diverse Counterfactual set (new outcome: 2)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 6.0 | - | 4.8 | 1.8 | 2.0 |
1 | 6.2 | - | 4.8 | 1.8 | 2.0 |
2 | 6.1 | - | 4.9 | 1.8 | 2.0 |
3 | - | - | 4.9 | 2.0 | 2.0 |
4 | 4.9 | 2.5 | 4.5 | 1.7 | 2.0 |
5 | 5.9 | - | 5.1 | 1.8 | 2.0 |
6 | 6.3 | 2.7 | 4.9 | 1.8 | 2.0 |
Query instance (original outcome : 1)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 6.3 | 3.3 | 4.7 | 1.6 | 1 |
Diverse Counterfactual set (new outcome: 2)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 6.1 | 3.0 | - | 1.8 | 2.0 |
1 | 6.0 | 3.0 | - | 1.8 | 2.0 |
2 | - | 2.8 | - | 1.8 | 2.0 |
3 | 6.5 | - | 5.1 | 2.0 | 2.0 |
4 | - | 2.8 | 5.1 | - | 2.0 |
5 | - | 2.7 | - | 1.8 | 2.0 |
6 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 |
Regression¶
For regression, we will use sklearn’s boston dataset. This dataset contains boston house-prices. More information at https://scikit-learn.org/stable/datasets/toy_dataset.html#boston-house-prices-dataset
[12]:
from sklearn.datasets import load_boston
df_boston = sklearn_to_df(load_boston())
df_boston.head()
[12]:
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0.0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1.0 | 296.0 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0.0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2.0 | 242.0 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0.0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2.0 | 242.0 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0.0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3.0 | 222.0 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0.0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3.0 | 222.0 | 18.7 | 396.90 | 5.33 | 36.2 |
[13]:
df_boston.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 506 entries, 0 to 505
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 CRIM 506 non-null float64
1 ZN 506 non-null float64
2 INDUS 506 non-null float64
3 CHAS 506 non-null float64
4 NOX 506 non-null float64
5 RM 506 non-null float64
6 AGE 506 non-null float64
7 DIS 506 non-null float64
8 RAD 506 non-null float64
9 TAX 506 non-null float64
10 PTRATIO 506 non-null float64
11 B 506 non-null float64
12 LSTAT 506 non-null float64
13 target 506 non-null float64
dtypes: float64(14)
memory usage: 55.5 KB
[14]:
continuous_features_boston = df_boston.drop(outcome_name, axis=1).columns.tolist()
target = df_boston[outcome_name]
[15]:
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestRegressor
# Split data into train and test
datasetX = df_boston.drop(outcome_name, axis=1)
x_train, x_test, y_train, y_test = train_test_split(datasetX,
target,
test_size = 0.2,
random_state=0)
categorical_features = x_train.columns.difference(continuous_features_boston)
# We create the preprocessing pipelines for both numeric and categorical data.
numeric_transformer = Pipeline(steps=[
('scaler', StandardScaler())])
categorical_transformer = Pipeline(steps=[
('onehot', OneHotEncoder(handle_unknown='ignore'))])
transformations = ColumnTransformer(
transformers=[
('num', numeric_transformer, continuous_features_boston),
('cat', categorical_transformer, categorical_features)])
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
regr_boston = Pipeline(steps=[('preprocessor', transformations),
('regressor', RandomForestRegressor())])
model_boston = regr_boston.fit(x_train, y_train)
[16]:
d_boston = dice_ml.Data(dataframe=df_boston, continuous_features=continuous_features_boston, outcome_name=outcome_name)
# We provide the type of model as a parameter (model_type)
m_boston = dice_ml.Model(model=model_boston, backend="sklearn", model_type='regressor')
[17]:
exp_genetic_boston = Dice(d_boston, m_boston, method="genetic")
As we can see below, all the target values will lie in the desired range
[18]:
# Multiple queries can be given as input at once
query_instances_boston = x_train[2:3]
genetic_boston = exp_genetic_boston.generate_counterfactuals(query_instances_boston,
total_CFs=2,
desired_range=[30, 45])
genetic_boston.visualize_as_dataframe(show_only_changes=True)
WARNING:root: MAD for feature ZN is 0, so replacing it with 1.0 to avoid error.
WARNING:root: MAD for feature CHAS is 0, so replacing it with 1.0 to avoid error.
Query instance (original outcome : 24)
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.11329 | 30.0 | 4.93 | 0.0 | 0.428 | 6.897 | 54.3 | 6.3361 | 6.0 | 300.0 | 16.6 | 391.25 | 11.38 | 24.014 |
Diverse Counterfactual set (new outcome: [30, 45])
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.01301 | 34.0 | 6.1 | - | 0.385 | 6.982 | 49.3 | 5.4917 | 7.0 | 329.0 | 15.5 | 390.4 | 4.86 | 33.047000885009766 |
1 | 0.10469 | 40.0 | 6.2 | 1.0 | 0.447 | 7.267 | 49.0 | 4.7872 | 4.0 | 254.0 | 17.6 | 389.2 | 6.05 | 33.95500183105469 |
[19]:
# Multiple queries can be given as input at once
query_instances_boston = x_train[17:19]
genetic_boston = exp_genetic_boston.generate_counterfactuals(query_instances_boston, total_CFs=4, desired_range=[40, 50])
genetic_boston.visualize_as_dataframe(show_only_changes=True)
WARNING:root: MAD for feature ZN is 0, so replacing it with 1.0 to avoid error.
WARNING:root: MAD for feature CHAS is 0, so replacing it with 1.0 to avoid error.
WARNING:root: MAD for feature ZN is 0, so replacing it with 1.0 to avoid error.
WARNING:root: MAD for feature CHAS is 0, so replacing it with 1.0 to avoid error.
Query instance (original outcome : 49)
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.01501 | 90.0 | 1.21 | 1.0 | 0.401 | 7.923 | 24.8 | 5.885 | 1.0 | 198.0 | 13.6 | 395.52 | 3.16 | 49.108 |
Diverse Counterfactual set (new outcome: [40, 50])
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | - | - | 1.2 | - | - | - | - | - | - | - | - | 395.5 | - | 49.108001708984375 |
1 | 0.02009 | 95.0 | 2.7 | 0.0 | 0.416 | 8.034 | 31.9 | 5.118 | 4.0 | 224.0 | 14.7 | 390.5 | 2.88 | 49.50899887084961 |
2 | 0.0351 | 95.0 | 2.7 | 0.0 | 0.416 | 7.853 | 33.2 | 5.118 | 4.0 | 224.0 | 14.7 | 392.8 | 3.81 | 48.40999984741211 |
3 | 0.01538 | - | 3.8 | 0.0 | 0.394 | 7.454 | 34.2 | 6.3361 | 3.0 | 244.0 | 15.9 | 386.3 | 3.11 | 43.03799819946289 |
Query instance (original outcome : 31)
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.06911 | 45.0 | 3.44 | 0.0 | 0.437 | 6.739 | 30.8 | 6.4798 | 5.0 | 398.0 | 15.2 | 389.71 | 4.69 | 30.827 |
Diverse Counterfactual set (new outcome: [40, 50])
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.02009 | 95.0 | 2.7 | - | 0.416 | 8.034 | 31.9 | 5.118 | 4.0 | 224.0 | 14.7 | 390.5 | 2.88 | 49.50899887084961 |
1 | 0.0351 | 95.0 | 2.7 | - | 0.416 | 7.853 | 33.2 | 5.118 | 4.0 | 224.0 | 14.7 | 392.8 | 3.81 | 48.40999984741211 |
2 | 0.06129 | 20.0 | 3.3 | 1.0 | 0.443 | 7.645 | 49.7 | 5.2119 | - | 216.0 | 14.9 | 377.1 | 3.01 | 45.68299865722656 |
3 | 0.03578 | 20.0 | 3.3 | - | 0.443 | 7.82 | 64.5 | 4.6947 | - | 216.0 | 14.9 | 387.3 | 3.76 | 45.4640007019043 |