in monobeast/minigrid/monobeast_amigo.py [0:0]
def compute_policy_gradient_loss(logits, actions, advantages):
# Main Policy Loss
cross_entropy = F.nll_loss(
F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
target=torch.flatten(actions, 0, 1),
reduction="none",
)
cross_entropy = cross_entropy.view_as(advantages)
advantages.requires_grad = False
policy_gradient_loss_per_timestep = cross_entropy * advantages
return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1))