in mico/utils/net_utils.py [0:0]
def cross_entropy_p_q(p, q):
"""This is the function calculating the cross entropy between two distributions.
If there are multiple distributions in `p` or in `q`, we calculate the cross entropy correspondingly and take the average.
"""
if len(q.size()) == 2:
q = q.repeat(p.size(0), 1, 1)
return (- (p * torch.log(q)).sum(dim=(1, 2))).mean()