in core/vtrace.py [0:0]
def from_logits(behavior_policy_logits,
target_policy_logits,
actions,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0):
"""V-trace for softmax policies."""
target_action_log_probs = action_log_probs(target_policy_logits, actions)
behavior_action_log_probs = action_log_probs(behavior_policy_logits,
actions)
log_rhos = target_action_log_probs - behavior_action_log_probs
vtrace_returns = from_importance_weights(
log_rhos=log_rhos,
discounts=discounts,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)
return VTraceFromLogitsReturns(
log_rhos=log_rhos,
behavior_action_log_probs=behavior_action_log_probs,
target_action_log_probs=target_action_log_probs,
**vtrace_returns._asdict())