def cross_entropy_p_q()

in mico/utils/net_utils.py [0:0]


def cross_entropy_p_q(p, q):
    """This is the function calculating the cross entropy between two distributions.
    If there are multiple distributions in `p` or in `q`, we calculate the cross entropy correspondingly and take the average.
    """
    if len(q.size()) == 2:
        q = q.repeat(p.size(0), 1, 1)
    return (- (p * torch.log(q)).sum(dim=(1, 2))).mean()