in rlkit/torch/sac/uwac_dropout.py [0:0]
def train_from_torch(self, batch):
self._current_epoch += 1
rewards = batch['rewards']
terminals = batch['terminals']
obs = batch['observations']
actions = batch['actions']
next_obs = batch['next_observations']
"""
Behavior clone a policy
"""
recon, mean, std = self.vae(obs, actions)
recon_loss = self.qf_criterion(recon, actions)
kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * kl_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
"""
Critic Training
"""
# import ipdb; ipdb.set_trace()
with torch.no_grad():
# Duplicate state 10 times (10 is a hyperparameter chosen by BCQ)
state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(next_obs.shape[0]*10, next_obs.shape[1])
# Compute value of perturbed actions sampled from the VAE
action_rep = self.policy(state_rep)[0]
target_qf1, target_qf1_var = self.target_qf1.multiple(state_rep, action_rep, with_var=True)
target_qf2, target_qf2_var = self.target_qf2.multiple(state_rep, action_rep, with_var=True)
# Soft Clipped Double Q-learning
target_Q = 0.75 * torch.min(target_qf1, target_qf2) + 0.25 * torch.max(target_qf1, target_qf2)
target_Q = target_Q.view(next_obs.shape[0], -1).max(1)[0].view(-1, 1)
target_Q = self.reward_scale * rewards + (1.0 - terminals) * self.discount * target_Q
target_Q_var = (target_qf1_var + target_qf2_var).view(next_obs.shape[0], -1).max(1)[0].view(-1, 1)
weight = self._get_weight(target_Q_var, self.beta)
qf1_pred = self.qf1.sample(obs, actions)
qf2_pred = self.qf2.sample(obs, actions)
if self.use_exp_penalty:
qf1_loss = ((qf1_pred - target_Q.detach())*weight.detach()).pow(2).mean() + self.q_penalty*(torch.nn.functional.relu(qf1_pred)*torch.exp(target_Q_var.data)).mean()
qf2_loss = ((qf2_pred - target_Q.detach())*weight.detach()).pow(2).mean() + self.q_penalty*(torch.nn.functional.relu(qf2_pred)*torch.exp(target_Q_var.data)).mean()
else:
qf1_loss = ((qf1_pred - target_Q.detach())*weight.detach()).pow(2).mean() + self.q_penalty*(torch.nn.functional.relu(qf1_pred)*target_Q_var.data).mean()
qf2_loss = ((qf2_pred - target_Q.detach())*weight.detach()).pow(2).mean() + self.q_penalty*(torch.nn.functional.relu(qf2_pred)*target_Q_var.data).mean()
"""
Actor Training
"""
sampled_actions, raw_sampled_actions = self.vae.decode_multiple(obs, num_decode=self.num_samples_mmd_match)
actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy(
obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match, 1).view(-1, obs.shape[1]), return_log_prob=True)
actor_samples = actor_samples.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1])
raw_actor_actions = raw_actor_actions.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1])
if self.kernel_choice == 'laplacian':
mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma)
elif self.kernel_choice == 'gaussian':
mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma)
action_divergence = ((sampled_actions - actor_samples)**2).sum(-1)
raw_action_divergence = ((raw_sampled_actions - raw_actor_actions)**2).sum(-1)
q_val1, q_val1_var = self.qf1.multiple(obs, actor_samples[:, 0, :], with_var=True)
q_val2, q_val2_var = self.qf2.multiple(obs, actor_samples[:, 0, :], with_var=True)
if self.policy_update_style == '0':
policy_loss = torch.min(q_val1, q_val2)[:, 0]
elif self.policy_update_style == '1':
policy_loss = torch.mean(q_val1, q_val2)[:, 0]
with torch.no_grad():
q_var=q_val1_var+q_val2_var
if self.var_Pi:
weight = self._get_weight(q_var,self.beta).squeeze()
else:
weight = 1.
if self._n_train_steps_total >= 40000:
# Now we can update the policy
if self.mode == 'auto':
policy_loss = (-policy_loss*weight.detach() + self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean()
else:
policy_loss = (-policy_loss*weight.detach() + 100*mmd_loss).mean()
else:
if self.mode == 'auto':
policy_loss = (self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean()
else:
policy_loss = 100*mmd_loss.mean()
"""
Update Networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward()
self.qf1_optimizer.step()
self.qf2_optimizer.zero_grad()
qf2_loss.backward()
self.qf2_optimizer.step()
self.policy_optimizer.zero_grad()
if self.mode == 'auto':
policy_loss.backward(retain_graph=True)
self.policy_optimizer.step()
if self.mode == 'auto':
self.alpha_optimizer.zero_grad()
(-policy_loss).backward()
self.alpha_optimizer.step()
self.log_alpha.data.clamp_(min=-5.0, max=10.0)
"""
Update networks
"""
if self._n_train_steps_total % self.target_update_period == 0:
ptu.soft_update_from_to(
self.qf1, self.target_qf1, self.soft_target_tau
)
ptu.soft_update_from_to(
self.qf2, self.target_qf2, self.soft_target_tau
)
"""
Some statistics for eval
"""
if self._need_to_update_eval_statistics:
self._need_to_update_eval_statistics = False
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q1 Predictions',
ptu.get_numpy(qf1_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q2 Predictions',
ptu.get_numpy(qf2_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets',
ptu.get_numpy(target_Q),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets Variance',
ptu.get_numpy(target_Q_var),
))
self.eval_statistics.update(create_stats_ordered_dict(
'MMD Loss',
ptu.get_numpy(mmd_loss)
))
self.eval_statistics.update(create_stats_ordered_dict(
'Action Divergence',
ptu.get_numpy(action_divergence)
))
self.eval_statistics.update(create_stats_ordered_dict(
'Raw Action Divergence',
ptu.get_numpy(raw_action_divergence)
))
if self.mode == 'auto':
self.eval_statistics['Alpha'] = self.log_alpha.exp().item()
self._n_train_steps_total += 1