def compute_gae()

in understanding_rl_vision/rl_clarity/interface.py [0:0]


def compute_gae(trajectories, *, gae_gamma, gae_lambda):
    values = trajectories["values"]
    next_values = values[:, 1:]
    rewards = trajectories["rewards"][:, :-1]
    try:
        dones = trajectories["dones"][:, :-1]
    except KeyError:
        dones = trajectories["firsts"][:, 1:]
    assert next_values.shape == rewards.shape == dones.shape
    deltas = rewards + (1 - dones) * gae_gamma * next_values - values[:, :-1]
    result = np.zeros(values.shape, values.dtype)
    for step in reversed(range(values.shape[1] - 1)):
        result[:, step] = (
            deltas[:, step]
            + (1 - dones[:, step]) * gae_gamma * gae_lambda * result[:, step + 1]
        )
    return result