in minihack/agent/polybeast/polybeast_learner.py [0:0]
def train(flags):
logging.info("Logging results to %s", flags.savedir)
if isinstance(flags, omegaconf.DictConfig):
flag_dict = omegaconf.OmegaConf.to_container(flags)
else:
flag_dict = vars(flags)
plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir)
if not flags.disable_cuda and torch.cuda.is_available():
logging.info("Using CUDA.")
learner_device = torch.device(flags.learner_device)
actor_device = torch.device(flags.actor_device)
else:
logging.info("Not using CUDA.")
learner_device = torch.device("cpu")
actor_device = torch.device("cpu")
if flags.max_learner_queue_size is None:
flags.max_learner_queue_size = flags.batch_size
# The queue the learner threads will get their data from.
# Setting `minimum_batch_size == maximum_batch_size`
# makes the batch size static. We could make it dynamic, but that
# requires a loss (and learning rate schedule) that's batch size
# independent.
learner_queue = libtorchbeast.BatchingQueue(
batch_dim=1,
minimum_batch_size=flags.batch_size,
maximum_batch_size=flags.batch_size,
check_inputs=True,
maximum_queue_size=flags.max_learner_queue_size,
)
# The "batcher", a queue for the inference call. Will yield
# "batch" objects with `get_inputs` and `set_outputs` methods.
# The batch size of the tensors will be dynamic.
inference_batcher = libtorchbeast.DynamicBatcher(
batch_dim=1,
minimum_batch_size=1,
maximum_batch_size=512,
timeout_ms=100,
check_outputs=True,
)
addresses = []
connections_per_server = 1
pipe_id = 0
while len(addresses) < flags.num_actors:
for _ in range(connections_per_server):
addresses.append(f"{flags.pipes_basename}.{pipe_id}")
if len(addresses) == flags.num_actors:
break
pipe_id += 1
logging.info("Using model %s", flags.model)
model = create_model(flags, learner_device)
plogger.metadata["model_numel"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
logging.info(
"Number of model parameters: %i", plogger.metadata["model_numel"]
)
actor_model = create_model(flags, actor_device)
# The ActorPool that will run `flags.num_actors` many loops.
actors = libtorchbeast.ActorPool(
unroll_length=flags.unroll_length,
learner_queue=learner_queue,
inference_batcher=inference_batcher,
env_server_addresses=addresses,
initial_agent_state=model.initial_state(),
)
def run():
try:
actors.run()
except Exception as e:
logging.error("Exception in actorpool thread!")
traceback.print_exc()
print()
raise e
actorpool_thread = threading.Thread(target=run, name="actorpool-thread")
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha,
)
def lr_lambda(epoch):
return (
1
- min(
epoch * flags.unroll_length * flags.batch_size,
flags.total_steps,
)
/ flags.total_steps
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
stats = {}
if flags.checkpoint and os.path.exists(flags.checkpoint):
logging.info("Loading checkpoint: %s" % flags.checkpoint)
checkpoint_states = torch.load(
flags.checkpoint, map_location=flags.learner_device
)
model.load_state_dict(checkpoint_states["model_state_dict"])
optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
stats = checkpoint_states["stats"]
logging.info(f"Resuming preempted job, current stats:\n{stats}")
# Initialize actor model like learner model.
actor_model.load_state_dict(model.state_dict())
learner_threads = [
threading.Thread(
target=learn,
name="learner-thread-%i" % i,
args=(
learner_queue,
model,
actor_model,
optimizer,
scheduler,
stats,
flags,
plogger,
learner_device,
),
)
for i in range(flags.num_learner_threads)
]
inference_threads = [
threading.Thread(
target=inference,
name="inference-thread-%i" % i,
args=(inference_batcher, actor_model, flags, actor_device),
)
for i in range(flags.num_inference_threads)
]
actorpool_thread.start()
for t in learner_threads + inference_threads:
t.start()
def checkpoint(checkpoint_path=None):
if flags.checkpoint:
if checkpoint_path is None:
checkpoint_path = flags.checkpoint
logging.info("Saving checkpoint to %s", checkpoint_path)
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"stats": stats,
"flags": vars(flags),
},
checkpoint_path,
)
def format_value(x):
return f"{x:1.5}" if isinstance(x, float) else str(x)
try:
train_start_time = timeit.default_timer()
train_time_offset = stats.get(
"train_seconds", 0
) # used for resuming training
last_checkpoint_time = timeit.default_timer()
dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75]
loop_start_time = timeit.default_timer()
loop_start_step = stats.get("step", 0)
while True:
if loop_start_step >= flags.total_steps:
break
time.sleep(5)
loop_end_time = timeit.default_timer()
loop_end_step = stats.get("step", 0)
stats["train_seconds"] = round(
loop_end_time - train_start_time + train_time_offset, 1
)
if loop_end_time - last_checkpoint_time > 10 * 60:
# Save every 10 min.
checkpoint()
last_checkpoint_time = loop_end_time
if len(dev_checkpoint_intervals) > 0:
step_percentage = loop_end_step / flags.total_steps
i = dev_checkpoint_intervals[0]
if step_percentage > i:
checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar")
dev_checkpoint_intervals = dev_checkpoint_intervals[1:]
logging.info(
"Step %i @ %.1f SPS. Inference batcher size: %i."
" Learner queue size: %i."
" Other stats: (%s)",
loop_end_step,
(loop_end_step - loop_start_step)
/ (loop_end_time - loop_start_time),
inference_batcher.size(),
learner_queue.size(),
", ".join(
f"{key} = {format_value(value)}"
for key, value in stats.items()
),
)
loop_start_time = loop_end_time
loop_start_step = loop_end_step
except KeyboardInterrupt:
pass # Close properly.
else:
logging.info("Learning finished after %i steps.", stats["step"])
checkpoint()
# Done with learning. Let's stop all the ongoing work.
inference_batcher.close()
learner_queue.close()
actorpool_thread.join()
for t in learner_threads + inference_threads:
t.join()