in pyhanabi/utils.py [0:0]
def load_op_model(method, idx1, idx2, device):
"""load op models, op models was trained only for 2 player
"""
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# assume model saved in root/models/op
folder = os.path.join(root, "models", "op", method)
agents = []
for idx in [idx1, idx2]:
if idx is None:
continue
if idx >= 0 and idx < 3:
num_fc = 1
skip_connect = False
elif idx >= 3 and idx < 6:
num_fc = 1
skip_connect = True
elif idx >= 6 and idx < 9:
num_fc = 2
skip_connect = False
else:
num_fc = 2
skip_connect = True
weight_file = os.path.join(folder, f"M{idx}.pthw")
if not os.path.exists(weight_file):
print(f"Cannot find weight at: {weight_file}")
assert False
state_dict = torch.load(weight_file)
input_dim = state_dict["net.0.weight"].size()[1]
hid_dim = 512
output_dim = state_dict["fc_a.weight"].size()[0]
agent = r2d2.R2D2Agent(
False,
3,
0.999,
0.9,
device,
input_dim,
hid_dim,
output_dim,
2,
5,
False,
num_fc_layer=num_fc,
skip_connect=skip_connect,
).to(device)
load_weight(agent.online_net, weight_file, device)
agents.append(agent)
return agents