in pyhanabi/r2d2.py [0:0]
def td_error(self, obs, hid, action, reward, terminal, bootstrap, seq_len, stat):
max_seq_len = obs["priv_s"].size(0)
bsize, num_player = 0, 1
if self.vdn:
bsize, num_player = self.flat_4d(obs)
self.flat_4d(action)
priv_s = obs["priv_s"]
legal_move = obs["legal_move"]
action = action["a"]
hid = {}
# this only works because the trajectories are padded,
# i.e. no terminal in the middle
online_qa, greedy_a, _, lstm_o = self.online_net(
priv_s, legal_move, action, hid
)
with torch.no_grad():
target_qa, _, _, _ = self.target_net(priv_s, legal_move, greedy_a, hid)
# assert target_q.size() == pa.size()
# target_qe = (pa * target_q).sum(-1).detach()
assert online_qa.size() == target_qa.size()
if self.vdn:
online_qa = online_qa.view(max_seq_len, bsize, num_player).sum(-1)
target_qa = target_qa.view(max_seq_len, bsize, num_player).sum(-1)
lstm_o = lstm_o.view(max_seq_len, bsize, num_player, -1)
terminal = terminal.float()
bootstrap = bootstrap.float()
errs = []
target_qa = torch.cat(
[target_qa[self.multi_step :], target_qa[: self.multi_step]], 0
)
target_qa[-self.multi_step :] = 0
assert target_qa.size() == reward.size()
target = reward + bootstrap * (self.gamma ** self.multi_step) * target_qa
mask = torch.arange(0, max_seq_len, device=seq_len.device)
mask = (mask.unsqueeze(1) < seq_len.unsqueeze(0)).float()
err = (target.detach() - online_qa) * mask
return err, lstm_o