in understanding_rl_vision/rl_clarity/loading.py [0:0]
def get_step_fn(config, params, *, num_envs, full_resolution):
config = config.copy()
config.pop("num_envs")
library = config.get("library", "baselines")
venv = create_env(num_envs, **config)
arch = get_arch(**config)
with tf.Graph().as_default(), tf.Session() as sess:
if library == "baselines":
from baselines.common.policies import build_policy
with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE):
policy_fn = build_policy(venv, arch)
policy = policy_fn(nbatch=venv.num_envs, nsteps=1, sess=sess)
stepdata = {
"ob": venv.reset(),
"state": policy.initial_state,
"first": np.ones((venv.num_envs,), bool),
}
if full_resolution:
stepdata["ob_full"] = np.stack(
[info["rgb"] for info in venv.env.get_info()], axis=0
)
def step_fn():
result = {"ob": stepdata["ob"], "first": stepdata["first"].astype(bool)}
if full_resolution:
result["ob_full"] = stepdata["ob_full"]
result["ac"], _, stepdata["state"], _ = policy.step(
stepdata["ob"],
S=stepdata["state"],
M=stepdata["first"].astype(float),
)
(
stepdata["ob"],
result["reward"],
stepdata["first"],
result["info"],
) = venv.step(result["ac"])
if full_resolution:
stepdata["ob_full"] = np.stack(
[info["rgb"] for info in result["info"]], axis=0
)
return result
else:
raise ValueError(f"Unsupported library: {library}")
load_params(params, sess=sess)
yield step_fn