def get_diag_guassian_entropy()

in lib/minecraft_util.py [0:0]


def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]:
    entropy_sum = torch.zeros_like(template, dtype=torch.float)
    count = torch.zeros(1, device=template.device, dtype=torch.int)
    for k, subhead in module.items():
        if isinstance(subhead, DictActionHead):
            entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template)
        elif isinstance(subhead, DiagGaussianActionHead):
            entropy_sum += module.entropy(logits)
        else:
            continue
        count += 1
    return entropy_sum / count