in torchbeast/monobeast.py [0:0]
def test(flags, num_episodes: int = 10):
if flags.xpid is None:
checkpointpath = "./latest/model.tar"
else:
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)
gym_env = create_env(flags)
env = environment.Environment(gym_env)
model = Net(gym_env.observation_space.shape, gym_env.action_space.n, flags.use_lstm)
model.eval()
checkpoint = torch.load(checkpointpath, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
observation = env.initial()
returns = []
while len(returns) < num_episodes:
if flags.mode == "test_render":
env.gym_env.render()
agent_outputs = model(observation)
policy_outputs, _ = agent_outputs
observation = env.step(policy_outputs["action"])
if observation["done"].item():
returns.append(observation["episode_return"].item())
logging.info(
"Episode ended after %d steps. Return: %.1f",
observation["episode_step"].item(),
observation["episode_return"].item(),
)
env.close()
logging.info(
"Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns)
)