in lib/minecraft_util.py [0:0]
def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]:
entropy_sum = torch.zeros_like(template, dtype=torch.float)
counts = torch.zeros_like(template, dtype=torch.int)
for k, subhead in module.items():
if isinstance(subhead, DictActionHead):
entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template)
elif isinstance(subhead, CategoricalActionHead):
entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k])
else:
continue
entropy_sum += entropy
counts += count
return entropy_sum, counts