in core/vtrace.py [0:0]
def from_importance_weights(log_rhos,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0):
"""V-trace from log importance weights."""
with torch.no_grad():
rhos = torch.exp(log_rhos)
if clip_rho_threshold is not None:
clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
else:
clipped_rhos = rhos
cs = torch.clamp(rhos, max=1.0)
# Append bootstrapped value to get [v1, ..., v_t+1]
values_t_plus_1 = torch.cat(
[values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
acc = torch.zeros_like(bootstrap_value)
result = []
for t in range(discounts.shape[0] - 1, -1, -1):
acc = deltas[t] + discounts[t] * cs[t] * acc
result.append(acc)
result.reverse()
vs_minus_v_xs = torch.stack(result)
# Add V(x_s) to get v_s.
vs = torch.add(vs_minus_v_xs, values)
# Advantage for policy gradient.
vs_t_plus_1 = torch.cat(
[vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
if clip_pg_rho_threshold is not None:
clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
else:
clipped_pg_rhos = rhos
pg_advantages = (clipped_pg_rhos *
(rewards + discounts * vs_t_plus_1 - values))
# Make sure no gradients backpropagated through the returned values.
return VTraceReturns(vs=vs, pg_advantages=pg_advantages)