in ppo_ewma/ppg.py [0:0]
def act(self, ob, first, state_in):
pd, vpred, _, state_out = self(
ob=tree_map(lambda x: x[:, None], ob),
first=first[:, None],
state_in=state_in,
)
ac = pd.sample()
logp = sum_nonbatch(pd.log_prob(ac))
return (
tree_map(lambda x: x[:, 0], ac),
state_out,
dict(vpred=vpred[:, 0], logp=logp[:, 0]),
)