in bring-your-own-container/fairseq_translation/fairseq/distributed_train.py [0:0]
def main(args):
port = 1112
with open("/opt/ml/input/config/resourceconfig.json", "r") as f:
resource_config = json.load(f)
hosts = resource_config["hosts"]
current_host = resource_config["current_host"]
num_gpus_per_node = torch.cuda.device_count()
world_size = len(hosts)
args.distributed_backend = "gloo"
args.distributed_init_method = "tcp://{host}:{port}".format(host=hosts[0], port=port)
args.distributed_world_size = world_size * num_gpus_per_node
mp = torch.multiprocessing.get_context("spawn")
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for i in range(num_gpus_per_node):
args.distributed_rank = hosts.index(current_host) * num_gpus_per_node + i
args.device_id = i
procs.append(
mp.Process(
target=run,
args=(
args,
error_queue,
),
daemon=True,
)
)
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()