in svg/agent.py [0:0]
def update_actor_and_alpha(self, xs, logger, step):
assert xs.ndimension() == 2
n_batch, _ = xs.size()
do_model_free_update = step < self.warmup_steps or \
self.horizon == 0 or not self.actor_mve or \
(self.actor_dx_threshold is not None and \
self.rolling_dx_loss is not None and
self.rolling_dx_loss > self.actor_dx_threshold)
if do_model_free_update:
# Do vanilla SAC updates while the model warms up.
# i.e., fit to just the Q function
_, pi, first_log_p = self.actor(xs)
actor_Q1, actor_Q2 = self.critic(xs, pi)
actor_Q = torch.min(actor_Q1, actor_Q2)
actor_loss = (self.temp.alpha.detach() * first_log_p - actor_Q).mean()
else:
# Switch to the model-based updates.
# i.e., fit to the controller's sequence cost
rewards, first_log_p, total_log_p_us = self.expand_Q(
xs, self.critic, sample=True, discount=True)
assert total_log_p_us.size() == rewards.size()
actor_loss = -(rewards/self.horizon).mean()
logger.log('train_actor/loss', actor_loss, step)
logger.log('train_actor/entropy', -first_log_p.mean(), step)
self.actor_opt.zero_grad()
actor_loss.backward()
self.actor_opt.step()
self.actor.log(logger, step)
self.temp.update(first_log_p, logger, step)
logger.log('train_alpha/value', self.temp.alpha, step)