in ppo_ewma/impala_cnn.py [0:0]
def forward(self, x):
x = x.to(dtype=th.float32) / self.scale_ob
b, t = x.shape[:-3]
x = x.reshape(b * t, *x.shape[-3:])
x = tu.transpose(x, "bhwc", "bchw")
x = tu.sequential(self.stacks, x, diag_name=self.name)
x = x.reshape(b, t, *x.shape[1:])
x = tu.flatten_image(x)
x = th.relu(x)
x = self.dense(x)
if self.final_relu:
x = th.relu(x)
return x