"""
Manages the tree representation for the APC model. It includes the ``TreeRepresentation`` class which handles the construction and manipulation of decision trees
and ``TreeNode`` class that represents a node in the tree.
This module is crucial for profiling aggregated data and extracting valuable insights
"""
import json
from typing import Union, Any
from pandas import DataFrame, Series
import numpy as np
from MED3pa.models.concrete_regressors import DecisionTreeRegressorModel
from .profiles import Profile
[docs]
def to_serializable(obj: Any, additional_arg: Any = None) -> Any:
"""Convert an object to a JSON-serializable format.
Args:
obj (Any): The object to convert.
Returns:
Any: The JSON-serializable representation of the object.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating)):
return obj.item()
if isinstance(obj, Profile):
if additional_arg is not None:
return obj.to_dict(additional_arg)
else:
return obj.to_dict()
if isinstance(obj, _TreeNode):
return obj.to_dict()
if isinstance(obj, dict):
return {k: to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [to_serializable(v) for v in obj]
return obj
[docs]
class TreeRepresentation:
"""
Represents the structure of a decision tree for a given set of features.
"""
def __init__(self, features: list) -> None:
"""
Initializes the TreeRepresentation with a list of feature names.
Args:
features (List[str]): List of feature names used in the decision tree.
"""
self.features = features
self.head = None
self.nb_nodes = 0
[docs]
def build_tree(self, dtr: DecisionTreeRegressorModel, X: DataFrame, y: Series, node_id: int = 0, path: list = ['*']) -> '_TreeNode':
"""
Recursively builds the tree representation starting from the specified node.
Args:
dtr (DecisionTreeRegressorModel): Trained decision tree regressor model.
X (DataFrame): Training data observations.
y (Series): Training data labels.
node_id (int): Node ID to start building from. Defaults to 0.
path (Optional[List[str]]): Path to the current node. Defaults to ['*'].
Returns:
_TreeNode: The root node of the tree representation.
"""
self.nb_nodes += 1
left_child = dtr.model.tree_.children_left[node_id]
right_child = dtr.model.tree_.children_right[node_id]
node_value = y.mean()
node_max = y.max()
node_samples_ratio = dtr.model.tree_.n_node_samples[node_id] / dtr.model.tree_.n_node_samples[0] * 100
# If we are at a leaf
if left_child == -1:
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
node_id=self.nb_nodes, path=path)
return curr_node
node_thresh = dtr.model.tree_.threshold[node_id]
node_feature_id = dtr.model.tree_.feature[node_id]
node_feature = self.features[node_feature_id]
# Check if the split would result in an empty set, if so, stop the recursion
if y[X[node_feature] <= node_thresh].size == 0 or y[X[node_feature] > node_thresh].size == 0:
print("split would results in an empty data section")
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
node_id=self.nb_nodes, path=path)
return curr_node
curr_path = list(path) # Copy the current path to avoid modifying the original list
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
threshold=node_thresh, feature=node_feature, feature_id=node_feature_id,
node_id=self.nb_nodes, path=curr_path)
# Update paths for child nodes
left_path = curr_path + [f"{node_feature} <= {node_thresh}"]
right_path = curr_path + [f"{node_feature} > {node_thresh}"]
curr_node.c_left = self.build_tree(dtr, X=X.loc[X[node_feature] <= node_thresh], y=y[X[node_feature] <= node_thresh],
node_id=left_child, path=left_path)
curr_node.c_right = self.build_tree(dtr, X=X.loc[X[node_feature] > node_thresh], y=y[X[node_feature] > node_thresh],
node_id=right_child, path=right_path)
return curr_node
[docs]
def get_all_profiles(self, min_ca: float = 0, min_samples_ratio: float = 0) -> list:
"""
Retrieves all profiles from the tree that meet the minimum criteria for value and sample ratio.
Args:
min_ca (float): Minimum value threshold for profiles. Defaults to 0.
min_samples_ratio (float): Minimum sample ratio threshold for profiles. Defaults to 0.
Returns:
List[Profile]: A list of Profile instances meeting the specified criteria.
"""
if self.head is None:
raise ValueError("Tree has not been built yet.")
profiles = self.head.get_profile(min_samples_ratio=min_samples_ratio, min_ca=min_ca)
return profiles
[docs]
def get_all_nodes(self) -> list:
"""
Retrieves all nodes from the tree with their paths.
Returns:
List[dict]: A list of dictionaries representing nodes with their paths.
"""
if self.head is None:
raise ValueError("Tree has not been built yet.")
return self.head.get_all_nodes()
[docs]
def save_tree(self, file_path: str) -> None:
"""
Saves the tree structure to a JSON file.
Args:
file_path (str): The file path where the tree structure will be saved.
"""
if self.head is None:
raise ValueError("Tree has not been built yet.")
tree_dict = {}
tree_dict = self.head.to_dict()
tree_dict['features'] = self.features
with open(file_path, 'w') as file:
json.dump(tree_dict, file, default=to_serializable, indent=4)
class _TreeNode:
"""
Represents a node in the tree structure.
"""
def __init__(self, value: float =None, value_max: float=None, samples_ratio: float=None, threshold: float = None,
feature: str = None, feature_id: int = None, node_id: int = 0, path: list = None) -> None:
"""
Initializes a _TreeNode object.
Args:
value (float): The average value at the node.
value_max (float): The maximum value at the node.
samples_ratio (float): The percentage of total samples present at the node.
threshold (Optional[float]): The threshold used for splitting at this node. Defaults to None.
feature (Optional[str]): The feature used for splitting at this node. Defaults to None.
feature_id (Optional[int]): The identifier of the feature used for splitting. Defaults to None.
node_id (int): Unique identifier for the node. Defaults to 0.
path (Optional[List[str]]): The path from the root to this node. Defaults to an empty list.
"""
self.c_left = None
self.c_right = None
self.value = value
self.value_max = value_max
self.samples_ratio = samples_ratio
self.threshold = threshold
self.feature = feature
self.feature_id = feature_id
self.node_id = node_id
self.path = path if path is not None else []
def assign_node(self, X: Union[DataFrame, Series]) -> float:
"""
Assigns a value to a node based on input observations, navigating the tree until a leaf node is reached.
Args:
X (Union[DataFrame, Series]): Input observations used to navigate and determine the value at a node.
min_samples_ratio (float): The minimum sample ratio to consider a node as valid for value assignment. Nodes
with a sample ratio below this threshold will use the value from the nearest valid ancestor.
Defaults to 0, which considers all nodes valid regardless of sample ratio.
Returns:
float: The value assigned based on the input observations and the structure of the tree.
Raises:
TypeError: If the input X is neither a pandas DataFrame nor a pandas Series.
"""
# Check if the current node is a leaf node
if self.c_left is None and self.c_right is None:
return self.value
if isinstance(X, DataFrame):
X_value = X[self.feature].values[0]
elif isinstance(X, Series):
X_value = X[self.feature]
else:
raise TypeError(f"Parameter X is of type {type(X)}, but it must be of type 'pandas.DataFrame' or 'pandas.Series'.")
if X_value <= self.threshold: # If node split condition is true, then left children
c_node = self.c_left
else:
c_node = self.c_right
return c_node.assign_node(X)
def assign_node_deprecated(self, X: Union[DataFrame, Series], depth: int = None, min_samples_ratio: float = 0) -> float:
"""
Assigns a value to a node based on input observations, potentially navigating the tree up to a certain depth.
Args:
X (Union[DataFrame, Series]): Input observations used to navigate and determine the value at a node.
depth (Optional[int]): The maximum depth to navigate in the tree for value assignment. Defaults to None,
which means navigating until a leaf node is reached.
min_samples_ratio (float): The minimum sample ratio to consider a node as valid for value assignment. Nodes
with a sample ratio below this threshold will use the value from the nearest valid ancestor.
Defaults to 0, which considers all nodes valid regardless of sample ratio.
Returns:
float: The value assigned based on the input observations and the structure of the tree.
Raises:
TypeError: If the input X is neither a pandas DataFrame nor a pandas Series.
"""
if depth == 0 or self.c_left is None:
return self.value
if isinstance(X, DataFrame):
X_value = X[self.feature].values[0]
elif isinstance(X, Series):
X_value = X[self.feature]
else:
raise TypeError(f"Parameter X is of type {type(X)}, but it must be of type 'pandas.DataFrame' or 'pandas.Series'.")
if depth is not None:
depth -= 1
if X_value <= self.threshold: # If node split condition is true, then left children
c_node = self.c_left
else:
c_node = self.c_right
if c_node.samples_ratio < min_samples_ratio: # If not enough samples in child node
return self.value
return c_node.assign_node(X, depth, min_samples_ratio)
def get_profile(self, min_samples_ratio: float, min_ca: float) -> list:
"""
Retrieves profiles from the subtree rooted at this node that meet the specified criteria.
Args:
min_samples_ratio (float): The minimum sample ratio a node must have to be included in the output profiles.
min_ca (float): The minimum value a node must have to be included in the output profiles.
Returns:
List[Profile]: A list of Profile instances representing nodes that meet the criteria.
"""
profiles = []
if self.c_left is not None and self.c_left.samples_ratio >= min_samples_ratio:
# Recursively retrieve profiles from the left child
profiles.extend(self.c_left.get_profile(min_samples_ratio, min_ca))
if self.c_right is not None and self.c_right.samples_ratio >= min_samples_ratio:
# Recursively retrieve profiles from the right child
profiles.extend(self.c_right.get_profile(min_samples_ratio, min_ca))
# Check if the current node meets the criteria
if self.samples_ratio >= min_samples_ratio and self.value >= min_ca:
profile = Profile(node_id=self.node_id, path=self.path)
profiles.append(profile)
return profiles
def get_all_nodes(self) -> list:
"""
Retrieves all nodes in the subtree rooted at this node with their paths.
Returns:
List[dict]: A list of dictionaries representing nodes with their paths.
"""
nodes = [{
'node_id': self.node_id,
'path': self.path
}]
if self.c_left is not None:
nodes.extend(self.c_left.get_all_nodes())
if self.c_right is not None:
nodes.extend(self.c_right.get_all_nodes())
return nodes
def to_dict(self) -> dict:
"""
Converts the node and its children to a dictionary.
Returns:
dict: A dictionary representation of the node and its children.
"""
node_dict = {
'threshold': self.threshold,
'feature': self.feature,
'feature_id': self.feature_id,
'node_id': self.node_id,
'path': self.path
}
if self.c_left is not None:
node_dict['c_left'] = self.c_left.to_dict()
if self.c_right is not None:
node_dict['c_right'] = self.c_right.to_dict()
return node_dict
def print_tree(self, depth=0):
"""
Prints the tree structure.
"""
indent = " " * depth
if self.threshold is None:
print(f"{indent}Leaf node (ID: {self.node_id}, Value: {self.value:.4f}, Max Value: {self.value_max:.4f}, Samples Ratio: {self.samples_ratio:.2f}%)")
else:
print(f"{indent}Node (ID: {self.node_id}, Feature: {self.feature}, Threshold: {self.threshold:.4f}, Value: {self.value:.4f}, Max Value: {self.value_max:.4f}, Samples Ratio: {self.samples_ratio:.2f}%)")
if self.c_left:
print(f"{indent} Left:")
self.c_left.print_tree(depth + 1)
if self.c_right:
print(f"{indent} Right:")
self.c_right.print_tree(depth + 1)