def test()

in run_exp.py [0:0]


def test(flags, num_eps: int = 1000):
    from rtfm import featurizer as X
    gym_env = Net.create_env(flags)
    if flags.mode == 'test_render':
        gym_env.featurizer = X.Concat([gym_env.featurizer, X.Terminal()])
    env = environment.Environment(gym_env)

    if not flags.random_agent:
        model = Net.make(flags, gym_env)
        model.eval()
        if flags.xpid is None:
            checkpointpath = './results_latest/model.tar'
        else:
            checkpointpath = os.path.expandvars(
                os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
                                                 'model.tar')))
        checkpoint = torch.load(checkpointpath, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    observation = env.initial()
    returns = []
    won = []
    entropy = []
    ep_len = []
    while len(won) < num_eps:
        done = False
        steps = 0
        while not done:
            if flags.random_agent:
                action = torch.zeros(1, 1, dtype=torch.int32)
                action[0][0] = random.randint(0, gym_env.action_space.n - 1)
                observation = env.step(action)
            else:
                agent_outputs = model(observation)
                observation = env.step(agent_outputs['action'])
                policy = F.softmax(agent_outputs['policy_logits'], dim=-1)
                log_policy = F.log_softmax(agent_outputs['policy_logits'], dim=-1)
                e = -torch.sum(policy * log_policy, dim=-1)
                entropy.append(e.mean(0).item())

            steps += 1
            done = observation['done'].item()
            if observation['done'].item():
                returns.append(observation['episode_return'].item())
                won.append(observation['reward'][0][0].item() > 0.5)
                ep_len.append(steps)
                # logging.info('Episode ended after %d steps. Return: %.1f',
                #              observation['episode_step'].item(),
                #              observation['episode_return'].item())
            if flags.mode == 'test_render':
                sleep_seconds = os.environ.get('DELAY', '0.3')
                time.sleep(float(sleep_seconds))

                if observation['done'].item():
                    print('Done: {}'.format('You won!!' if won[-1] else 'You lost!!'))
                    print('Episode steps: {}'.format(observation['episode_step']))
                    print('Episode return: {}'.format(observation['episode_return']))
                    done_seconds = os.environ.get('DONE', None)
                    if done_seconds is None:
                        print('Press Enter to continue')
                        input()
                    else:
                        time.sleep(float(done_seconds))

    env.close()
    logging.info('Average returns over %i episodes: %.2f. Win rate: %.2f. Entropy: %.2f. Len: %.2f', num_eps, sum(returns)/len(returns), sum(won)/len(returns), sum(entropy)/max(1, len(entropy)), sum(ep_len)/len(ep_len))