in lib/policy.py [0:0]
def forward(self, ob, state_in, context):
first = context["first"]
x = self.img_preprocess(ob["img"])
x = self.img_process(x)
if self.diff_obs_process:
processed_obs = self.diff_obs_process(ob["diff_goal"])
x = processed_obs + x
if self.pre_lstm_ln is not None:
x = self.pre_lstm_ln(x)
if self.recurrent_layer is not None:
x, state_out = self.recurrent_layer(x, first, state_in)
else:
state_out = state_in
x = F.relu(x, inplace=False)
x = self.lastlayer(x)
x = self.final_ln(x)
pi_latent = vf_latent = x
if self.single_output:
return pi_latent, state_out
return (pi_latent, vf_latent), state_out