Source code for rocelib.recourse_methods.KDTreeNNCE

import pandas as pd
import torch
from sklearn.neighbors import KDTree

from rocelib.recourse_methods.RecourseGenerator import RecourseGenerator


[docs] class KDTreeNNCE(RecourseGenerator): """ A recourse generator that uses KD-Tree for nearest neighbor counterfactual explanations. Inherits from the RecourseGenerator class and implements the _generation_method to find counterfactual explanations using KD-Tree for nearest neighbors. Attributes: _task (Task): The task to solve, inherited from RecourseGenerator. __customFunc (callable, optional): A custom distance function, inherited from RecourseGenerator. """ def _generation_method(self, instance, gamma=0.1, column_name="target", neg_value=0, **kwargs) -> pd.DataFrame: """ Generates a counterfactual explanation using KD-Tree for nearest neighbor search. @param instance: The instance for which to generate a counterfactual. @param distance_func: The function used to calculate the distance between points. @param custom_distance_func: Optional custom distance function. (Not used in this method) @param gamma: The distance threshold for convergence. (Not used in this method) @param column_name: The name of the target column. (Not used in this method) @param neg_value: The value considered negative in the target variable. @param kwargs: Additional keyword arguments. @return: A DataFrame containing the nearest counterfactual explanation or None if no positive instances. """ model = self.task.model # Convert X values of dataset to tensor X_tensor = torch.tensor(self.task.training_data.X.values, dtype=torch.float32) # Get all model predictions of model, turning them to 0s or 1s model_labels = model.predict(X_tensor) model_labels = (model_labels >= 0.5).astype(int) # Determine the target label y = neg_value nnce_y = 1 - y # Convert instance to DataFrame if it is a Series if isinstance(instance, pd.Series): instance = instance.to_frame().T # Prepare the data preds = self.task.training_data.X.copy() preds["predicted"] = model_labels # Filter out instances that have the desired counterfactual label positive_instances = preds[preds["predicted"] == nnce_y].drop(columns=["predicted"]) # If there are no positive instances, return None if positive_instances.empty: return instance # Build KD-Tree kd_tree = KDTree(positive_instances.values) # Query the KD-Tree for the nearest neighbour dist, idx = kd_tree.query(instance.values, k=1, return_distance=True) nearest_instance = positive_instances.iloc[idx[0]] nearest_instance["predicted"] = nnce_y # Add the distance as a new column nearest_instance["Loss"] = dist[0] return nearest_instance