ppo_ewma/impala_cnn.py (128 lines of code) (raw):

import math import torch as th from torch import nn from torch.nn import functional as F from . import torch_util as tu from gym3.types import Real, TensorType REAL = Real() class Encoder(nn.Module): """ Takes in seq of observations and outputs sequence of codes Encoders can be stateful, meaning that you pass in one observation at a time and update the state, which is a separate object. (This object doesn't store any state except parameters) """ def __init__(self, obtype, codetype): super().__init__() self.obtype = obtype self.codetype = codetype def initial_state(self, batchsize): raise NotImplementedError def empty_state(self): return None def stateless_forward(self, obs): """ inputs: obs: array or dict, all with preshape (B, T) returns: codes: array or dict, all with preshape (B, T) """ code, _state = self(obs, None, self.empty_state()) return code def forward(self, obs, first, state_in): """ inputs: obs: array or dict, all with preshape (B, T) first: float array shape (B, T) state_in: array or dict, all with preshape (B,) returns: codes: array or dict state_out: array or dict """ raise NotImplementedError class CnnBasicBlock(nn.Module): """ Residual basic block (without batchnorm), as in ImpalaCNN Preserves channel number and shape """ def __init__(self, inchan, scale=1, batch_norm=False): super().__init__() self.inchan = inchan self.batch_norm = batch_norm s = math.sqrt(scale) self.conv0 = tu.NormedConv2d(self.inchan, self.inchan, 3, padding=1, scale=s) self.conv1 = tu.NormedConv2d(self.inchan, self.inchan, 3, padding=1, scale=s) if self.batch_norm: self.bn0 = nn.BatchNorm2d(self.inchan) self.bn1 = nn.BatchNorm2d(self.inchan) def residual(self, x): # inplace should be False for the first relu, so that it does not change the input, # which will be used for skip connection. # getattr is for backwards compatibility with loaded models if getattr(self, "batch_norm", False): x = self.bn0(x) x = F.relu(x, inplace=False) x = self.conv0(x) if getattr(self, "batch_norm", False): x = self.bn1(x) x = F.relu(x, inplace=True) x = self.conv1(x) return x def forward(self, x): return x + self.residual(x) class CnnDownStack(nn.Module): """ Downsampling stack from Impala CNN """ def __init__(self, inchan, nblock, outchan, scale=1, pool=True, **kwargs): super().__init__() self.inchan = inchan self.outchan = outchan self.pool = pool self.firstconv = tu.NormedConv2d(inchan, outchan, 3, padding=1) s = scale / math.sqrt(nblock) self.blocks = nn.ModuleList( [CnnBasicBlock(outchan, scale=s, **kwargs) for _ in range(nblock)] ) def forward(self, x): x = self.firstconv(x) if getattr(self, "pool", True): x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) for block in self.blocks: x = block(x) return x def output_shape(self, inshape): c, h, w = inshape assert c == self.inchan if getattr(self, "pool", True): return (self.outchan, (h + 1) // 2, (w + 1) // 2) else: return (self.outchan, h, w) class ImpalaCNN(nn.Module): name = "ImpalaCNN" # put it here to preserve pickle compat def __init__( self, inshape, chans, outsize, scale_ob, nblock, final_relu=True, **kwargs ): super().__init__() self.scale_ob = scale_ob h, w, c = inshape curshape = (c, h, w) s = 1 / math.sqrt(len(chans)) # per stack scale self.stacks = nn.ModuleList() for outchan in chans: stack = CnnDownStack( curshape[0], nblock=nblock, outchan=outchan, scale=s, **kwargs ) self.stacks.append(stack) curshape = stack.output_shape(curshape) self.dense = tu.NormedLinear(tu.intprod(curshape), outsize, scale=1.4) self.outsize = outsize self.final_relu = final_relu def forward(self, x): x = x.to(dtype=th.float32) / self.scale_ob b, t = x.shape[:-3] x = x.reshape(b * t, *x.shape[-3:]) x = tu.transpose(x, "bhwc", "bchw") x = tu.sequential(self.stacks, x, diag_name=self.name) x = x.reshape(b, t, *x.shape[1:]) x = tu.flatten_image(x) x = th.relu(x) x = self.dense(x) if self.final_relu: x = th.relu(x) return x class ImpalaEncoder(Encoder): def __init__( self, inshape, outsize=256, chans=(16, 32, 32), scale_ob=255.0, nblock=2, **kwargs ): codetype = TensorType(eltype=REAL, shape=(outsize,)) obtype = TensorType(eltype=REAL, shape=inshape) super().__init__(codetype=codetype, obtype=obtype) self.cnn = ImpalaCNN( inshape=inshape, chans=chans, scale_ob=scale_ob, nblock=nblock, outsize=outsize, **kwargs ) def forward(self, x, first, state_in): x = self.cnn(x) return x, state_in def initial_state(self, batchsize): return tu.zeros(batchsize, 0)