in causalml/inference/tree/uplift.pyx [0:0]
def arr_normI(self, cur_node_summary_n, left_node_summary_n,
alpha: cython.float = 0.9, currentDivergence: cython.float = 0.0) -> cython.float:
'''
Normalization factor.
Args
----
cur_node_summary_n : array of shape [n_class]
Has type numpy.int32.
The counts of each of the control
and treament groups of the current node, i.e. [N(T=i)...]
left_node_summary_n : array of shape [n_class]
Has type numpy.int32.
The counts of each of the control
and treament groups of the left node, i.e. [N(T=i)...]
alpha : float
The weight used to balance different normalization parts.
Returns
-------
norm_res : float
Normalization factor.
'''
cdef N_TYPE_t[::1] cur_summary_n = cur_node_summary_n
cdef N_TYPE_t[::1] left_summary_n = left_node_summary_n
cdef int n_class = cur_summary_n.shape[0]
cdef int i = 0
cdef P_TYPE_t norm_res = 0.0
cdef P_TYPE_t n_c = cur_summary_n[0]
cdef P_TYPE_t n_c_left = left_summary_n[0]
cdef P_TYPE_t pt_a = 0.0, pt_a_i = 0.0, pc_a = 0.0, sum_n_t_left = 0.0, sum_n_t = 0.0
for i in range(1, n_class):
sum_n_t_left += left_summary_n[i]
sum_n_t += cur_summary_n[i]
pt_a = 1. * sum_n_t_left / (sum_n_t + 0.1)
pc_a = 1. * n_c_left / (n_c + 0.1)
if self.evaluationFunction == self.evaluate_IDDP:
# Normalization Part 1
norm_res += (entropyH(1. * sum_n_t / (sum_n_t + n_c), 1. * n_c / (sum_n_t + n_c)) * currentDivergence)
norm_res += (1. * sum_n_t / (sum_n_t + n_c) * entropyH(pt_a))
else:
# Normalization Part 1
norm_res += (alpha * entropyH(1. * sum_n_t / (sum_n_t + n_c), 1. * n_c / (sum_n_t + n_c)) * kl_divergence(pt_a, pc_a))
# Normalization Part 2 & 3
for i in range(1, n_class):
pt_a_i = 1. * left_summary_n[i] / (cur_summary_n[i] + 0.1)
norm_res += ((1 - alpha) * entropyH(1. * cur_summary_n[i] / (cur_summary_n[i] + n_c), 1. * n_c / (cur_summary_n[i] + n_c)) * kl_divergence(1. * pt_a_i, pc_a))
norm_res += (1. * cur_summary_n[i] / (sum_n_t + n_c) * entropyH(pt_a_i))
# Normalization Part 4
norm_res += 1. * n_c / (sum_n_t + n_c) * entropyH(pc_a)
# Normalization Part 5
norm_res += 0.5
return norm_res