mlsh_code/main.py (69 lines of code) (raw):

import argparse import tensorflow as tf parser = argparse.ArgumentParser() parser.add_argument('savename', type=str) parser.add_argument('--task', type=str) parser.add_argument('--num_subs', type=int) parser.add_argument('--macro_duration', type=int) parser.add_argument('--num_rollouts', type=int) parser.add_argument('--warmup_time', type=int) parser.add_argument('--train_time', type=int) parser.add_argument('--force_subpolicy', type=int) parser.add_argument('--replay', type=str) parser.add_argument('-s', action='store_true') parser.add_argument('--continue_iter', type=str) args = parser.parse_args() # python main.py --task MovementBandits-v0 --num_subs 2 --macro_duration 10 --num_rollouts 1000 --warmup_time 60 --train_time 1 --replay True test from mpi4py import MPI from rl_algs.common import set_global_seeds, tf_util as U import os.path as osp import gym, logging import numpy as np from collections import deque from gym import spaces import misc_util import sys import shutil import subprocess import master def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') replay = str2bool(args.replay) args.replay = str2bool(args.replay) RELPATH = osp.join(args.savename) LOGDIR = osp.join('/root/results' if sys.platform.startswith('linux') else '/tmp', RELPATH) def callback(it): if MPI.COMM_WORLD.Get_rank()==0: if it % 5 == 0 and it > 3 and not replay: fname = osp.join("savedir/", 'checkpoints', '%.5i'%it) U.save_state(fname) if it == 0 and args.continue_iter is not None: fname = osp.join("savedir/"+args.savename+"/checkpoints/", str(args.continue_iter)) U.load_state(fname) pass def train(): num_timesteps=1e9 seed = 1401 rank = MPI.COMM_WORLD.Get_rank() sess = U.single_threaded_session() sess.__enter__() workerseed = seed + 1000 * MPI.COMM_WORLD.Get_rank() rank = MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) # if rank != 0: # logger.set_level(logger.DISABLED) # logger.log("rank %i" % MPI.COMM_WORLD.Get_rank()) world_group = MPI.COMM_WORLD.Get_group() mygroup = rank % 10 theta_group = world_group.Incl([x for x in range(MPI.COMM_WORLD.size) if (x % 10 == mygroup)]) comm = MPI.COMM_WORLD.Create(theta_group) comm.Barrier() # comm = MPI.COMM_WORLD master.start(callback, args=args, workerseed=workerseed, rank=rank, comm=comm) def main(): if MPI.COMM_WORLD.Get_rank() == 0 and osp.exists(LOGDIR): shutil.rmtree(LOGDIR) MPI.COMM_WORLD.Barrier() # with logger.session(dir=LOGDIR): train() if __name__ == '__main__': main()