in lib/minecraft_util.py [0:0]
def get_norm_entropy_from_cat_head(module, name, masks, logits):
# Note that the mask has already been applied to the logits at this point
entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)
if name in masks:
n = torch.sum(masks[name], dim=-1, dtype=torch.float)
norm_entropy = entropy / torch.log(n)
# When the mask only allows one option the normalized entropy makes no sense
# as it is basically both maximal (the distribution is as uniform as it can be)
# and minimal (there is no variance at all).
# A such, we ignore them for purpose of calculating entropy.
zero = torch.zeros_like(norm_entropy)
norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy)
count = n.not_equal(1.0).int()
else:
n = torch.tensor(logits.shape[-1], dtype=torch.float)
norm_entropy = entropy / torch.log(n)
count = torch.ones_like(norm_entropy, dtype=torch.int)
# entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in module.output_shape[:-1]:
norm_entropy = norm_entropy.sum(dim=-1)
count = count.sum(dim=-1)
return norm_entropy, count