in baselines/gail/run_mujoco.py [0:0]
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
env = gym.make(args.env_id)
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(args.seed)
gym.logger.setLevel(logging.WARN)
task_name = get_task_name(args)
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
args.log_dir = osp.join(args.log_dir, task_name)
if args.task == 'train':
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
reward_giver = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff)
train(env,
args.seed,
policy_fn,
reward_giver,
dataset,
args.algo,
args.g_step,
args.d_step,
args.policy_entcoeff,
args.num_timesteps,
args.save_per_iter,
args.checkpoint_dir,
args.log_dir,
args.pretrained,
args.BC_max_iter,
task_name
)
elif args.task == 'evaluate':
runner(env,
policy_fn,
args.load_model_path,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=args.stochastic_policy,
save=args.save_sample
)
else:
raise NotImplementedError
env.close()