in understanding_rl_vision/rl_clarity/interface.py [0:0]
def compute_gae(trajectories, *, gae_gamma, gae_lambda):
values = trajectories["values"]
next_values = values[:, 1:]
rewards = trajectories["rewards"][:, :-1]
try:
dones = trajectories["dones"][:, :-1]
except KeyError:
dones = trajectories["firsts"][:, 1:]
assert next_values.shape == rewards.shape == dones.shape
deltas = rewards + (1 - dones) * gae_gamma * next_values - values[:, :-1]
result = np.zeros(values.shape, values.dtype)
for step in reversed(range(values.shape[1] - 1)):
result[:, step] = (
deltas[:, step]
+ (1 - dones[:, step]) * gae_gamma * gae_lambda * result[:, step + 1]
)
return result