in svg/agent.py [0:0]
def update_critic(self, xs, xps, us, rs, not_done, logger, step):
assert xs.ndimension() == 2
n_batch, _ = xs.size()
rs = rs.squeeze()
not_done = not_done.squeeze()
with torch.no_grad():
if not self.critic_target_mve or step < self.warmup_steps:
mu, target_us, log_pi = self.actor.forward(
xps, compute_pi=True, compute_log_pi=True)
log_pi = log_pi.squeeze(1)
target_Q1, target_Q2 = [
Q.squeeze(1) for Q in self.critic_target(xps, target_us)]
target_Q = torch.min(target_Q1, target_Q2) - self.temp.alpha.detach() * log_pi
assert target_Q.size() == rs.size()
assert target_Q.ndimension() == 1
target_Q = rs + not_done * self.discount * target_Q
target_Q = target_Q.detach()
else:
target_Q, first_log_p, total_log_p_us = self.expand_Q(
xps, self.critic_target, sample=True, discount=True)
target_Q = target_Q - self.temp.alpha.detach() * first_log_p
target_Q = rs + not_done * self.discount * target_Q
target_Q = target_Q.detach()
current_Q1, current_Q2 = [Q.squeeze(1) for Q in self.critic(xs, us)]
assert current_Q1.size() == target_Q.size()
assert current_Q2.size() == target_Q.size()
Q_loss = F.mse_loss(current_Q1, target_Q) + \
F.mse_loss(current_Q2, target_Q)
logger.log('train_critic/Q_loss', Q_loss, step)
current_Q = torch.min(current_Q1, current_Q2)
logger.log('train_critic/value', current_Q.mean(), step)
self.critic_opt.zero_grad()
Q_loss.backward()
logger.log('train_critic/Q_loss', Q_loss, step)
self.critic_opt.step()
self.critic.log(logger, step)