understanding_rl_vision/rl_clarity/example.py (51 lines of code) (raw):

import os import tempfile import argparse from understanding_rl_vision import rl_clarity from understanding_rl_vision.rl_clarity.training import ( PROCGEN_ENV_NAMES, ATARI_ENV_DICT, ) def train_and_run(env_name_or_id, *, base_path=None): if base_path is None: base_path = tempfile.mkdtemp(prefix="rl_clarity_example_") training_dir = os.path.join(base_path, "training") interface_dir = os.path.join(base_path, "interface") if "://" not in base_path: os.makedirs(training_dir, exist_ok=True) os.makedirs(interface_dir, exist_ok=True) if env_name_or_id in PROCGEN_ENV_NAMES + ["coinrun_old"]: env_kwargs = {"env_name": env_name_or_id} elif env_name_or_id in ATARI_ENV_DICT: env_kwargs = {"env_id": env_name_or_id, "env_kind": "atari"} else: raise ValueError(f"Unsupported env {env_name_or_id}") # train for very few timesteps, to demonstrate # note: training code has not been well-tested rl_clarity.train( num_envs=8, num_steps=16, timesteps_per_proc=8 * 16 * 2, save_interval=2, save_dir=training_dir, **env_kwargs, ) checkpoint_path = os.path.join(training_dir, "checkpoint.jd") print(f"Checkpoint saved to: {checkpoint_path}") print("Generating interface...") # generate a small interface, to demonstrate rl_clarity.run( checkpoint_path, output_dir=interface_dir, trajectories_kwargs={"num_envs": 8, "num_steps": 16}, observations_kwargs={"num_envs": 8, "num_obs": 4, "obs_every": 4}, layer_kwargs={"name_contains_one_of": ["2b"]}, ) interface_path = os.path.join(interface_dir, "interface.html") interface_url = ("" if "://" in interface_path else "file://") + interface_path print(f"Interface URL: {interface_url}") def main(): parser = argparse.ArgumentParser() parser.add_argument("env", nargs='?', default="coinrun_old") parser.add_argument("-p", "--path") args = parser.parse_args() train_and_run(args.env, base_path=args.path) if __name__ == "__main__": main()