def forward()

in ppo_ewma/impala_cnn.py [0:0]


    def forward(self, x):
        x = x.to(dtype=th.float32) / self.scale_ob

        b, t = x.shape[:-3]
        x = x.reshape(b * t, *x.shape[-3:])
        x = tu.transpose(x, "bhwc", "bchw")
        x = tu.sequential(self.stacks, x, diag_name=self.name)
        x = x.reshape(b, t, *x.shape[1:])
        x = tu.flatten_image(x)
        x = th.relu(x)
        x = self.dense(x)
        if self.final_relu:
            x = th.relu(x)
        return x