in examples/svi_horovod.py [0:0]
def main(args):
# Create a model, synthetic data, and a guide.
pyro.set_rng_seed(args.seed)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
if args.horovod:
# Initialize Horovod and set PyTorch globals.
import horovod.torch as hvd
hvd.init()
torch.set_num_threads(1)
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
device = torch.tensor(0).device
if args.horovod:
# Initialize parameters and broadcast to all workers.
guide(covariates[:1], data[:1]) # Initializes model and guide.
hvd.broadcast_parameters(guide.state_dict(), root_rank=0)
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
# Create an ELBO loss and a Pyro optimizer.
elbo = Trace_ELBO()
optim = Adam({"lr": args.learning_rate})
if args.horovod:
# Wrap the basic optimizer in a distributed optimizer.
optim = HorovodOptimizer(optim)
# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
if args.horovod:
# Horovod requires a distributed sampler.
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, hvd.size(), hvd.rank())
else:
sampler = torch.utils.data.RandomSampler(dataset)
config = {"batch_size": args.batch_size, "sampler": sampler}
if args.cuda:
config["num_workers"] = 1
config["pin_memory"] = True
# Try to use forkserver to spawn workers instead of fork.
if (hasattr(mp, "_supports_context") and mp._supports_context and
"forkserver" in mp.get_all_start_methods()):
config["multiprocessing_context"] = "forkserver"
dataloader = torch.utils.data.DataLoader(dataset, **config)
# Run stochastic variational inference.
svi = SVI(model, guide, optim, elbo)
for epoch in range(args.num_epochs):
if args.horovod:
# Set rng seeds on distributed samplers. This is required.
sampler.set_epoch(epoch)
for step, (covariates_batch, data_batch) in enumerate(dataloader):
loss = svi.step(covariates_batch.to(device), data_batch.to(device))
if args.horovod:
# Optionally average loss metric across workers.
# You can do this with arbitrary torch.Tensors.
loss = torch.tensor(loss)
loss = hvd.allreduce(loss, "loss")
loss = loss.item()
# Print only on the rank=0 worker.
if step % 100 == 0 and hvd.rank() == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
else:
if step % 100 == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
if args.horovod:
# After we're done with the distributed parts of the program,
# we can shutdown all but the rank=0 worker.
hvd.shutdown()
if hvd.rank() != 0:
return
if args.outfile:
print("saving to {}".format(args.outfile))
torch.save({"model": model, "guide": guide}, args.outfile)