in src_code/controllers/basic_controller_interactive.py [0:0]
def _build_agents(self, input_shape, input_alone_shape):
self.agent = agent_REGISTRY[self.args.agent](input_shape, input_alone_shape, self.args)
if self.args.pretrained:
print('Loading pretrained model')
model_dict = self.agent.agent_alone.state_dict()
checkpoint = th.load(self.args.single_model_name+"/"+"-".join(self.args.env_args['map_name'])+"/agent.th")
# 1. filter out unnecessary keys
state_dict = {}
for k, v in checkpoint.items():
if 'agent_alone' in k:
state_dict[k[12:]] = v
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
self.agent.agent_alone.load_state_dict(model_dict)