def train()

in understanding_rl_vision/rl_clarity/training.py [0:0]


def train(comm=None, *, save_dir=None, **kwargs):
    """
    Train a model using Baselines' PPO2, and to save a checkpoint file in the
    required format.

    There is one required kwarg: either env_name (for env_kind="procgen") or
    env_id (for env_kind="atari").

    Models for the paper were trained with 16 parallel MPI workers.

    Note: this code has not been well-tested.
    """
    kwargs.setdefault("env_kind", "procgen")
    kwargs.setdefault("num_envs", 64)
    kwargs.setdefault("learning_rate", 5e-4)
    kwargs.setdefault("entropy_coeff", 0.01)
    kwargs.setdefault("gamma", 0.999)
    kwargs.setdefault("lambda", 0.95)
    kwargs.setdefault("num_steps", 256)
    kwargs.setdefault("num_minibatches", 8)
    kwargs.setdefault("library", "baselines")
    kwargs.setdefault("save_all", False)
    kwargs.setdefault("ppo_epochs", 3)
    kwargs.setdefault("clip_range", 0.2)
    kwargs.setdefault("timesteps_per_proc", 1_000_000_000)
    kwargs.setdefault("cnn", "clear")
    kwargs.setdefault("use_lstm", 0)
    kwargs.setdefault("stack_channels", "16_32_32")
    kwargs.setdefault("emb_size", 256)
    kwargs.setdefault("epsilon_greedy", 0.0)
    kwargs.setdefault("reward_scale", 1.0)
    kwargs.setdefault("frame_stack", 1)
    kwargs.setdefault("use_sticky_actions", 0)
    kwargs.setdefault("clip_vf", 1)
    kwargs.setdefault("reward_processing", "none")
    kwargs.setdefault("save_interval", 10)

    if comm is None:
        comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    setup_mpi_gpus()

    if save_dir is None:
        save_dir = tempfile.mkdtemp(prefix="rl_clarity_train_")

    create_env_kwargs = kwargs.copy()
    num_envs = create_env_kwargs.pop("num_envs")
    venv = create_env(num_envs, **create_env_kwargs)

    library = kwargs["library"]
    if library == "baselines":
        reward_processing = kwargs["reward_processing"]
        if reward_processing == "none":
            pass
        elif reward_processing == "clip":
            venv = VecClipReward(venv=venv)
        elif reward_processing == "normalize":
            venv = VecNormalize(venv=venv, ob=False, per_env=False)
        else:
            raise ValueError(f"Unsupported reward processing: {reward_processing}")

        scope = "ppo2_model"

        def update_fn(update, params=None):
            if rank == 0:
                save_interval = kwargs["save_interval"]
                if save_interval > 0 and update % save_interval == 0:
                    print("Saving...")
                    params = get_tf_params(scope)
                    save_path = save_data(
                        save_dir=save_dir,
                        args_dict=kwargs,
                        params=params,
                        step=(update if kwargs["save_all"] else None),
                    )
                    print(f"Saved to: {save_path}")

        sess = create_tf_session()
        sess.__enter__()

        if kwargs["use_lstm"]:
            raise ValueError("Recurrent networks not yet supported.")
        arch = get_arch(**kwargs)

        from baselines.ppo2 import ppo2

        ppo2.learn(
            env=venv,
            network=arch,
            total_timesteps=kwargs["timesteps_per_proc"],
            save_interval=0,
            nsteps=kwargs["num_steps"],
            nminibatches=kwargs["num_minibatches"],
            lam=kwargs["lambda"],
            gamma=kwargs["gamma"],
            noptepochs=kwargs["ppo_epochs"],
            log_interval=1,
            ent_coef=kwargs["entropy_coeff"],
            mpi_rank_weight=1.0,
            clip_vf=bool(kwargs["clip_vf"]),
            comm=comm,
            lr=kwargs["learning_rate"],
            cliprange=kwargs["clip_range"],
            update_fn=update_fn,
            init_fn=None,
            vf_coef=0.5,
            max_grad_norm=0.5,
        )
    else:
        raise ValueError(f"Unsupported library: {library}")

    return save_dir