def get_norm_cat_entropy()

in lib/minecraft_util.py [0:0]


def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]:
    entropy_sum = torch.zeros_like(template, dtype=torch.float)
    counts = torch.zeros_like(template, dtype=torch.int)
    for k, subhead in module.items():
        if isinstance(subhead, DictActionHead):
            entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template)
        elif isinstance(subhead, CategoricalActionHead):
            entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k])
        else:
            continue
        entropy_sum += entropy
        counts += count
    return entropy_sum, counts