lib/action_head.py (157 lines of code) (raw):

import logging from typing import Any, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from gym3.types import DictType, Discrete, Real, TensorType, ValType LOG0 = -100 def fan_in_linear(module: nn.Module, scale=1.0, bias=True): """Fan-in init""" module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True) if bias: module.bias.data *= 0 class ActionHead(nn.Module): """Abstract base class for action heads compatible with forc""" def forward(self, input_data: torch.Tensor) -> Any: """ Just a forward pass through this head :returns pd_params - parameters describing the probability distribution """ raise NotImplementedError def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor: """Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`""" raise NotImplementedError def entropy(self, pd_params: torch.Tensor) -> torch.Tensor: """Entropy of this distribution""" raise NotImplementedError def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any: """ Draw a sample from probability distribution given by those params :param pd_params Parameters of a probability distribution :param deterministic Whether to return a stochastic sample or deterministic mode of a distribution """ raise NotImplementedError def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor: """KL divergence between two distribution described by these two params""" raise NotImplementedError class DiagGaussianActionHead(ActionHead): """ Action head where actions are normally distributed uncorrelated variables with specific means and variances. Means are calculated directly from the network while standard deviations are a parameter of this module """ LOG2PI = np.log(2.0 * np.pi) def __init__(self, input_dim: int, num_dimensions: int): super().__init__() self.input_dim = input_dim self.num_dimensions = num_dimensions self.linear_layer = nn.Linear(input_dim, num_dimensions) self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True) def reset_parameters(self): init.orthogonal_(self.linear_layer.weight, gain=0.01) init.constant_(self.linear_layer.bias, 0.0) def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor: assert not mask, "Can not use a mask in a gaussian action head" means = self.linear_layer(input_data) # Unsqueeze many times to get to the same shape logstd = self.log_std[(None,) * (len(means.shape) - 1)] mean_view, logstd = torch.broadcast_tensors(means, logstd) return torch.stack([mean_view, logstd], dim=-1) def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor: """Log-likelihood""" means = pd_params[..., 0] log_std = pd_params[..., 1] std = torch.exp(log_std) z_score = (action_sample - means) / std return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1)) def entropy(self, pd_params: torch.Tensor) -> torch.Tensor: """ Categorical distribution entropy calculation - sum probs * log(probs). In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2) """ log_std = pd_params[..., 1] return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1) def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor: means = pd_params[..., 0] log_std = pd_params[..., 1] if deterministic: return means else: return torch.randn_like(means) * torch.exp(log_std) + means def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor: """ Categorical distribution KL divergence calculation KL(Q || P) = sum Q_i log (Q_i / P_i) Formula is: log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2) """ means_q = params_q[..., 0] log_std_q = params_q[..., 1] means_p = params_p[..., 0] log_std_p = params_p[..., 1] std_q = torch.exp(log_std_q) std_p = torch.exp(log_std_p) kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5 return kl_div.sum(dim=-1, keepdim=True) class CategoricalActionHead(ActionHead): """Action head with categorical actions""" def __init__( self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0 ): super().__init__() self.input_dim = input_dim self.num_actions = num_actions self.output_shape = shape + (num_actions,) self.temperature = temperature if builtin_linear_layer: self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape)) else: assert ( input_dim == num_actions ), f"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them." self.linear_layer = None def reset_parameters(self): if self.linear_layer is not None: init.orthogonal_(self.linear_layer.weight, gain=0.01) init.constant_(self.linear_layer.bias, 0.0) finit.fan_in_linear(self.linear_layer, scale=0.01) def forward(self, input_data: torch.Tensor, mask=None) -> Any: if self.linear_layer is not None: flat_out = self.linear_layer(input_data) else: flat_out = input_data shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape) shaped_out /= self.temperature if mask is not None: shaped_out[~mask] = LOG0 # Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half' return F.log_softmax(shaped_out.float(), dim=-1) def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: value = actions.long().unsqueeze(-1) value, log_pmf = torch.broadcast_tensors(value, logits) value = value[..., :1] result = log_pmf.gather(-1, value).squeeze(-1) # result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: result = result.sum(dim=-1) return result def entropy(self, logits: torch.Tensor) -> torch.Tensor: """Categorical distribution entropy calculation - sum probs * log(probs)""" probs = torch.exp(logits) entropy = -torch.sum(probs * logits, dim=-1) # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: entropy = entropy.sum(dim=-1) return entropy def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any: if deterministic: return torch.argmax(logits, dim=-1) else: # Gumbel-Softmax trick. u = torch.rand_like(logits) # In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0 # Then the log(-log(1.0)) will give -inf when it should give +inf # This is a silly hack to get around that. # This hack does not skew the probability distribution, because this event can't possibly win the argmax. u[u == 1.0] = 0.999 return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1) def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor: """ Categorical distribution KL divergence calculation KL(Q || P) = sum Q_i log (Q_i / P_i) When talking about logits this is: sum exp(Q_i) * (Q_i - P_i) """ kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True) # kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: kl = kl.sum(dim=-2) # dim=-2 because we use keepdim=True above. return kl class DictActionHead(nn.ModuleDict): """Action head with multiple sub-actions""" def reset_parameters(self): for subhead in self.values(): subhead.reset_parameters() def forward(self, input_data: torch.Tensor, **kwargs) -> Any: """ :param kwargs: each kwarg should be a dict with keys corresponding to self.keys() e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call: forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}} Then children will be called with: A: forward(input_data, foo=True, bar=7) B: forward(input_data) C: forward(input_Data, foo=False) """ result = {} for head_name, subhead in self.items(): head_kwargs = { kwarg_name: kwarg[head_name] for kwarg_name, kwarg in kwargs.items() if kwarg is not None and head_name in kwarg } result[head_name] = subhead(input_data, **head_kwargs) return result def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items()) def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any: return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()} def entropy(self, logits: torch.Tensor) -> torch.Tensor: return sum(subhead.entropy(logits[k]) for k, subhead in self.items()) def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor: return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items()) def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0): """Helper function to create an action head corresponding to the environment action space""" if isinstance(ac_space, TensorType): if isinstance(ac_space.eltype, Discrete): return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature) elif isinstance(ac_space.eltype, Real): if temperature != 1.0: logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.") assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented." return DiagGaussianActionHead(pi_out_size, ac_space.shape[0]) elif isinstance(ac_space, DictType): return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()}) raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")