in rlalgos/deprecated/ppo/discrete_ppo.py [0:0]
def get_loss(self, trajectories):
device = self.config["learner_device"]
trajectories = trajectories.to(device)
max_length = trajectories.lengths.max().item()
assert trajectories.lengths.eq(max_length).all()
actions = trajectories["action"]
actions_probabilities = trajectories["action_probabilities"]
reward = trajectories["_reward"]
frame = trajectories["frame"]
last_action = trajectories["last_action"]
done = trajectories["_done"].float()
# Re compute model on trajectories
n_action_scores = []
n_values = []
hidden_state = trajectories["agent_state"][:, 0]
for T in range(max_length):
hidden_state = masked_tensor(
hidden_state,
trajectories["agent_state"][:, T],
trajectories["initial_state"][:, T],
)
_as, _v, hidden_state = self.learning_model(
hidden_state, frame[:, T], last_action[:, T]
)
n_action_scores.append(_as.unsqueeze(1))
n_values.append(_v.unsqueeze(1))
n_action_scores = torch.cat(n_action_scores, dim=1)
n_values = torch.cat(
[*n_values, torch.zeros(trajectories.n_elems(), 1, 1).to(device)], dim=1
).squeeze(-1)
# Compute value function for last state
_idx = torch.arange(trajectories.n_elems()).to(device)
_hidden_state = (
hidden_state.detach()
) # trajectories["_agent_state"][_idx, trajectories.lengths - 1]
_frame = trajectories["_frame"][_idx, trajectories.lengths - 1]
_last_action = trajectories["_last_action"][_idx, trajectories.lengths - 1]
_, _v, _ = self.learning_model(_hidden_state, _frame, _last_action)
n_values[_idx, trajectories.lengths] = _v.squeeze(-1)
advantage = self.get_gae(
trajectories,
n_values,
discount_factor=self.config["discount_factor"],
_lambda=self.config["gae_lambda"],
)
value_loss = advantage ** 2
avg_value_loss = value_loss.mean()
n_action_probabilities = torch.softmax(n_action_scores, dim=2)
n_action_distribution = torch.distributions.Categorical(n_action_probabilities)
log_a = torch.distributions.Categorical(actions_probabilities).log_prob(actions)
log_na = n_action_distribution.log_prob(actions)
ratios = torch.exp(log_na - log_a)
surr1 = ratios * advantage
surr2 = (
torch.clamp(
ratios, 1 - self.config["eps_clip"], 1 - self.config["eps_clip"]
)
* advantage
)
ppo_loss = torch.min(surr1, surr2)
avg_ppo_loss = ppo_loss.mean()
entropy_loss = n_action_distribution.entropy()
avg_entropy_loss = entropy_loss.mean()
dt = DictTensor(
{
"entropy_loss": avg_entropy_loss,
"ppo_loss": avg_ppo_loss,
"value_loss": avg_value_loss,
}
)
return dt