in lib/policy.py [0:0]
def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False):
# We need to add a fictitious time dimension everywhere
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
if taken_action is None:
ac = self.pi_head.sample(pd, deterministic=not stochastic)
else:
ac = tree_map(lambda x: x.unsqueeze(1), taken_action)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
# After unsqueezing, squeeze back to remove fictitious time dimension
result = {"log_prob": log_prob[:, 0], "vpred": self.value_head.denormalize(vpred)[:, 0]}
if return_pd:
result["pd"] = tree_map(lambda x: x[:, 0], pd)
ac = tree_map(lambda x: x[:, 0], ac)
return ac, state_out, result