def compute_policy_gradient_loss()

in monobeast/minigrid/monobeast_amigo.py [0:0]


def compute_policy_gradient_loss(logits, actions, advantages):
    # Main Policy Loss
    cross_entropy = F.nll_loss(
        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
        target=torch.flatten(actions, 0, 1),
        reduction="none",
    )
    cross_entropy = cross_entropy.view_as(advantages)
    advantages.requires_grad = False
    policy_gradient_loss_per_timestep = cross_entropy * advantages
    return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1))