in pyhanabi/r2d2.py [0:0]
def loss(self, batch, pred_weight, stat):
err, lstm_o = self.td_error(
batch.obs,
batch.h0,
batch.action,
batch.reward,
batch.terminal,
batch.bootstrap,
batch.seq_len,
stat,
)
rl_loss = nn.functional.smooth_l1_loss(
err, torch.zeros_like(err), reduction="none"
)
rl_loss = rl_loss.sum(0)
stat["rl_loss"].feed((rl_loss / batch.seq_len).mean().item())
priority = err.abs()
# priority = self.aggregate_priority(p, batch.seq_len)
if pred_weight > 0:
if self.vdn:
pred_loss1 = self.aux_task_vdn(
lstm_o,
batch.obs["own_hand"],
batch.obs["temperature"],
batch.seq_len,
rl_loss.size(),
stat,
)
loss = rl_loss + pred_weight * pred_loss1
else:
pred_loss = self.aux_task_iql(
lstm_o, batch.obs["own_hand"], batch.seq_len, rl_loss.size(), stat,
)
loss = rl_loss + pred_weight * pred_loss
else:
loss = rl_loss
return loss, priority