lib/impala_cnn.py (150 lines of code) (raw):
import math
from copy import deepcopy
from typing import Dict, List, Optional
from torch import nn
from torch.nn import functional as F
from lib import misc
from lib import torch_util as tu
from lib.util import FanInInitReLULayer
class CnnBasicBlock(nn.Module):
"""
Residual basic block, as in ImpalaCNN. Preserves channel number and shape
:param inchan: number of input channels
:param init_scale: weight init scale multiplier
"""
def __init__(
self,
inchan: int,
init_scale: float = 1,
log_scope="",
init_norm_kwargs: Dict = {},
**kwargs,
):
super().__init__()
self.inchan = inchan
s = math.sqrt(init_scale)
self.conv0 = FanInInitReLULayer(
self.inchan,
self.inchan,
kernel_size=3,
padding=1,
init_scale=s,
log_scope=f"{log_scope}/conv0",
**init_norm_kwargs,
)
self.conv1 = FanInInitReLULayer(
self.inchan,
self.inchan,
kernel_size=3,
padding=1,
init_scale=s,
log_scope=f"{log_scope}/conv1",
**init_norm_kwargs,
)
def forward(self, x):
x = x + self.conv1(self.conv0(x))
return x
class CnnDownStack(nn.Module):
"""
Downsampling stack from Impala CNN.
:param inchan: number of input channels
:param nblock: number of residual blocks after downsampling
:param outchan: number of output channels
:param init_scale: weight init scale multiplier
:param pool: if true, downsample with max pool
:param post_pool_groups: if not None, normalize with group norm with this many groups
:param kwargs: remaining kwargs are passed into the blocks and layers
"""
name = "Impala_CnnDownStack"
def __init__(
self,
inchan: int,
nblock: int,
outchan: int,
init_scale: float = 1,
pool: bool = True,
post_pool_groups: Optional[int] = None,
log_scope: str = "",
init_norm_kwargs: Dict = {},
first_conv_norm=False,
**kwargs,
):
super().__init__()
self.inchan = inchan
self.outchan = outchan
self.pool = pool
first_conv_init_kwargs = deepcopy(init_norm_kwargs)
if not first_conv_norm:
first_conv_init_kwargs["group_norm_groups"] = None
first_conv_init_kwargs["batch_norm"] = False
self.firstconv = FanInInitReLULayer(
inchan,
outchan,
kernel_size=3,
padding=1,
log_scope=f"{log_scope}/firstconv",
**first_conv_init_kwargs,
)
self.post_pool_groups = post_pool_groups
if post_pool_groups is not None:
self.n = nn.GroupNorm(post_pool_groups, outchan)
self.blocks = nn.ModuleList(
[
CnnBasicBlock(
outchan,
init_scale=init_scale / math.sqrt(nblock),
log_scope=f"{log_scope}/block{i}",
init_norm_kwargs=init_norm_kwargs,
**kwargs,
)
for i in range(nblock)
]
)
def forward(self, x):
x = self.firstconv(x)
if self.pool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
if self.post_pool_groups is not None:
x = self.n(x)
x = tu.sequential(self.blocks, x, diag_name=self.name)
return x
def output_shape(self, inshape):
c, h, w = inshape
assert c == self.inchan
if self.pool:
return (self.outchan, (h + 1) // 2, (w + 1) // 2)
else:
return (self.outchan, h, w)
class ImpalaCNN(nn.Module):
"""
:param inshape: input image shape (height, width, channels)
:param chans: number of residual downsample stacks. Each element is the number of
filters per convolution in the stack
:param outsize: output hidden size
:param nblock: number of residual blocks per stack. Each block has 2 convs and a residual
:param init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
in ypt.model.util:FanInInitReLULayer
:param dense_init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
in ypt.model.util:FanInInitReLULayer
:param kwargs: remaining kwargs are passed into the CnnDownStacks
"""
name = "ImpalaCNN"
def __init__(
self,
inshape: List[int],
chans: List[int],
outsize: int,
nblock: int,
init_norm_kwargs: Dict = {},
dense_init_norm_kwargs: Dict = {},
first_conv_norm=False,
**kwargs,
):
super().__init__()
h, w, c = inshape
curshape = (c, h, w)
self.stacks = nn.ModuleList()
for i, outchan in enumerate(chans):
stack = CnnDownStack(
curshape[0],
nblock=nblock,
outchan=outchan,
init_scale=math.sqrt(len(chans)),
log_scope=f"downstack{i}",
init_norm_kwargs=init_norm_kwargs,
first_conv_norm=first_conv_norm if i == 0 else True,
**kwargs,
)
self.stacks.append(stack)
curshape = stack.output_shape(curshape)
self.dense = FanInInitReLULayer(
misc.intprod(curshape),
outsize,
layer_type="linear",
log_scope="imapala_final_dense",
init_scale=1.4,
**dense_init_norm_kwargs,
)
self.outsize = outsize
def forward(self, x):
b, t = x.shape[:-3]
x = x.reshape(b * t, *x.shape[-3:])
x = misc.transpose(x, "bhwc", "bchw")
x = tu.sequential(self.stacks, x, diag_name=self.name)
x = x.reshape(b, t, *x.shape[1:])
x = tu.flatten_image(x)
x = self.dense(x)
return x