in monobeast/minigrid/monobeast_amigo.py [0:0]
def train(flags):
"""Full training loop."""
if flags.xpid is None:
flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
plogger = file_writer.FileWriter(
xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
)
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)
if flags.num_buffers is None: # Set sensible default for num_buffers.
flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
if flags.num_actors >= flags.num_buffers:
raise ValueError("num_buffers should be larger than num_actors")
T = flags.unroll_length
B = flags.batch_size
flags.device = None
if not flags.disable_cuda and torch.cuda.is_available():
logging.info("Using CUDA.")
flags.device = torch.device("cuda")
else:
logging.info("Not using CUDA.")
flags.device = torch.device("cpu")
env = create_env(flags)
#env = wrappers.FullyObsWrapper(env)
if flags.num_input_frames > 1:
env = FrameStack(env, flags.num_input_frames)
generator_model = Generator(env.observation_space.shape, env.width, env.height, num_input_frames=flags.num_input_frames)
model = Net(env.observation_space.shape, env.action_space.n, state_embedding_dim=flags.state_embedding_dim, num_input_frames=flags.num_input_frames, use_lstm=flags.use_lstm, num_lstm_layers=flags.num_lstm_layers)
global goal_count_dict
goal_count_dict = torch.zeros(11).float().to(device=flags.device)
if flags.inner:
logits_size = (env.width-2)*(env.height-2)
else:
logits_size = env.width * env.height
buffers = create_buffers(env.observation_space.shape, model.num_actions, flags, env.width, env.height, logits_size)
model.share_memory()
generator_model.share_memory()
# Add initial RNN state.
initial_agent_state_buffers = []
for _ in range(flags.num_buffers):
state = model.initial_state(batch_size=1)
for t in state:
t.share_memory_()
initial_agent_state_buffers.append(state)
actor_processes = []
ctx = mp.get_context("fork")
free_queue = ctx.SimpleQueue()
full_queue = ctx.SimpleQueue()
for i in range(flags.num_actors):
actor = ctx.Process(
target=act,
args=(i, free_queue, full_queue, model, generator_model, buffers,
initial_agent_state_buffers, flags))
actor.start()
actor_processes.append(actor)
learner_model = Net(env.observation_space.shape, env.action_space.n, state_embedding_dim=flags.state_embedding_dim, num_input_frames=flags.num_input_frames, use_lstm=flags.use_lstm, num_lstm_layers=flags.num_lstm_layers).to(
device=flags.device
)
learner_generator_model = Generator(env.observation_space.shape, env.width, env.height, num_input_frames=flags.num_input_frames).to(device=flags.device)
optimizer = torch.optim.RMSprop(
learner_model.parameters(),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha,
)
generator_model_optimizer = torch.optim.RMSprop(
learner_generator_model.parameters(),
lr=flags.generator_learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha)
def lr_lambda(epoch):
return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
generator_scheduler = torch.optim.lr_scheduler.LambdaLR(generator_model_optimizer, lr_lambda)
logger = logging.getLogger("logfile")
stat_keys = [
"total_loss",
"mean_episode_return",
"pg_loss",
"baseline_loss",
"entropy_loss",
"gen_rewards",
"gg_loss",
"generator_entropy_loss",
"generator_baseline_loss",
"mean_intrinsic_rewards",
"mean_episode_steps",
"ex_reward",
"generator_current_target",
]
logger.info("# Step\t%s", "\t".join(stat_keys))
frames, stats = 0, {}
def batch_and_learn(i, lock=threading.Lock()):
"""Thread target for the learning process."""
nonlocal frames, stats
timings = prof.Timings()
while frames < flags.total_frames:
timings.reset()
batch, agent_state = get_batch(flags, free_queue, full_queue, buffers,
initial_agent_state_buffers, timings)
stats = learn(model, learner_model, generator_model, learner_generator_model, batch, agent_state, optimizer, generator_model_optimizer, scheduler, generator_scheduler, flags, env.max_steps)
timings.time("learn")
with lock:
to_log = dict(frames=frames)
to_log.update({k: stats[k] for k in stat_keys})
plogger.log(to_log)
frames += T * B
if i == 0:
logging.info("Batch and learn: %s", timings.summary())
for m in range(flags.num_buffers):
free_queue.put(m)
threads = []
for i in range(flags.num_threads):
thread = threading.Thread(
target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,)
)
thread.start()
threads.append(thread)
def checkpoint():
if flags.disable_checkpoint:
return
logging.info("Saving checkpoint to %s", checkpointpath)
torch.save(
{
"model_state_dict": model.state_dict(),
"generator_model_state_dict": generator_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"generator_model_optimizer_state_dict": generator_model_optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"generator_scheduler_state_dict": generator_scheduler.state_dict(),
"flags": vars(flags),
},
checkpointpath,
)
timer = timeit.default_timer
try:
last_checkpoint_time = timer()
while frames < flags.total_frames:
start_frames = frames
start_time = timer()
time.sleep(5)
if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min.
checkpoint()
last_checkpoint_time = timer()
fps = (frames - start_frames) / (timer() - start_time)
if stats.get("episode_returns", None):
mean_return = (
"Return per episode: %.1f. " % stats["mean_episode_return"]
)
else:
mean_return = ""
total_loss = stats.get("total_loss", float("inf"))
logging.info(
"After %i frames: loss %f @ %.1f fps. %sStats:\n%s",
frames,
total_loss,
fps,
mean_return,
pprint.pformat(stats),
)
except KeyboardInterrupt:
return # Try joining actors then quit.
else:
for thread in threads:
thread.join()
logging.info("Learning finished after %d frames.", frames)
finally:
for _ in range(flags.num_actors):
free_queue.put(None)
for actor in actor_processes:
actor.join(timeout=1)
checkpoint()
plogger.close()