in econml/_tree_exporter.py [0:0]
def node_replacement_text(self, tree, node_id, criterion):
# Write node mean CATE
node_info = self.node_dict[node_id]
node_string = 'CATE mean' + self.characters[4]
value_text = ""
mean = node_info['mean']
if hasattr(mean, 'shape') and (len(mean.shape) > 0):
if len(mean.shape) == 1:
for i in range(mean.shape[0]):
value_text += "{}".format(np.around(mean[i], self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0][i], self.precision),
np.around(node_info['ci'][1][i], self.precision))
if i != mean.shape[0] - 1:
value_text += ", "
value_text += self.characters[4]
elif len(mean.shape) == 2:
for i in range(mean.shape[0]):
for j in range(mean.shape[1]):
value_text += "{}".format(np.around(mean[i, j], self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0][i, j], self.precision),
np.around(node_info['ci'][1][i, j], self.precision))
if j != mean.shape[1] - 1:
value_text += ", "
value_text += self.characters[4]
else:
raise ValueError("can only handle up to 2d values")
else:
value_text += "{}".format(np.around(mean, self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0], self.precision),
np.around(node_info['ci'][1], self.precision))
value_text += self.characters[4]
node_string += value_text
# Write node std of CATE
node_string += "CATE std" + self.characters[4]
std = node_info['std']
value_text = ""
if hasattr(std, 'shape') and (len(std.shape) > 0):
if len(std.shape) == 1:
for i in range(std.shape[0]):
value_text += "{}".format(np.around(std[i], self.precision))
if i != std.shape[0] - 1:
value_text += ", "
elif len(std.shape) == 2:
for i in range(std.shape[0]):
for j in range(std.shape[1]):
value_text += "{}".format(np.around(std[i, j], self.precision))
if j != std.shape[1] - 1:
value_text += ", "
if i != std.shape[0] - 1:
value_text += self.characters[4]
else:
raise ValueError("can only handle up to 2d values")
else:
value_text += "{}".format(np.around(std, self.precision))
node_string += value_text
return node_string