def train()

in code/train.py [0:0]


def train(args):
    use_cuda = args.num_gpus > 0
    device = torch.device("cuda" if use_cuda else "cpu")
    
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    local_rank = dist.get_local_rank()
    
    # set the seed for generating random numbers
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    train_loader = _get_train_data_loader(args.batch_size, args.data_dir)
    if rank == 0:
        test_loader = _get_test_data_loader(args.test_batch_size, args.test)
        print("Max length of sequence: ", MAX_LEN)
        print("Freezing {} layers".format(args.frozen_layers))
        print("Model used: ", PRE_TRAINED_MODEL_NAME)

    logger.debug(
        "Processes {}/{} ({:.0f}%) of train data".format(
            len(train_loader.sampler),
            len(train_loader.dataset),
            100.0 * len(train_loader.sampler) / len(train_loader.dataset),
        )
    )

    model = ProteinClassifier(
        args.num_labels  # The number of output labels.
    )
    freeze(model, args.frozen_layers)
    model = DDP(model.to(device), broadcast_buffers=False)
    torch.cuda.set_device(local_rank)
    model.cuda(local_rank)
    
    optimizer = optim.Lamb(
            model.parameters(), 
            lr = args.lr * dist.get_world_size(), 
            betas=(0.9, 0.999), 
            eps=args.epsilon, 
            weight_decay=args.weight_decay)
    
    total_steps = len(train_loader.dataset)
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps)
    
    loss_fn = nn.CrossEntropyLoss().to(device)
    
    for epoch in range(1, args.epochs + 1):
        model.train()
        for step, batch in enumerate(train_loader):
            b_input_ids = batch['input_ids'].to(device)
            b_input_mask = batch['attention_mask'].to(device)
            b_labels = batch['targets'].to(device)

            outputs = model(b_input_ids,attention_mask=b_input_mask)
            loss = loss_fn(outputs, b_labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # modified based on their gradients, the learning rate, etc.
            optimizer.step()
            optimizer.zero_grad()
            
            if step % args.log_interval == 0 and rank == 0:
                logger.info(
                    "Collecting data from Master Node: \n Train Epoch: {} [{}/{} ({:.0f}%)] Training Loss: {:.6f}".format(
                        epoch,
                        step * len(batch['input_ids'])*world_size,
                        len(train_loader.dataset),
                        100.0 * step / len(train_loader),
                        loss.item(),
                    )
                )
            if args.verbose:
                print('Batch', step, "from rank", rank)
        if rank == 0:
            test(model, test_loader, device)
        scheduler.step()
    if rank == 0:
        model_save = model.module if hasattr(model, "module") else model
        save_model(model_save, args.model_dir)