in run_exp.py [0:0]
def test(flags, num_eps: int = 1000):
from rtfm import featurizer as X
gym_env = Net.create_env(flags)
if flags.mode == 'test_render':
gym_env.featurizer = X.Concat([gym_env.featurizer, X.Terminal()])
env = environment.Environment(gym_env)
if not flags.random_agent:
model = Net.make(flags, gym_env)
model.eval()
if flags.xpid is None:
checkpointpath = './results_latest/model.tar'
else:
checkpointpath = os.path.expandvars(
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
'model.tar')))
checkpoint = torch.load(checkpointpath, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
observation = env.initial()
returns = []
won = []
entropy = []
ep_len = []
while len(won) < num_eps:
done = False
steps = 0
while not done:
if flags.random_agent:
action = torch.zeros(1, 1, dtype=torch.int32)
action[0][0] = random.randint(0, gym_env.action_space.n - 1)
observation = env.step(action)
else:
agent_outputs = model(observation)
observation = env.step(agent_outputs['action'])
policy = F.softmax(agent_outputs['policy_logits'], dim=-1)
log_policy = F.log_softmax(agent_outputs['policy_logits'], dim=-1)
e = -torch.sum(policy * log_policy, dim=-1)
entropy.append(e.mean(0).item())
steps += 1
done = observation['done'].item()
if observation['done'].item():
returns.append(observation['episode_return'].item())
won.append(observation['reward'][0][0].item() > 0.5)
ep_len.append(steps)
# logging.info('Episode ended after %d steps. Return: %.1f',
# observation['episode_step'].item(),
# observation['episode_return'].item())
if flags.mode == 'test_render':
sleep_seconds = os.environ.get('DELAY', '0.3')
time.sleep(float(sleep_seconds))
if observation['done'].item():
print('Done: {}'.format('You won!!' if won[-1] else 'You lost!!'))
print('Episode steps: {}'.format(observation['episode_step']))
print('Episode return: {}'.format(observation['episode_return']))
done_seconds = os.environ.get('DONE', None)
if done_seconds is None:
print('Press Enter to continue')
input()
else:
time.sleep(float(done_seconds))
env.close()
logging.info('Average returns over %i episodes: %.2f. Win rate: %.2f. Entropy: %.2f. Len: %.2f', num_eps, sum(returns)/len(returns), sum(won)/len(returns), sum(entropy)/max(1, len(entropy)), sum(ep_len)/len(ep_len))