in pyhanabi/tools/obl_model.py [0:0]
def __init__(self, device, in_dim, hid_dim, out_dim, num_lstm_layer):
super().__init__()
# for backward compatibility
if isinstance(in_dim, int):
assert in_dim == 783
self.in_dim = in_dim
self.priv_in_dim = in_dim - 125
self.publ_in_dim = in_dim - 2 * 125
else:
self.in_dim = in_dim
self.priv_in_dim = in_dim[1]
self.publ_in_dim = in_dim[2]
self.hid_dim = hid_dim
self.out_dim = out_dim
self.num_ff_layer = 1
self.num_lstm_layer = num_lstm_layer
self.priv_net = nn.Sequential(
nn.Linear(self.priv_in_dim, self.hid_dim),
nn.ReLU(),
nn.Linear(self.hid_dim, self.hid_dim),
nn.ReLU(),
nn.Linear(self.hid_dim, self.hid_dim),
nn.ReLU(),
)
ff_layers = [nn.Linear(self.publ_in_dim, self.hid_dim), nn.ReLU()]
for i in range(1, self.num_ff_layer):
ff_layers.append(nn.Linear(self.hid_dim, self.hid_dim))
ff_layers.append(nn.ReLU())
self.publ_net = nn.Sequential(*ff_layers)
self.lstm = nn.LSTM(
self.hid_dim,
self.hid_dim,
num_layers=self.num_lstm_layer,
).to(device)
self.lstm.flatten_parameters()
self.fc_v = nn.Linear(self.hid_dim, 1)
self.fc_a = nn.Linear(self.hid_dim, self.out_dim)
# for aux task
self.pred_1st = nn.Linear(self.hid_dim, 5 * 3)