in lib/minecraft_util.py [0:0]
def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]:
entropy_sum = torch.zeros_like(template, dtype=torch.float)
count = torch.zeros(1, device=template.device, dtype=torch.int)
for k, subhead in module.items():
if isinstance(subhead, DictActionHead):
entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template)
elif isinstance(subhead, DiagGaussianActionHead):
entropy_sum += module.entropy(logits)
else:
continue
count += 1
return entropy_sum / count