in rlkit/torch/td3/td3.py [0:0]
def train_from_torch(self, batch):
rewards = batch['rewards']
terminals = batch['terminals']
obs = batch['observations']
actions = batch['actions']
next_obs = batch['next_observations']
"""
Critic operations.
"""
next_actions = self.target_policy(next_obs)
noise = ptu.randn(next_actions.shape) * self.target_policy_noise
noise = torch.clamp(
noise,
-self.target_policy_noise_clip,
self.target_policy_noise_clip
)
noisy_next_actions = next_actions + noise
target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
target_q_values = torch.min(target_q1_values, target_q2_values)
q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
q_target = q_target.detach()
q1_pred = self.qf1(obs, actions)
bellman_errors_1 = (q1_pred - q_target) ** 2
qf1_loss = bellman_errors_1.mean()
q2_pred = self.qf2(obs, actions)
bellman_errors_2 = (q2_pred - q_target) ** 2
qf2_loss = bellman_errors_2.mean()
"""
Update Networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward()
self.qf1_optimizer.step()
self.qf2_optimizer.zero_grad()
qf2_loss.backward()
self.qf2_optimizer.step()
policy_actions = policy_loss = None
if self._n_train_steps_total % self.policy_and_target_update_period == 0:
policy_actions = self.policy(obs)
q_output = self.qf1(obs, policy_actions)
policy_loss = - q_output.mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)
if self._need_to_update_eval_statistics:
self._need_to_update_eval_statistics = False
if policy_loss is None:
policy_actions = self.policy(obs)
q_output = self.qf1(obs, policy_actions)
policy_loss = - q_output.mean()
self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q1 Predictions',
ptu.get_numpy(q1_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q2 Predictions',
ptu.get_numpy(q2_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets',
ptu.get_numpy(q_target),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Bellman Errors 1',
ptu.get_numpy(bellman_errors_1),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Bellman Errors 2',
ptu.get_numpy(bellman_errors_2),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Policy Action',
ptu.get_numpy(policy_actions),
))
self._n_train_steps_total += 1