lib/minecraft_util.py (66 lines of code) (raw):

import functools import inspect from typing import Optional, Tuple import numpy as np import torch from lib.action_head import (CategoricalActionHead, DiagGaussianActionHead, DictActionHead) def store_args(method): """Stores provided method args as instance attributes.""" argspec = inspect.getfullargspec(method) defaults = {} if argspec.defaults is not None: defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)) if argspec.kwonlydefaults is not None: defaults.update(argspec.kwonlydefaults) arg_names = argspec.args[1:] @functools.wraps(method) def wrapper(*positional_args, **keyword_args): self = positional_args[0] # Get default arg values args = defaults.copy() # Add provided arg values for name, value in zip(arg_names, positional_args[1:]): args[name] = value args.update(keyword_args) self.__dict__.update(args) return method(*positional_args, **keyword_args) return wrapper def get_norm_entropy_from_cat_head(module, name, masks, logits): # Note that the mask has already been applied to the logits at this point entropy = -torch.sum(torch.exp(logits) * logits, dim=-1) if name in masks: n = torch.sum(masks[name], dim=-1, dtype=torch.float) norm_entropy = entropy / torch.log(n) # When the mask only allows one option the normalized entropy makes no sense # as it is basically both maximal (the distribution is as uniform as it can be) # and minimal (there is no variance at all). # A such, we ignore them for purpose of calculating entropy. zero = torch.zeros_like(norm_entropy) norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy) count = n.not_equal(1.0).int() else: n = torch.tensor(logits.shape[-1], dtype=torch.float) norm_entropy = entropy / torch.log(n) count = torch.ones_like(norm_entropy, dtype=torch.int) # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in module.output_shape[:-1]: norm_entropy = norm_entropy.sum(dim=-1) count = count.sum(dim=-1) return norm_entropy, count def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]: entropy_sum = torch.zeros_like(template, dtype=torch.float) counts = torch.zeros_like(template, dtype=torch.int) for k, subhead in module.items(): if isinstance(subhead, DictActionHead): entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template) elif isinstance(subhead, CategoricalActionHead): entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k]) else: continue entropy_sum += entropy counts += count return entropy_sum, counts def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]: entropy_sum = torch.zeros_like(template, dtype=torch.float) count = torch.zeros(1, device=template.device, dtype=torch.int) for k, subhead in module.items(): if isinstance(subhead, DictActionHead): entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template) elif isinstance(subhead, DiagGaussianActionHead): entropy_sum += module.entropy(logits) else: continue count += 1 return entropy_sum / count