lib/action_head.py (157 lines of code) (raw):
import logging
from typing import Any, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from gym3.types import DictType, Discrete, Real, TensorType, ValType
LOG0 = -100
def fan_in_linear(module: nn.Module, scale=1.0, bias=True):
"""Fan-in init"""
module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)
if bias:
module.bias.data *= 0
class ActionHead(nn.Module):
"""Abstract base class for action heads compatible with forc"""
def forward(self, input_data: torch.Tensor) -> Any:
"""
Just a forward pass through this head
:returns pd_params - parameters describing the probability distribution
"""
raise NotImplementedError
def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
"""Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`"""
raise NotImplementedError
def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
"""Entropy of this distribution"""
raise NotImplementedError
def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any:
"""
Draw a sample from probability distribution given by those params
:param pd_params Parameters of a probability distribution
:param deterministic Whether to return a stochastic sample or deterministic mode of a distribution
"""
raise NotImplementedError
def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
"""KL divergence between two distribution described by these two params"""
raise NotImplementedError
class DiagGaussianActionHead(ActionHead):
"""
Action head where actions are normally distributed uncorrelated variables with specific means and variances.
Means are calculated directly from the network while standard deviations are a parameter of this module
"""
LOG2PI = np.log(2.0 * np.pi)
def __init__(self, input_dim: int, num_dimensions: int):
super().__init__()
self.input_dim = input_dim
self.num_dimensions = num_dimensions
self.linear_layer = nn.Linear(input_dim, num_dimensions)
self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True)
def reset_parameters(self):
init.orthogonal_(self.linear_layer.weight, gain=0.01)
init.constant_(self.linear_layer.bias, 0.0)
def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor:
assert not mask, "Can not use a mask in a gaussian action head"
means = self.linear_layer(input_data)
# Unsqueeze many times to get to the same shape
logstd = self.log_std[(None,) * (len(means.shape) - 1)]
mean_view, logstd = torch.broadcast_tensors(means, logstd)
return torch.stack([mean_view, logstd], dim=-1)
def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
"""Log-likelihood"""
means = pd_params[..., 0]
log_std = pd_params[..., 1]
std = torch.exp(log_std)
z_score = (action_sample - means) / std
return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1))
def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution entropy calculation - sum probs * log(probs).
In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2)
"""
log_std = pd_params[..., 1]
return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1)
def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
means = pd_params[..., 0]
log_std = pd_params[..., 1]
if deterministic:
return means
else:
return torch.randn_like(means) * torch.exp(log_std) + means
def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution KL divergence calculation
KL(Q || P) = sum Q_i log (Q_i / P_i)
Formula is:
log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2)
"""
means_q = params_q[..., 0]
log_std_q = params_q[..., 1]
means_p = params_p[..., 0]
log_std_p = params_p[..., 1]
std_q = torch.exp(log_std_q)
std_p = torch.exp(log_std_p)
kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5
return kl_div.sum(dim=-1, keepdim=True)
class CategoricalActionHead(ActionHead):
"""Action head with categorical actions"""
def __init__(
self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0
):
super().__init__()
self.input_dim = input_dim
self.num_actions = num_actions
self.output_shape = shape + (num_actions,)
self.temperature = temperature
if builtin_linear_layer:
self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape))
else:
assert (
input_dim == num_actions
), f"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them."
self.linear_layer = None
def reset_parameters(self):
if self.linear_layer is not None:
init.orthogonal_(self.linear_layer.weight, gain=0.01)
init.constant_(self.linear_layer.bias, 0.0)
finit.fan_in_linear(self.linear_layer, scale=0.01)
def forward(self, input_data: torch.Tensor, mask=None) -> Any:
if self.linear_layer is not None:
flat_out = self.linear_layer(input_data)
else:
flat_out = input_data
shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape)
shaped_out /= self.temperature
if mask is not None:
shaped_out[~mask] = LOG0
# Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'
return F.log_softmax(shaped_out.float(), dim=-1)
def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
value = actions.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, logits)
value = value[..., :1]
result = log_pmf.gather(-1, value).squeeze(-1)
# result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
result = result.sum(dim=-1)
return result
def entropy(self, logits: torch.Tensor) -> torch.Tensor:
"""Categorical distribution entropy calculation - sum probs * log(probs)"""
probs = torch.exp(logits)
entropy = -torch.sum(probs * logits, dim=-1)
# entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
entropy = entropy.sum(dim=-1)
return entropy
def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
if deterministic:
return torch.argmax(logits, dim=-1)
else:
# Gumbel-Softmax trick.
u = torch.rand_like(logits)
# In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0
# Then the log(-log(1.0)) will give -inf when it should give +inf
# This is a silly hack to get around that.
# This hack does not skew the probability distribution, because this event can't possibly win the argmax.
u[u == 1.0] = 0.999
return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1)
def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution KL divergence calculation
KL(Q || P) = sum Q_i log (Q_i / P_i)
When talking about logits this is:
sum exp(Q_i) * (Q_i - P_i)
"""
kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True)
# kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
kl = kl.sum(dim=-2) # dim=-2 because we use keepdim=True above.
return kl
class DictActionHead(nn.ModuleDict):
"""Action head with multiple sub-actions"""
def reset_parameters(self):
for subhead in self.values():
subhead.reset_parameters()
def forward(self, input_data: torch.Tensor, **kwargs) -> Any:
"""
:param kwargs: each kwarg should be a dict with keys corresponding to self.keys()
e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call:
forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}}
Then children will be called with:
A: forward(input_data, foo=True, bar=7)
B: forward(input_data)
C: forward(input_Data, foo=False)
"""
result = {}
for head_name, subhead in self.items():
head_kwargs = {
kwarg_name: kwarg[head_name]
for kwarg_name, kwarg in kwargs.items()
if kwarg is not None and head_name in kwarg
}
result[head_name] = subhead(input_data, **head_kwargs)
return result
def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items())
def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()}
def entropy(self, logits: torch.Tensor) -> torch.Tensor:
return sum(subhead.entropy(logits[k]) for k, subhead in self.items())
def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items())
def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):
"""Helper function to create an action head corresponding to the environment action space"""
if isinstance(ac_space, TensorType):
if isinstance(ac_space.eltype, Discrete):
return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)
elif isinstance(ac_space.eltype, Real):
if temperature != 1.0:
logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.")
assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented."
return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])
elif isinstance(ac_space, DictType):
return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})
raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")