in rlkit/torch/ddpg/ddpg.py [0:0]
def train_from_torch(self, batch):
rewards = batch['rewards']
terminals = batch['terminals']
obs = batch['observations']
actions = batch['actions']
next_obs = batch['next_observations']
"""
Policy operations.
"""
if self.policy_pre_activation_weight > 0:
policy_actions, pre_tanh_value = self.policy(
obs, return_preactivations=True,
)
pre_activation_policy_loss = (
(pre_tanh_value**2).sum(dim=1).mean()
)
q_output = self.qf(obs, policy_actions)
raw_policy_loss = - q_output.mean()
policy_loss = (
raw_policy_loss +
pre_activation_policy_loss * self.policy_pre_activation_weight
)
else:
policy_actions = self.policy(obs)
q_output = self.qf(obs, policy_actions)
raw_policy_loss = policy_loss = - q_output.mean()
"""
Critic operations.
"""
next_actions = self.target_policy(next_obs)
# speed up computation by not backpropping these gradients
next_actions.detach()
target_q_values = self.target_qf(
next_obs,
next_actions,
)
q_target = rewards + (1. - terminals) * self.discount * target_q_values
q_target = q_target.detach()
q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
q_pred = self.qf(obs, actions)
bellman_errors = (q_pred - q_target) ** 2
raw_qf_loss = self.qf_criterion(q_pred, q_target)
if self.qf_weight_decay > 0:
reg_loss = self.qf_weight_decay * sum(
torch.sum(param ** 2)
for param in self.qf.regularizable_parameters()
)
qf_loss = raw_qf_loss + reg_loss
else:
qf_loss = raw_qf_loss
"""
Update Networks
"""
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
self.qf_optimizer.zero_grad()
qf_loss.backward()
self.qf_optimizer.step()
self._update_target_networks()
"""
Save some statistics for eval using just one batch.
"""
if self._need_to_update_eval_statistics:
self._need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
))
self.eval_statistics['Raw Policy Loss'] = np.mean(ptu.get_numpy(
raw_policy_loss
))
self.eval_statistics['Preactivation Policy Loss'] = (
self.eval_statistics['Policy Loss'] -
self.eval_statistics['Raw Policy Loss']
)
self.eval_statistics.update(create_stats_ordered_dict(
'Q Predictions',
ptu.get_numpy(q_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets',
ptu.get_numpy(q_target),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Bellman Errors',
ptu.get_numpy(bellman_errors),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Policy Action',
ptu.get_numpy(policy_actions),
))
self._n_train_steps_total += 1