in gala/storage.py [0:0]
def compute_returns(self,
next_value,
use_gae,
gamma,
gae_lambda,
use_proper_time_limits=True):
if use_proper_time_limits:
if use_gae:
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size(0))):
delta = self.rewards[step] + gamma * self.value_preds[
step + 1] * self.masks[step +
1] - self.value_preds[step]
gae = delta + gamma * gae_lambda * self.masks[step +
1] * gae
gae = gae * self.bad_masks[step + 1]
self.returns[step] = gae + self.value_preds[step]
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = (self.returns[step + 1] * \
gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
+ (1 - self.bad_masks[step + 1]) * self.value_preds[step]
else:
if use_gae:
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size(0))):
delta = self.rewards[step] + gamma * self.value_preds[
step + 1] * self.masks[step +
1] - self.value_preds[step]
gae = delta + gamma * gae_lambda * self.masks[step +
1] * gae
self.returns[step] = gae + self.value_preds[step]
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = self.returns[step + 1] * \
gamma * self.masks[step + 1] + self.rewards[step]