in causalml/inference/tree/plot.py [0:0]
def get_fill_color(self, tree: _tree.Tree, node_id: int) -> str:
"""
Fetch appropriate color for node
Args:
tree: Tree class
node_id: int, node index
Returns: str
"""
if "rgb" not in self.colors:
# Initialize colors and bounds if required
self.colors["rgb"] = _color_brew(tree.n_classes[0])
if tree.n_outputs != 1:
# Find max and min impurities for multi-output
self.colors["bounds"] = (
np.nanmin(-tree.impurity),
np.nanmax(-tree.impurity),
)
elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
# Find max and min values in leaf nodes for regression
self.colors["bounds"] = (np.nanmin(tree.value), np.nanmax(tree.value))
if tree.n_outputs == 1:
node_val = tree.value[node_id][0, :] / tree.weighted_n_node_samples[node_id]
if tree.n_classes[0] == 1:
# Regression
node_val = tree.value[node_id][0, :]
else:
# If multi-output color node by impurity
node_val = -tree.impurity[node_id]
return self.get_color(node_val)