def get_step_fn()

in understanding_rl_vision/rl_clarity/loading.py [0:0]


def get_step_fn(config, params, *, num_envs, full_resolution):
    config = config.copy()
    config.pop("num_envs")
    library = config.get("library", "baselines")
    venv = create_env(num_envs, **config)
    arch = get_arch(**config)

    with tf.Graph().as_default(), tf.Session() as sess:
        if library == "baselines":
            from baselines.common.policies import build_policy

            with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE):
                policy_fn = build_policy(venv, arch)
                policy = policy_fn(nbatch=venv.num_envs, nsteps=1, sess=sess)

            stepdata = {
                "ob": venv.reset(),
                "state": policy.initial_state,
                "first": np.ones((venv.num_envs,), bool),
            }
            if full_resolution:
                stepdata["ob_full"] = np.stack(
                    [info["rgb"] for info in venv.env.get_info()], axis=0
                )

            def step_fn():
                result = {"ob": stepdata["ob"], "first": stepdata["first"].astype(bool)}
                if full_resolution:
                    result["ob_full"] = stepdata["ob_full"]
                result["ac"], _, stepdata["state"], _ = policy.step(
                    stepdata["ob"],
                    S=stepdata["state"],
                    M=stepdata["first"].astype(float),
                )
                (
                    stepdata["ob"],
                    result["reward"],
                    stepdata["first"],
                    result["info"],
                ) = venv.step(result["ac"])
                if full_resolution:
                    stepdata["ob_full"] = np.stack(
                        [info["rgb"] for info in result["info"]], axis=0
                    )
                return result

        else:
            raise ValueError(f"Unsupported library: {library}")

        load_params(params, sess=sess)

        yield step_fn