lib/policy.py (314 lines of code) (raw):
from copy import deepcopy
from email import policy
from typing import Dict, Optional
import numpy as np
import torch as th
from gym3.types import DictType
from torch import nn
from torch.nn import functional as F
from lib.action_head import make_action_head
from lib.action_mapping import CameraHierarchicalMapping
from lib.impala_cnn import ImpalaCNN
from lib.normalize_ewma import NormalizeEwma
from lib.scaled_mse_head import ScaledMSEHead
from lib.tree_util import tree_map
from lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from lib.misc import transpose
class ImgPreprocessing(nn.Module):
"""Normalize incoming images.
:param img_statistics: remote path to npz file with a mean and std image. If specified
normalize images using this.
:param scale_img: If true and img_statistics not specified, scale incoming images by 1/255.
"""
def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True):
super().__init__()
self.img_mean = None
if img_statistics is not None:
img_statistics = dict(**np.load(img_statistics))
self.img_mean = nn.Parameter(th.Tensor(img_statistics["mean"]), requires_grad=False)
self.img_std = nn.Parameter(th.Tensor(img_statistics["std"]), requires_grad=False)
else:
self.ob_scale = 255.0 if scale_img else 1.0
def forward(self, img):
x = img.to(dtype=th.float32)
if self.img_mean is not None:
x = (x - self.img_mean) / self.img_std
else:
x = x / self.ob_scale
return x
class ImgObsProcess(nn.Module):
"""ImpalaCNN followed by a linear layer.
:param cnn_outsize: impala output dimension
:param output_size: output size of the linear layer.
:param dense_init_norm_kwargs: kwargs for linear FanInInitReLULayer
:param init_norm_kwargs: kwargs for 2d and 3d conv FanInInitReLULayer
"""
def __init__(
self,
cnn_outsize: int,
output_size: int,
dense_init_norm_kwargs: Dict = {},
init_norm_kwargs: Dict = {},
**kwargs,
):
super().__init__()
self.cnn = ImpalaCNN(
outsize=cnn_outsize,
init_norm_kwargs=init_norm_kwargs,
dense_init_norm_kwargs=dense_init_norm_kwargs,
**kwargs,
)
self.linear = FanInInitReLULayer(
cnn_outsize,
output_size,
layer_type="linear",
**dense_init_norm_kwargs,
)
def forward(self, img):
return self.linear(self.cnn(img))
class MinecraftPolicy(nn.Module):
"""
:param recurrence_type:
None - No recurrence, adds no extra layers
lstm - (Depreciated). Singular LSTM
multi_layer_lstm - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs
Does NOT support ragged batching
multi_masked_lstm - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower
Uses n_recurrence_layers to determine number of consecututive LSTMs
transformer - Dense transformer
:param init_norm_kwargs: kwargs for all FanInInitReLULayers.
"""
def __init__(
self,
recurrence_type="lstm",
impala_width=1,
impala_chans=(16, 32, 32),
obs_processing_width=256,
hidsize=512,
single_output=False, # True if we don't need separate outputs for action/value outputs
img_shape=None,
scale_input_img=True,
only_img_input=False,
init_norm_kwargs={},
impala_kwargs={},
# Unused argument assumed by forc.
input_shape=None, # pylint: disable=unused-argument
active_reward_monitors=None,
img_statistics=None,
first_conv_norm=False,
diff_mlp_embedding=False,
attention_mask_style="clipped_causal",
attention_heads=8,
attention_memory_size=2048,
use_pointwise_layer=True,
pointwise_ratio=4,
pointwise_use_activation=False,
n_recurrence_layers=1,
recurrence_is_residual=True,
timesteps=None,
use_pre_lstm_ln=True, # Not needed for transformer
**unused_kwargs,
):
super().__init__()
assert recurrence_type in [
"multi_layer_lstm",
"multi_layer_bilstm",
"multi_masked_lstm",
"transformer",
"none",
]
active_reward_monitors = active_reward_monitors or {}
self.single_output = single_output
chans = tuple(int(impala_width * c) for c in impala_chans)
self.hidsize = hidsize
# Dense init kwargs replaces batchnorm/groupnorm with layernorm
self.init_norm_kwargs = init_norm_kwargs
self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs)
if self.dense_init_norm_kwargs.get("group_norm_groups", None) is not None:
self.dense_init_norm_kwargs.pop("group_norm_groups", None)
self.dense_init_norm_kwargs["layer_norm"] = True
if self.dense_init_norm_kwargs.get("batch_norm", False):
self.dense_init_norm_kwargs.pop("batch_norm", False)
self.dense_init_norm_kwargs["layer_norm"] = True
# Setup inputs
self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img)
self.img_process = ImgObsProcess(
cnn_outsize=256,
output_size=hidsize,
inshape=img_shape,
chans=chans,
nblock=2,
dense_init_norm_kwargs=self.dense_init_norm_kwargs,
init_norm_kwargs=init_norm_kwargs,
first_conv_norm=first_conv_norm,
**impala_kwargs,
)
self.pre_lstm_ln = nn.LayerNorm(hidsize) if use_pre_lstm_ln else None
self.diff_obs_process = None
self.recurrence_type = recurrence_type
self.recurrent_layer = None
self.recurrent_layer = ResidualRecurrentBlocks(
hidsize=hidsize,
timesteps=timesteps,
recurrence_type=recurrence_type,
is_residual=recurrence_is_residual,
use_pointwise_layer=use_pointwise_layer,
pointwise_ratio=pointwise_ratio,
pointwise_use_activation=pointwise_use_activation,
attention_mask_style=attention_mask_style,
attention_heads=attention_heads,
attention_memory_size=attention_memory_size,
n_block=n_recurrence_layers,
)
self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type="linear", **self.dense_init_norm_kwargs)
self.final_ln = th.nn.LayerNorm(hidsize)
def output_latent_size(self):
return self.hidsize
def forward(self, ob, state_in, context):
first = context["first"]
x = self.img_preprocess(ob["img"])
x = self.img_process(x)
if self.diff_obs_process:
processed_obs = self.diff_obs_process(ob["diff_goal"])
x = processed_obs + x
if self.pre_lstm_ln is not None:
x = self.pre_lstm_ln(x)
if self.recurrent_layer is not None:
x, state_out = self.recurrent_layer(x, first, state_in)
else:
state_out = state_in
x = F.relu(x, inplace=False)
x = self.lastlayer(x)
x = self.final_ln(x)
pi_latent = vf_latent = x
if self.single_output:
return pi_latent, state_out
return (pi_latent, vf_latent), state_out
def initial_state(self, batchsize):
if self.recurrent_layer:
return self.recurrent_layer.initial_state(batchsize)
else:
return None
class MinecraftAgentPolicy(nn.Module):
def __init__(self, action_space, policy_kwargs, pi_head_kwargs):
super().__init__()
self.net = MinecraftPolicy(**policy_kwargs)
self.action_space = action_space
self.value_head = self.make_value_head(self.net.output_latent_size())
self.pi_head = self.make_action_head(self.net.output_latent_size(), **pi_head_kwargs)
def make_value_head(self, v_out_size: int, norm_type: str = "ewma", norm_kwargs: Optional[Dict] = None):
return ScaledMSEHead(v_out_size, 1, norm_type=norm_type, norm_kwargs=norm_kwargs)
def make_action_head(self, pi_out_size: int, **pi_head_opts):
return make_action_head(self.action_space, pi_out_size, **pi_head_opts)
def initial_state(self, batch_size: int):
return self.net.initial_state(batch_size)
def reset_parameters(self):
super().reset_parameters()
self.net.reset_parameters()
self.pi_head.reset_parameters()
self.value_head.reset_parameters()
def forward(self, obs, first: th.Tensor, state_in):
if isinstance(obs, dict):
# We don't want to mutate the obs input.
obs = obs.copy()
# If special "mask" key is in obs,
# It's for masking the logits.
# We take it out (the network doesn't need it)
mask = obs.pop("mask", None)
else:
mask = None
(pi_h, v_h), state_out = self.net(obs, state_in, context={"first": first})
pi_logits = self.pi_head(pi_h, mask=mask)
vpred = self.value_head(v_h)
return (pi_logits, vpred, None), state_out
def get_logprob_of_action(self, pd, action):
"""
Get logprob of taking action `action` given probability distribution
(see `get_gradient_for_action` to get this distribution)
"""
ac = tree_map(lambda x: x.unsqueeze(1), action)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
return log_prob[:, 0]
def get_kl_of_action_dists(self, pd1, pd2):
"""
Get the KL divergence between two action probability distributions
"""
return self.pi_head.kl_divergence(pd1, pd2)
def get_output_for_observation(self, obs, state_in, first):
"""
Return gradient-enabled outputs for given observation.
Use `get_logprob_of_action` to get log probability of action
with the given probability distribution.
Returns:
- probability distribution given observation
- value prediction for given observation
- new state
"""
# We need to add a fictitious time dimension everywhere
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
return pd, self.value_head.denormalize(vpred)[:, 0], state_out
@th.no_grad()
def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False):
# We need to add a fictitious time dimension everywhere
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
if taken_action is None:
ac = self.pi_head.sample(pd, deterministic=not stochastic)
else:
ac = tree_map(lambda x: x.unsqueeze(1), taken_action)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
# After unsqueezing, squeeze back to remove fictitious time dimension
result = {"log_prob": log_prob[:, 0], "vpred": self.value_head.denormalize(vpred)[:, 0]}
if return_pd:
result["pd"] = tree_map(lambda x: x[:, 0], pd)
ac = tree_map(lambda x: x[:, 0], ac)
return ac, state_out, result
@th.no_grad()
def v(self, obs, first, state_in):
"""Predict value for a given mdp observation"""
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
# After unsqueezing, squeeze back
return self.value_head.denormalize(vpred)[:, 0]
class InverseActionNet(MinecraftPolicy):
"""
Args:
conv3d_params: PRE impala 3D CNN params. They are just passed into th.nn.Conv3D.
"""
def __init__(
self,
hidsize=512,
conv3d_params=None,
**MCPoliy_kwargs,
):
super().__init__(
hidsize=hidsize,
# If we're using 3dconv, then we normalize entire impala otherwise don't
# normalize the first impala layer since we normalize the input
first_conv_norm=conv3d_params is not None,
**MCPoliy_kwargs,
)
self.conv3d_layer = None
if conv3d_params is not None:
# 3D conv is the first layer, so don't normalize its input
conv3d_init_params = deepcopy(self.init_norm_kwargs)
conv3d_init_params["group_norm_groups"] = None
conv3d_init_params["batch_norm"] = False
self.conv3d_layer = FanInInitReLULayer(
layer_type="conv3d",
log_scope="3d_conv",
**conv3d_params,
**conv3d_init_params,
)
def forward(self, ob, state_in, context):
first = context["first"]
x = self.img_preprocess(ob["img"])
# Conv3D Prior to Impala
if self.conv3d_layer is not None:
x = self._conv3d_forward(x)
# Impala Stack
x = self.img_process(x)
if self.recurrent_layer is not None:
x, state_out = self.recurrent_layer(x, first, state_in)
x = F.relu(x, inplace=False)
pi_latent = self.lastlayer(x)
pi_latent = self.final_ln(x)
return (pi_latent, None), state_out
def _conv3d_forward(self, x):
# Convert from (B, T, H, W, C) -> (B, H, W, C, T)
x = transpose(x, "bthwc", "bcthw")
new_x = []
for mini_batch in th.split(x, 1):
new_x.append(self.conv3d_layer(mini_batch))
x = th.cat(new_x)
# Convert back
x = transpose(x, "bcthw", "bthwc")
return x
class InverseActionPolicy(nn.Module):
def __init__(
self,
action_space,
pi_head_kwargs=None,
idm_net_kwargs=None,
):
super().__init__()
self.action_space = action_space
self.net = InverseActionNet(**idm_net_kwargs)
pi_out_size = self.net.output_latent_size()
pi_head_kwargs = {} if pi_head_kwargs is None else pi_head_kwargs
self.pi_head = self.make_action_head(pi_out_size=pi_out_size, **pi_head_kwargs)
def make_action_head(self, **kwargs):
return make_action_head(self.action_space, **kwargs)
def reset_parameters(self):
super().reset_parameters()
self.net.reset_parameters()
self.pi_head.reset_parameters()
def forward(self, obs, first: th.Tensor, state_in, **kwargs):
if isinstance(obs, dict):
# We don't want to mutate the obs input.
obs = obs.copy()
# If special "mask" key is in obs,
# It's for masking the logits.
# We take it out (the network doesn't need it)
mask = obs.pop("mask", None)
else:
mask = None
(pi_h, _), state_out = self.net(obs, state_in=state_in, context={"first": first}, **kwargs)
pi_logits = self.pi_head(pi_h, mask=mask)
return (pi_logits, None, None), state_out
@th.no_grad()
def predict(
self,
obs,
deterministic: bool = True,
**kwargs,
):
(pd, _, _), state_out = self(obs=obs, **kwargs)
ac = self.pi_head.sample(pd, deterministic=deterministic)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
result = {"log_prob": log_prob, "pd": pd}
return ac, state_out, result
def initial_state(self, batch_size: int):
return self.net.initial_state(batch_size)