lib/util.py (209 lines of code) (raw):

from typing import Dict, Optional import torch as th from torch import nn from torch.nn import functional as F import lib.torch_util as tu from lib.masked_attention import MaskedAttention from lib.minecraft_util import store_args from lib.tree_util import tree_map def get_module_log_keys_recursive(m: nn.Module): """Recursively get all keys that a module and its children want to log.""" keys = [] if hasattr(m, "get_log_keys"): keys += m.get_log_keys() for c in m.children(): keys += get_module_log_keys_recursive(c) return keys class FanInInitReLULayer(nn.Module): """Implements a slightly modified init that correctly produces std 1 outputs given ReLU activation :param inchan: number of input channels :param outchan: number of output channels :param layer_args: positional layer args :param layer_type: options are "linear" (dense layer), "conv" (2D Convolution), "conv3d" (3D convolution) :param init_scale: multiplier on initial weights :param batch_norm: use batch norm after the layer (for 2D data) :param group_norm_groups: if not None, use group norm with this many groups after the layer. Group norm 1 would be equivalent of layernorm for 2D data. :param layer_norm: use layernorm after the layer (for 1D data) :param layer_kwargs: keyword arguments for the layer """ @store_args def __init__( self, inchan: int, outchan: int, *layer_args, layer_type: str = "conv", init_scale: int = 1, batch_norm: bool = False, batch_norm_kwargs: Dict = {}, group_norm_groups: Optional[int] = None, layer_norm: bool = False, use_activation=True, log_scope: Optional[str] = None, **layer_kwargs, ): super().__init__() # Normalization self.norm = None if batch_norm: self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs) elif group_norm_groups is not None: self.norm = nn.GroupNorm(group_norm_groups, inchan) elif layer_norm: self.norm = nn.LayerNorm(inchan) layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type] self.layer = layer(inchan, outchan, bias=self.norm is None, *layer_args, **layer_kwargs) # Init Weights (Fan-In) self.layer.weight.data *= init_scale / self.layer.weight.norm( dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True ) # Init Bias if self.layer.bias is not None: self.layer.bias.data *= 0 def forward(self, x): """Norm after the activation. Experimented with this for both IAM and BC and it was slightly better.""" if self.norm is not None: x = self.norm(x) x = self.layer(x) if self.use_activation: x = F.relu(x, inplace=True) return x def get_log_keys(self): return [ f"activation_mean/{self.log_scope}", f"activation_std/{self.log_scope}", ] class ResidualRecurrentBlocks(nn.Module): @store_args def __init__( self, n_block=2, recurrence_type="multi_layer_lstm", is_residual=True, **block_kwargs, ): super().__init__() init_scale = n_block ** -0.5 if is_residual else 1 self.blocks = nn.ModuleList( [ ResidualRecurrentBlock( **block_kwargs, recurrence_type=recurrence_type, is_residual=is_residual, init_scale=init_scale, block_number=i, ) for i in range(n_block) ] ) def forward(self, x, first, state): state_out = [] assert len(state) == len( self.blocks ), f"Length of state {len(state)} did not match length of blocks {len(self.blocks)}" for block, _s_in in zip(self.blocks, state): x, _s_o = block(x, first, _s_in) state_out.append(_s_o) return x, state_out def initial_state(self, batchsize): if "lstm" in self.recurrence_type: return [None for b in self.blocks] else: return [b.r.initial_state(batchsize) for b in self.blocks] class ResidualRecurrentBlock(nn.Module): @store_args def __init__( self, hidsize, timesteps, init_scale=1, recurrence_type="multi_layer_lstm", is_residual=True, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, attention_heads=8, attention_memory_size=2048, attention_mask_style="clipped_causal", log_scope="resblock", block_number=0, ): super().__init__() self.log_scope = f"{log_scope}{block_number}" s = init_scale if use_pointwise_layer: if is_residual: s *= 2 ** -0.5 # second residual self.mlp0 = FanInInitReLULayer( hidsize, hidsize * pointwise_ratio, init_scale=1, layer_type="linear", layer_norm=True, log_scope=self.log_scope + "/ptwise_mlp0", ) self.mlp1 = FanInInitReLULayer( hidsize * pointwise_ratio, hidsize, init_scale=s, layer_type="linear", use_activation=pointwise_use_activation, log_scope=self.log_scope + "/ptwise_mlp1", ) self.pre_r_ln = nn.LayerNorm(hidsize) if recurrence_type in ["multi_layer_lstm", "multi_layer_bilstm"]: self.r = nn.LSTM(hidsize, hidsize, batch_first=True) nn.init.normal_(self.r.weight_hh_l0, std=s * (self.r.weight_hh_l0.shape[0] ** -0.5)) nn.init.normal_(self.r.weight_ih_l0, std=s * (self.r.weight_ih_l0.shape[0] ** -0.5)) self.r.bias_hh_l0.data *= 0 self.r.bias_ih_l0.data *= 0 elif recurrence_type == "transformer": self.r = MaskedAttention( input_size=hidsize, timesteps=timesteps, memory_size=attention_memory_size, heads=attention_heads, init_scale=s, norm="none", log_scope=log_scope + "/sa", use_muP_factor=True, mask=attention_mask_style, ) def forward(self, x, first, state): residual = x x = self.pre_r_ln(x) x, state_out = recurrent_forward( self.r, x, first, state, reverse_lstm=self.recurrence_type == "multi_layer_bilstm" and (self.block_number + 1) % 2 == 0, ) if self.is_residual and "lstm" in self.recurrence_type: # Transformer already residual. x = x + residual if self.use_pointwise_layer: # Residual MLP residual = x x = self.mlp1(self.mlp0(x)) if self.is_residual: x = x + residual return x, state_out def recurrent_forward(module, x, first, state, reverse_lstm=False): if isinstance(module, nn.LSTM): if state is not None: # In case recurrent models do not accept a "first" argument we zero out the hidden state here mask = 1 - first[:, 0, None, None].to(th.float) state = tree_map(lambda _s: _s * mask, state) state = tree_map(lambda _s: _s.transpose(0, 1), state) # NL, B, H if reverse_lstm: x = th.flip(x, [1]) x, state_out = module(x, state) if reverse_lstm: x = th.flip(x, [1]) state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out) # B, NL, H return x, state_out else: return module(x, first, state) def _banded_repeat(x, t): """ Repeats x with a shift. For example (ignoring the batch dimension): _banded_repeat([A B C D E], 4) = [D E 0 0 0] [C D E 0 0] [B C D E 0] [A B C D E] """ b, T = x.shape x = th.cat([x, x.new_zeros(b, t - 1)], dim=1) result = x.unfold(1, T, 1).flip(1) return result def bandify(b_nd, t, T): """ b_nd -> D_ntT, where "n" indexes over basis functions "d" indexes over time differences "t" indexes over output time "T" indexes over input time only t >= T is nonzero B_ntT[n, t, T] = b_nd[n, t - T] """ nbasis, bandsize = b_nd.shape b_nd = b_nd[:, th.arange(bandsize - 1, -1, -1)] if bandsize >= T: b_nT = b_nd[:, -T:] else: b_nT = th.cat([b_nd.new_zeros(nbasis, T - bandsize), b_nd], dim=1) D_tnT = _banded_repeat(b_nT, t) return D_tnT def get_norm(name, d, dtype=th.float32): if name == "none": return lambda x: x elif name == "layer": return tu.LayerNorm(d, dtype=dtype) else: raise NotImplementedError(name)