in baseline_model/run_tree_transformer_multi_gpu.py [0:0]
def run_spawn(gpu, args, trn, vld, tst, SRC, max_len_src, max_len_trg, logging):
print(gpu)
if args.dist_gpu:
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.n_dist_gpu,
rank=gpu
)
device = torch.device(gpu)
best_valid_loss = float('inf')
torch.backends.cudnn.deterministic = True
torch.cuda.set_device(gpu)
collate = text_data_collator(trn)
train_iterator = DataLoader(trn, batch_size=args.bsz, collate_fn=collate, num_workers=args.num_workers, shuffle=False)
collate = text_data_collator(vld)
valid_iterator = DataLoader(vld, batch_size=args.bsz, collate_fn=collate, num_workers=args.num_workers, shuffle=False)
INPUT_DIM = len(SRC.vocab)
gnn_asm = Graph_NN( annotation_size = len(SRC.vocab),
out_feats = args.hid_dim,
n_steps = args.n_gnn_layers,
device = device
)
gnn_ast = Graph_NN( annotation_size = len(SRC.vocab),
out_feats = args.hid_dim,
n_steps = args.n_gnn_layers,
device = device)
enc = Encoder(INPUT_DIM,
args.hid_dim,
args.n_layers,
args.n_heads,
args.pf_dim,
args.dropout,
device,
args.mem_dim,
embedding_flag=args.embedding_flag,
max_length = max_len_src)
dec = Decoder_AST(
args.output_dim,
args.hid_dim,
args.n_layers,
args.n_heads,
args.pf_dim,
args.dropout,
device,
max_length = max_len_trg)
SRC_PAD_IDX = 0
TRG_PAD_IDX = 0
model = Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device,
gnn=gnn_ast, gnn_asm=gnn_asm).to(device)
model = model.cuda(gpu)
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu], output_device=gpu)
model.apply(initialize_weights)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX, reduction='sum').cuda(gpu)
optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
if args.training and not args.eval:
for epoch in range(args.n_epoch):
start_time = time.time()
train_loss, train_acc = train_eval_tree(args, model, train_iterator, optimizer\
, device, criterion, max_len_trg, train_flag=True)
torch.distributed.barrier()
if gpu == 0:
valid_loss, valid_acc = train_eval_tree(args, model, valid_iterator, None\
, device, criterion, max_len_trg, train_flag=False)
torch.distributed.barrier()
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss and (args.checkpoint_path is not None) and gpu == 0:
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'model_multi_gpu.pt'))
if gpu==0:
logging('Epoch: %d | Time: %dm %ds | learning rate %.3f' %(epoch,epoch_mins,epoch_secs, optimizer._rate*10000))
print_performances('Training', train_loss, train_acc, start_time, logging=logging)
print_performances('Validation', valid_loss, valid_acc, start_time, logging=logging)
torch.distributed.barrier()
if gpu == 0:
model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'model_multi_gpu.pt')))
start_time = time.time()
test_loss, test_acc = test_tree(args, model, valid_iterator, TRG_PAD_IDX, device, args.label_smoothing, criterion, args.clip)
print_performances('Test', test_loss, test_acc, start_time, logging=logging)