in reagent/training/sac_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.PolicyNetworkInput, batch_idx: int):
"""
IMPORTANT: the input action here is assumed to match the
range of the output of the actor.
"""
assert isinstance(training_batch, rlt.PolicyNetworkInput)
state = training_batch.state
action = training_batch.action
reward = training_batch.reward
discount = torch.full_like(reward, self.gamma)
not_done_mask = training_batch.not_terminal
#
# First, optimize Q networks; minimizing MSE between
# Q(s, a) & r + discount * V'(next_s)
#
if self.value_network is not None:
next_state_value = self.value_network_target(training_batch.next_state)
else:
next_state_actor_output = self.actor_network(training_batch.next_state)
next_state_actor_action = (
training_batch.next_state,
rlt.FeatureData(next_state_actor_output.action),
)
next_state_value = self.q1_network_target(*next_state_actor_action)
if self.q2_network is not None:
target_q2_value = self.q2_network_target(*next_state_actor_action)
next_state_value = torch.min(next_state_value, target_q2_value)
log_prob_a = self.actor_network.get_log_prob(
training_batch.next_state, next_state_actor_output.action
).clamp(LOG_PROB_MIN, LOG_PROB_MAX)
next_state_value -= self.entropy_temperature * log_prob_a
if self.gamma > 0.0:
target_q_value = (
reward + discount * next_state_value * not_done_mask.float()
)
else:
# This is useful in debugging instability issues
target_q_value = reward
q1_value = self.q1_network(state, action)
q1_loss = F.mse_loss(q1_value, target_q_value)
yield q1_loss
if self.q2_network:
q2_value = self.q2_network(state, action)
q2_loss = F.mse_loss(q2_value, target_q_value)
yield q2_loss
# Second, optimize the actor; minimizing KL-divergence between
# propensity & softmax of value. Due to reparameterization trick,
# it ends up being log_prob(actor_action) - Q(s, actor_action)
actor_output = self.actor_network(state)
state_actor_action = (state, rlt.FeatureData(actor_output.action))
q1_actor_value = self.q1_network(*state_actor_action)
min_q_actor_value = q1_actor_value
if self.q2_network:
q2_actor_value = self.q2_network(*state_actor_action)
min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)
actor_log_prob = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
if not self.backprop_through_log_prob:
actor_log_prob = actor_log_prob.detach()
if self.crr_config is not None:
cur_value = self.value_network(training_batch.state)
advantage = (min_q_actor_value - cur_value).detach()
# pyre-fixme[16]: `Optional` has no attribute `get_weight_from_advantage`.
crr_weight = self.crr_config.get_weight_from_advantage(advantage)
assert (
actor_log_prob.shape == crr_weight.shape
), f"{actor_log_prob.shape} != {crr_weight.shape}"
actor_loss = -(actor_log_prob * crr_weight.detach())
else:
actor_loss = self.entropy_temperature * actor_log_prob - min_q_actor_value
# Do this in 2 steps so we can log histogram of actor loss
actor_loss_mean = actor_loss.mean()
if self.add_kld_to_loss:
if self.apply_kld_on_mean:
action_batch_m = torch.mean(actor_output.squashed_mean, axis=0)
action_batch_v = torch.var(actor_output.squashed_mean, axis=0)
else:
action_batch_m = torch.mean(actor_output.action, axis=0)
action_batch_v = torch.var(actor_output.action, axis=0)
kld = (
0.5
* (
(action_batch_v + (action_batch_m - self.action_emb_mean) ** 2)
/ self.action_emb_variance
- 1
+ self.action_emb_variance.log()
- action_batch_v.log()
).sum()
)
actor_loss_mean += self.kld_weight * kld
yield actor_loss_mean
# Optimize Alpha
if self.alpha_optimizer is not None:
alpha_loss = -(
(
self.log_alpha
* (
actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
+ self.target_entropy
).detach()
).mean()
)
yield alpha_loss
self.entropy_temperature = self.log_alpha.exp()
#
# Lastly, if applicable, optimize value network; minimizing MSE between
# V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
#
if self.value_network is not None:
state_value = self.value_network(state)
if self.logged_action_uniform_prior:
log_prob_a = torch.zeros_like(min_q_actor_value)
target_value = min_q_actor_value
else:
log_prob_a = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
target_value = min_q_actor_value - self.entropy_temperature * log_prob_a
value_loss = F.mse_loss(state_value, target_value.detach())
yield value_loss
self.logger.log_metrics(
{
"td_loss": q1_loss,
"logged_rewards": reward.mean(),
"model_values_on_logged_actions": q1_value.mean(),
"q1_value": q1_value.mean(),
"entropy_temperature": self.entropy_temperature,
"log_prob_a": log_prob_a.mean(),
"next_state_value": next_state_value.mean(),
"target_q_value": target_q_value.mean(),
"min_q_actor_value": min_q_actor_value.mean(),
"actor_output_log_prob": actor_output.log_prob.mean(),
"actor_loss": actor_loss.mean(),
},
step=self.all_batches_processed,
)
if self.q2_network:
self.logger.log_metrics(
{"q2_value": q2_value.mean()},
step=self.all_batches_processed,
)
if self.value_network:
self.logger.log_metrics(
{"target_state_value": target_value.mean()},
step=self.all_batches_processed,
)
if self.add_kld_to_loss:
self.logger.log_metrics(
{
"action_batch_mean": action_batch_m.mean(),
"action_batch_var": action_batch_v.mean(),
# pyre-fixme[61]: `kld` may not be initialized here.
"kld": kld,
},
step=self.all_batches_processed,
)
# Use the soft update rule to update the target networks
result = self.soft_update_result()
self.log("td_loss", q1_loss, prog_bar=True)
yield result