Module imodels.util.tree
Expand source code
from sklearn import datasets
from sklearn.tree import DecisionTreeRegressor
def compute_tree_complexity(tree, complexity_measure='num_rules'):
"""Calculate number of non-leaf nodes
"""
children_left = tree.children_left
children_right = tree.children_right
# num_split_nodes = 0
complexity = 0
stack = [(0, 0)] # start with the root node id (0) and its depth (0)
while len(stack) > 0:
# `pop` ensures each node is only visited once
node_id, depth = stack.pop()
# If the left and right child of a node is not the same we have a split
# node
is_split_node = children_left[node_id] != children_right[node_id]
# If a split node, append left and right children and depth to `stack`
# so we can loop through them
if is_split_node:
if complexity_measure == 'num_rules':
complexity += 1
stack.append((children_left[node_id], depth + 1))
stack.append((children_right[node_id], depth + 1))
else:
if complexity_measure != 'num_rules':
complexity += 1
return complexity
if __name__ == '__main__':
X, y = datasets.fetch_california_housing(return_X_y=True) # regression
m = DecisionTreeRegressor(random_state=42, max_leaf_nodes=4)
m.fit(X, y)
print(compute_tree_complexity(m.tree_, complexity_measure='num_leaves'))
Functions
def compute_tree_complexity(tree, complexity_measure='num_rules')
-
Calculate number of non-leaf nodes
Expand source code
def compute_tree_complexity(tree, complexity_measure='num_rules'): """Calculate number of non-leaf nodes """ children_left = tree.children_left children_right = tree.children_right # num_split_nodes = 0 complexity = 0 stack = [(0, 0)] # start with the root node id (0) and its depth (0) while len(stack) > 0: # `pop` ensures each node is only visited once node_id, depth = stack.pop() # If the left and right child of a node is not the same we have a split # node is_split_node = children_left[node_id] != children_right[node_id] # If a split node, append left and right children and depth to `stack` # so we can loop through them if is_split_node: if complexity_measure == 'num_rules': complexity += 1 stack.append((children_left[node_id], depth + 1)) stack.append((children_right[node_id], depth + 1)) else: if complexity_measure != 'num_rules': complexity += 1 return complexity