in torchbeast/polybeast_learner.py [0:0]
def train(flags):
if flags.xpid is None:
flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
plogger = file_writer.FileWriter(
xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
)
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)
if not flags.disable_cuda and torch.cuda.is_available():
logging.info("Using CUDA.")
flags.learner_device = torch.device("cuda:0")
flags.actor_device = torch.device("cuda:1")
else:
logging.info("Not using CUDA.")
flags.learner_device = torch.device("cpu")
flags.actor_device = torch.device("cpu")
if flags.max_learner_queue_size is None:
flags.max_learner_queue_size = flags.batch_size
# The queue the learner threads will get their data from.
# Setting `minimum_batch_size == maximum_batch_size`
# makes the batch size static.
learner_queue = libtorchbeast.BatchingQueue(
batch_dim=1,
minimum_batch_size=flags.batch_size,
maximum_batch_size=flags.batch_size,
check_inputs=True,
maximum_queue_size=flags.max_learner_queue_size,
)
# The "batcher", a queue for the inference call. Will yield
# "batch" objects with `get_inputs` and `set_outputs` methods.
# The batch size of the tensors will be dynamic.
inference_batcher = libtorchbeast.DynamicBatcher(
batch_dim=1,
minimum_batch_size=1,
maximum_batch_size=512,
timeout_ms=100,
check_outputs=True,
)
addresses = []
connections_per_server = 1
pipe_id = 0
while len(addresses) < flags.num_actors:
for _ in range(connections_per_server):
addresses.append(f"{flags.pipes_basename}.{pipe_id}")
if len(addresses) == flags.num_actors:
break
pipe_id += 1
model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
model = model.to(device=flags.learner_device)
actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
actor_model.to(device=flags.actor_device)
# The ActorPool that will run `flags.num_actors` many loops.
actors = libtorchbeast.ActorPool(
unroll_length=flags.unroll_length,
learner_queue=learner_queue,
inference_batcher=inference_batcher,
env_server_addresses=addresses,
initial_agent_state=actor_model.initial_state(),
)
def run():
try:
actors.run()
except Exception as e:
logging.error("Exception in actorpool thread!")
traceback.print_exc()
print()
raise e
actorpool_thread = threading.Thread(target=run, name="actorpool-thread")
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha,
)
def lr_lambda(epoch):
return (
1
- min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
/ flags.total_steps
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
stats = {}
# Load state from a checkpoint, if possible.
if os.path.exists(checkpointpath):
checkpoint_states = torch.load(
checkpointpath, map_location=flags.learner_device
)
model.load_state_dict(checkpoint_states["model_state_dict"])
optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
stats = checkpoint_states["stats"]
logging.info(f"Resuming preempted job, current stats:\n{stats}")
# Initialize actor model like learner model.
actor_model.load_state_dict(model.state_dict())
learner_threads = [
threading.Thread(
target=learn,
name="learner-thread-%i" % i,
args=(
flags,
learner_queue,
model,
actor_model,
optimizer,
scheduler,
stats,
plogger,
),
)
for i in range(flags.num_learner_threads)
]
inference_threads = [
threading.Thread(
target=inference,
name="inference-thread-%i" % i,
args=(flags, inference_batcher, actor_model),
)
for i in range(flags.num_inference_threads)
]
actorpool_thread.start()
for t in learner_threads + inference_threads:
t.start()
def checkpoint():
if flags.disable_checkpoint:
return
logging.info("Saving checkpoint to %s", checkpointpath)
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"stats": stats,
"flags": vars(flags),
},
checkpointpath,
)
def format_value(x):
return f"{x:1.5}" if isinstance(x, float) else str(x)
try:
last_checkpoint_time = timeit.default_timer()
while True:
start_time = timeit.default_timer()
start_step = stats.get("step", 0)
if start_step >= flags.total_steps:
break
time.sleep(5)
end_step = stats.get("step", 0)
if timeit.default_timer() - last_checkpoint_time > 10 * 60:
# Save every 10 min.
checkpoint()
last_checkpoint_time = timeit.default_timer()
logging.info(
"Step %i @ %.1f SPS. Inference batcher size: %i."
" Learner queue size: %i."
" Other stats: (%s)",
end_step,
(end_step - start_step) / (timeit.default_timer() - start_time),
inference_batcher.size(),
learner_queue.size(),
", ".join(
f"{key} = {format_value(value)}" for key, value in stats.items()
),
)
except KeyboardInterrupt:
pass # Close properly.
else:
logging.info("Learning finished after %i steps.", stats["step"])
checkpoint()
# Done with learning. Stop all the ongoing work.
inference_batcher.close()
learner_queue.close()
actorpool_thread.join()
for t in learner_threads + inference_threads:
t.join()