in drqv2.py [0:0]
def update_actor(self, obs, step):
metrics = dict()
stddev = utils.schedule(self.stddev_schedule, step)
dist = self.actor(obs, stddev)
action = dist.sample(clip=self.stddev_clip)
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
Q1, Q2 = self.critic(obs, action)
Q = torch.min(Q1, Q2)
actor_loss = -Q.mean()
# optimize actor
self.actor_opt.zero_grad(set_to_none=True)
actor_loss.backward()
self.actor_opt.step()
if self.use_tb:
metrics['actor_loss'] = actor_loss.item()
metrics['actor_logprob'] = log_prob.mean().item()
metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
return metrics