import argparse
from mpi4py import MPI
from . import ppg
from . import torch_util as tu
from .impala_cnn import ImpalaEncoder
from . import logger
from .envs import get_venv

def train_fn(env_name="coinrun",
    distribution_mode="hard",
    arch="dual",  # 'shared', 'detach', or 'dual'
    # 'shared' = shared policy and value networks
    # 'dual' = separate policy and value networks
    # 'detach' = shared policy and value networks, but with the value function gradient detached during the policy phase to avoid interference
    interacts_total=100_000_000,
    num_envs=64,
    nstep=256,
    n_epoch_pi=1,
    n_epoch_vf=1,
    gamma=.999,
    lambda_=0.95,
    aux_lr=5e-4,
    aux_beta1=0.9,
    aux_beta2=0.999,
    lr=5e-4,
    beta1=0.9,
    beta2=0.999,
    nminibatch=8,
    aux_mbsize=4,
    clip_param=.2,
    kl_penalty=0.0,
    kl_ewma_decay=None,
    n_aux_epochs=6,
    n_pi=32,
    beta_clone=1.0,
    vf_true_weight=1.0,
    adv_ewma_decay=0.0,
    log_dir='/tmp/ppg',
    log_new_eps=False,
    comm=None,
    staleness=0,
    staleness_loss='decoupled',
    imp_samp_max=100.0):
    if comm is None:
        comm = MPI.COMM_WORLD
    tu.setup_dist(comm=comm)
    tu.register_distributions_for_tree_util()

    if log_dir is not None:
        format_strs = ['csv', 'stdout'] if comm.Get_rank() == 0 else []
        logger.configure(comm=comm, dir=log_dir, format_strs=format_strs)

    venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode)

    enc_fn = lambda obtype: ImpalaEncoder(
        obtype.shape,
        outsize=256,
        chans=(16, 32, 32),
    )
    model = ppg.PhasicValueModel(venv.ob_space, venv.ac_space, enc_fn, arch=arch)

    model.to(tu.dev())
    logger.log(tu.format_model(model))
    tu.sync_params(model.parameters())

    name2coef = {"pol_distance": beta_clone, "vf_true": vf_true_weight}

    ppg.learn(
        venv=venv,
        model=model,
        interacts_total=interacts_total,
        ppo_hps=dict(
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            nstep=nstep,
            γ=gamma,
            λ=lambda_,
            nminibatch=nminibatch,
            n_epoch_vf=n_epoch_vf,
            n_epoch_pi=n_epoch_pi,
            clip_param=clip_param,
            kl_penalty=kl_penalty,
            adv_ewma_decay=adv_ewma_decay,
            log_save_opts={"save_mode": "last", "log_new_eps": log_new_eps},
            staleness=staleness,
            staleness_loss=staleness_loss,
            imp_samp_max=imp_samp_max
        ),
        aux_lr=aux_lr,
        aux_beta1=aux_beta1,
        aux_beta2=aux_beta2,
        aux_mbsize=aux_mbsize,
        n_aux_epochs=n_aux_epochs,
        n_pi=n_pi,
        kl_ewma_decay=kl_ewma_decay,
        name2coef=name2coef,
        comm=comm,
    )

def main():
    parser = argparse.ArgumentParser(description='Process PPG training arguments.')
    parser.add_argument('--env_name', type=str, default='coinrun')
    parser.add_argument('--num_envs', type=int, default=64)
    parser.add_argument('--n_epoch_pi', type=int, default=1)
    parser.add_argument('--n_epoch_vf', type=int, default=1)
    parser.add_argument('--n_aux_epochs', type=int, default=6)
    parser.add_argument('--n_pi', type=int, default=32)
    parser.add_argument('--clip_param', type=float, default=0.2)
    parser.add_argument('--kl_penalty', type=float, default=0.0)
    parser.add_argument('--arch', type=str, default='dual') # 'shared', 'detach', or 'dual'

    args = parser.parse_args()

    comm = MPI.COMM_WORLD

    train_fn(
        env_name=args.env_name,
        num_envs=args.num_envs,
        n_epoch_pi=args.n_epoch_pi,
        n_epoch_vf=args.n_epoch_vf,
        n_aux_epochs=args.n_aux_epochs,
        n_pi=args.n_pi,
        arch=args.arch,
        comm=comm)

if __name__ == '__main__':
    main()
