in ppo_ewma/ppg.py [0:0]
def forward(self, ob, first, state_in):
state_out = {}
x_out = {}
for k in self.enc_keys:
x_out[k], state_out[k] = self.get_encoder(k)(ob, first, state_in[k])
x_out[k] = self.reshape_x(x_out[k])
pi_x = x_out[self.pi_key]
pivec = self.pi_head(pi_x)
pd = self.make_distr(pivec)
aux = {}
for k in self.vf_keys:
if self.detach_value_head:
x_out[k] = x_out[k].detach()
aux[k] = self.get_vhead(k)(x_out[k])[..., 0]
vfvec = aux[self.true_vf_key]
aux.update({"vpredaux": self.aux_vf_head(pi_x)[..., 0], "vpredtrue": vfvec})
return pd, vfvec, aux, state_out