def enjoy_env_sess()

in coinrun/enjoy.py [0:0]


def enjoy_env_sess(sess):
    should_render = True
    should_eval = Config.TRAIN_EVAL or Config.TEST_EVAL
    rep_count = Config.REP

    if should_eval:
        env = utils.make_general_env(Config.NUM_EVAL)
        should_render = False
    else:
        env = utils.make_general_env(1)

    env = wrappers.add_final_wrappers(env)

    if should_render:
        from gym.envs.classic_control import rendering

    nenvs = env.num_envs

    agent = create_act_model(sess, env, nenvs)

    sess.run(tf.global_variables_initializer())
    loaded_params = utils.load_params_for_scope(sess, 'model')

    if not loaded_params:
        print('NO SAVED PARAMS LOADED')

    obs = env.reset()
    t_step = 0

    if should_render:
        viewer = rendering.SimpleImageViewer()

    should_render_obs = not Config.IS_HIGH_RES

    def maybe_render(info=None):
        if should_render and not should_render_obs:
            env.render()

    maybe_render()

    scores = np.array([0] * nenvs)
    score_counts = np.array([0] * nenvs)
    curr_rews = np.zeros((nenvs, 3))

    def should_continue():
        if should_eval:
            return np.sum(score_counts) < rep_count * nenvs

        return True

    state = agent.initial_state
    done = np.zeros(nenvs)

    while should_continue():
        action, values, state, _ = agent.step(obs, state, done)
        obs, rew, done, info = env.step(action)

        if should_render and should_render_obs:
            if np.shape(obs)[-1] % 3 == 0:
                ob_frame = obs[0,:,:,-3:]
            else:
                ob_frame = obs[0,:,:,-1]
                ob_frame = np.stack([ob_frame] * 3, axis=2)
            viewer.imshow(ob_frame)

        curr_rews[:,0] += rew

        for i, d in enumerate(done):
            if d:
                if score_counts[i] < rep_count:
                    score_counts[i] += 1

                    if 'episode' in info[i]:
                        scores[i] += info[i].get('episode')['r']

        if t_step % 100 == 0:
            mpi_print('t', t_step, values[0], done[0], rew[0], curr_rews[0], np.shape(obs))

        maybe_render(info[0])

        t_step += 1

        if should_render:
            time.sleep(.02)

        if done[0]:
            if should_render:
                mpi_print('ep_rew', curr_rews)

            curr_rews[:] = 0

    result = 0

    if should_eval:
        mean_score = np.mean(scores) / rep_count
        max_idx = np.argmax(scores)
        mpi_print('scores', scores / rep_count)
        print('mean_score', mean_score)
        mpi_print('max idx', max_idx)

        mpi_mean_score = utils.mpi_average([mean_score])
        mpi_print('mpi_mean', mpi_mean_score)

        result = mean_score

    return result