def main()

in bring-your-own-container/fairseq_translation/fairseq/distributed_train.py [0:0]


def main(args):

    port = 1112
    with open("/opt/ml/input/config/resourceconfig.json", "r") as f:
        resource_config = json.load(f)
    hosts = resource_config["hosts"]
    current_host = resource_config["current_host"]

    num_gpus_per_node = torch.cuda.device_count()
    world_size = len(hosts)

    args.distributed_backend = "gloo"

    args.distributed_init_method = "tcp://{host}:{port}".format(host=hosts[0], port=port)

    args.distributed_world_size = world_size * num_gpus_per_node

    mp = torch.multiprocessing.get_context("spawn")

    # Create a thread to listen for errors in the child processes.
    error_queue = mp.SimpleQueue()
    error_handler = ErrorHandler(error_queue)

    # Train with multiprocessing.
    procs = []
    for i in range(num_gpus_per_node):

        args.distributed_rank = hosts.index(current_host) * num_gpus_per_node + i
        args.device_id = i

        procs.append(
            mp.Process(
                target=run,
                args=(
                    args,
                    error_queue,
                ),
                daemon=True,
            )
        )
        procs[i].start()
        error_handler.add_child(procs[i].pid)
    for p in procs:
        p.join()