def train_and_run()

in understanding_rl_vision/rl_clarity/example.py [0:0]


def train_and_run(env_name_or_id, *, base_path=None):
    if base_path is None:
        base_path = tempfile.mkdtemp(prefix="rl_clarity_example_")
    training_dir = os.path.join(base_path, "training")
    interface_dir = os.path.join(base_path, "interface")
    if "://" not in base_path:
        os.makedirs(training_dir, exist_ok=True)
        os.makedirs(interface_dir, exist_ok=True)

    if env_name_or_id in PROCGEN_ENV_NAMES + ["coinrun_old"]:
        env_kwargs = {"env_name": env_name_or_id}
    elif env_name_or_id in ATARI_ENV_DICT:
        env_kwargs = {"env_id": env_name_or_id, "env_kind": "atari"}
    else:
        raise ValueError(f"Unsupported env {env_name_or_id}")

    # train for very few timesteps, to demonstrate
    # note: training code has not been well-tested
    rl_clarity.train(
        num_envs=8,
        num_steps=16,
        timesteps_per_proc=8 * 16 * 2,
        save_interval=2,
        save_dir=training_dir,
        **env_kwargs,
    )
    checkpoint_path = os.path.join(training_dir, "checkpoint.jd")
    print(f"Checkpoint saved to: {checkpoint_path}")

    print("Generating interface...")
    # generate a small interface, to demonstrate
    rl_clarity.run(
        checkpoint_path,
        output_dir=interface_dir,
        trajectories_kwargs={"num_envs": 8, "num_steps": 16},
        observations_kwargs={"num_envs": 8, "num_obs": 4, "obs_every": 4},
        layer_kwargs={"name_contains_one_of": ["2b"]},
    )

    interface_path = os.path.join(interface_dir, "interface.html")
    interface_url = ("" if "://" in interface_path else "file://") + interface_path
    print(f"Interface URL: {interface_url}")