in svg/agent.py [0:0]
def update_critic_mve(self, first_xs, first_us, first_rs, next_xs, first_not_dones, logger, step):
""" MVE critic loss from Feinberg et al (2015) """
assert first_xs.dim() == 2
assert first_us.dim() == 2
assert first_rs.dim() == 2
assert next_xs.dim() == 2
assert first_not_dones.dim() == 2
n_batch = next_xs.size(0)
# unroll policy, concatenate obs and actions
pred_us, log_p_us, pred_xs = self.dx.unroll_policy(
next_xs, self.actor, sample=True, detach_xt=self.actor_detach_rho)
all_obs = torch.cat((first_xs.unsqueeze(0), next_xs.unsqueeze(0), pred_xs))
all_us = torch.cat([first_us.unsqueeze(0), pred_us])
xu = torch.cat([all_obs, all_us], dim=2)
horizon_len = all_obs.size(0) - 1 # H
# get immediate rewards
pred_rs = self.rew(xu[1:-1]) # t from 0 to H - 1
rewards = torch.cat([first_rs.unsqueeze(0), pred_rs]).squeeze(2)
rewards = rewards.unsqueeze(1).expand(-1, horizon_len, -1)
log_p_us = log_p_us.unsqueeze(1).expand(-1, horizon_len, -1)
# get not dones factor matrix, rows --> t, cols --> k
first_not_dones = first_not_dones.unsqueeze(0)
init_not_dones = torch.ones_like(first_not_dones) # we know the first states are not terminal
pred_not_dones = 1. - self.done(xu[2:]).sigmoid() # t from 1 to H
not_dones = torch.cat([init_not_dones, first_not_dones, pred_not_dones]).squeeze(2)
not_dones = not_dones.unsqueeze(1).repeat(1, horizon_len, 1)
triu_rows, triu_cols = torch.triu_indices(row=horizon_len + 1, col=horizon_len, offset=1, device=not_dones.device)
not_dones[triu_rows, triu_cols, :] = 1.
not_dones = not_dones.cumprod(dim=0).detach()
# get lower-triangular reward discount factor matrix
discount = torch.tensor(self.discount, device=rewards.device)
discount_exps = torch.stack([torch.arange(-i, -i + horizon_len) for i in range(horizon_len)], dim=1)
r_discounts = discount ** discount_exps.to(rewards.device)
r_discounts = r_discounts.tril().unsqueeze(-1)
# get discounted sums of soft rewards (t from -1 to H - 1 (k from t to H - 1))
alpha = self.temp.alpha.detach()
soft_rewards = (not_dones[:-1] * rewards) - (discount * alpha * not_dones[1:] * log_p_us)
soft_rewards = (r_discounts * soft_rewards).sum(0)
# get target q-values, final critic targets
target_q1, target_q2 = self.critic_target(all_obs[-1], all_us[-1])
target_qs = torch.min(target_q1, target_q2).squeeze(-1).expand(horizon_len, -1)
q_discounts = discount ** torch.arange(horizon_len, 0, step=-1).to(target_qs.device)
target_qs = target_qs * (not_dones[-1] * q_discounts.unsqueeze(-1))
critic_targets = (soft_rewards + target_qs).detach()
# get predicted q-values
with utils.eval_mode(self.critic):
q1, q2 = self.critic(all_obs[:-1].flatten(end_dim=-2),
all_us[:-1].flatten(end_dim=-2))
q1, q2 = q1.reshape(horizon_len, n_batch), q2.reshape(horizon_len, n_batch)
assert q1.size() == critic_targets.size()
assert q2.size() == critic_targets.size()
# update critics
q1_loss = (not_dones[:-1, 0] * (q1 - critic_targets).pow(2)).mean()
q2_loss = (not_dones[:-1, 0] * (q2 - critic_targets).pow(2)).mean()
Q_loss = q1_loss + q2_loss
logger.log('train_critic/Q_loss', Q_loss, step)
current_Q = torch.min(q1, q2)
logger.log('train_critic/value', current_Q.mean(), step)
self.critic_opt.zero_grad()
Q_loss.backward()
logger.log('train_critic/Q_loss', Q_loss, step)
self.critic_opt.step()
self.critic.log(logger, step)