def get_norm_entropy_from_cat_head()

in lib/minecraft_util.py [0:0]


def get_norm_entropy_from_cat_head(module, name, masks, logits):
    # Note that the mask has already been applied to the logits at this point
    entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)
    if name in masks:
        n = torch.sum(masks[name], dim=-1, dtype=torch.float)
        norm_entropy = entropy / torch.log(n)
        # When the mask only allows one option the normalized entropy makes no sense
        # as it is basically both maximal (the distribution is as uniform as it can be)
        # and minimal (there is no variance at all).
        # A such, we ignore them for purpose of calculating entropy.
        zero = torch.zeros_like(norm_entropy)
        norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy)
        count = n.not_equal(1.0).int()
    else:
        n = torch.tensor(logits.shape[-1], dtype=torch.float)
        norm_entropy = entropy / torch.log(n)
        count = torch.ones_like(norm_entropy, dtype=torch.int)

    # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
    for _ in module.output_shape[:-1]:
        norm_entropy = norm_entropy.sum(dim=-1)
        count = count.sum(dim=-1)
    return norm_entropy, count