in drqv2.py [0:0]
def update_critic(self, obs, action, reward, discount, next_obs, step):
metrics = dict()
with torch.no_grad():
stddev = utils.schedule(self.stddev_schedule, step)
dist = self.actor(next_obs, stddev)
next_action = dist.sample(clip=self.stddev_clip)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_V = torch.min(target_Q1, target_Q2)
target_Q = reward + (discount * target_V)
Q1, Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
if self.use_tb:
metrics['critic_target_q'] = target_Q.mean().item()
metrics['critic_q1'] = Q1.mean().item()
metrics['critic_q2'] = Q2.mean().item()
metrics['critic_loss'] = critic_loss.item()
# optimize encoder and critic
self.encoder_opt.zero_grad(set_to_none=True)
self.critic_opt.zero_grad(set_to_none=True)
critic_loss.backward()
self.critic_opt.step()
self.encoder_opt.step()
return metrics