def run_spawn()

in baseline_model/run_tree_transformer_multi_gpu.py [0:0]


def run_spawn(gpu, args, trn, vld, tst, SRC, max_len_src, max_len_trg, logging):
    print(gpu)
    if args.dist_gpu:
        torch.cuda.set_device(gpu)
        dist.init_process_group(                                   
            backend='nccl',                                         
            init_method='env://',
            world_size=args.n_dist_gpu,                              
            rank=gpu
        )      
    device = torch.device(gpu)
    best_valid_loss = float('inf')
    torch.backends.cudnn.deterministic = True
    torch.cuda.set_device(gpu)

    collate = text_data_collator(trn)
    train_iterator = DataLoader(trn, batch_size=args.bsz, collate_fn=collate, num_workers=args.num_workers, shuffle=False)
    collate = text_data_collator(vld)
    valid_iterator = DataLoader(vld, batch_size=args.bsz, collate_fn=collate, num_workers=args.num_workers, shuffle=False)


    INPUT_DIM = len(SRC.vocab)

    gnn_asm = Graph_NN( annotation_size = len(SRC.vocab),
                        out_feats = args.hid_dim,
                        n_steps = args.n_gnn_layers,
                        device = device
                        )

    gnn_ast = Graph_NN( annotation_size = len(SRC.vocab),
                        out_feats = args.hid_dim,
                        n_steps = args.n_gnn_layers,
                        device = device)

    enc = Encoder(INPUT_DIM,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  device,
                  args.mem_dim,
                  embedding_flag=args.embedding_flag,
                  max_length = max_len_src)

    dec = Decoder_AST(
                  args.output_dim,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  device,
                  max_length = max_len_trg)

    SRC_PAD_IDX = 0 
    TRG_PAD_IDX = 0 

    model = Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device,  
                gnn=gnn_ast, gnn_asm=gnn_asm).to(device)
    model = model.cuda(gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu], output_device=gpu) 
    model.apply(initialize_weights)
    criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX, reduction='sum').cuda(gpu)
    optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    if args.training and not args.eval:
        for epoch in range(args.n_epoch):
            start_time = time.time()
            train_loss, train_acc = train_eval_tree(args, model, train_iterator, optimizer\
                                , device, criterion, max_len_trg, train_flag=True)
            torch.distributed.barrier()

            if gpu == 0:
                valid_loss, valid_acc = train_eval_tree(args, model, valid_iterator, None\
                                    , device, criterion, max_len_trg, train_flag=False)
            torch.distributed.barrier()
            end_time = time.time()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss and (args.checkpoint_path is not None) and gpu == 0:
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'model_multi_gpu.pt'))

            if gpu==0:
                logging('Epoch: %d | Time: %dm %ds | learning rate %.3f' %(epoch,epoch_mins,epoch_secs, optimizer._rate*10000))
                print_performances('Training', train_loss, train_acc, start_time, logging=logging)
                print_performances('Validation', valid_loss, valid_acc, start_time, logging=logging)

            torch.distributed.barrier()

    if gpu == 0:
        model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'model_multi_gpu.pt')))
        start_time = time.time()
        test_loss, test_acc = test_tree(args, model, valid_iterator, TRG_PAD_IDX, device, args.label_smoothing, criterion, args.clip)
        print_performances('Test', test_loss, test_acc, start_time, logging=logging)