in pyhanabi/r2d2.py [0:0]
def act(self, obs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Acts on the given obs, with eps-greedy policy.
output: {'a' : actions}, a long Tensor of shape
[batchsize] or [batchsize, num_player]
"""
obsize, ibsize, num_player = 0, 0, 0
if self.vdn:
obsize, ibsize, num_player = obs["priv_s"].size()[:3]
priv_s = obs["priv_s"].flatten(0, 2)
legal_move = obs["legal_move"].flatten(0, 2)
eps = obs["eps"].flatten(0, 2)
else:
obsize, ibsize = obs["priv_s"].size()[:2]
num_player = 1
priv_s = obs["priv_s"].flatten(0, 1)
legal_move = obs["legal_move"].flatten(0, 1)
eps = obs["eps"].flatten(0, 1)
hid = {
"h0": obs["h0"].flatten(0, 1).transpose(0, 1).contiguous(),
"c0": obs["c0"].flatten(0, 1).transpose(0, 1).contiguous(),
}
greedy_action, new_hid = self.greedy_act(priv_s, legal_move, hid)
random_action = legal_move.multinomial(1).squeeze(1)
rand = torch.rand(greedy_action.size(), device=greedy_action.device)
assert rand.size() == eps.size()
rand = (rand < eps).long()
action = (greedy_action * (1 - rand) + random_action * rand).detach().long()
if self.vdn:
action = action.view(obsize, ibsize, num_player)
greedy_action = greedy_action.view(obsize, ibsize, num_player)
rand = rand.view(obsize, ibsize, num_player)
else:
action = action.view(obsize, ibsize)
greedy_action = greedy_action.view(obsize, ibsize)
rand = rand.view(obsize, ibsize)
hid_shape = (
obsize,
ibsize * num_player,
self.online_net.num_lstm_layer,
self.online_net.hid_dim,
)
h0 = new_hid["h0"].transpose(0, 1).view(*hid_shape)
c0 = new_hid["c0"].transpose(0, 1).view(*hid_shape)
reply = {
"a": action.detach().cpu(),
"greedy_a": greedy_action.detach().cpu(),
"h0": h0.contiguous().detach().cpu(),
"c0": c0.contiguous().detach().cpu(),
}
return reply