def train()

in distributed_training/src_dir/main_trainer.py [0:0]


def train(local_rank, args):
    best_acc1 = -1
    model_history = {}
    model_history = util.init_modelhistory(model_history)
    train_start = time.time()

    if local_rank is not None:
        args.local_rank = local_rank

    # distributed_setting
    if args.multigpus_distributed:
        args = dis_util.dist_setting(args)

    # choose model from pytorch model_zoo
    model = util.torch_model(
        args.model_name,
        num_classes=args.num_classes,
        pretrained=True,
        local_rank=args.local_rank,
        model_parallel=args.model_parallel)  # 1000 resnext101_32x8d
    criterion = nn.CrossEntropyLoss().cuda()

    model, args = dis_util.dist_model(model, args)

    # CuDNN library will benchmark several algorithms and pick that which it found to be fastest
    cudnn.benchmark = False if args.seed else True

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.apex:
        model, optimizer, args = dis_util.apex_init(model, optimizer, args)
    elif args.model_parallel:
        model, optimizer, args = dis_util.smp_init(model, optimizer, args)
    elif args.data_parallel:
        model, optimizer, args = dis_util.sdp_init(model, optimizer, args)

    train_loader, train_sampler = _get_train_data_loader(args, **args.kwargs)

    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
        len(train_loader.sampler), len(train_loader.dataset),
        100. * len(train_loader.sampler) / len(train_loader.dataset)))

    test_loader = _get_test_data_loader(args, **args.kwargs)

    #     if args.rank == 0:
    logger.info("Processes {}/{} ({:.0f}%) of test data".format(
        len(test_loader.sampler), len(test_loader.dataset),
        100. * len(test_loader.sampler) / len(test_loader.dataset)))

    print(" local_rank : {}, local_batch_size : {}".format(
        args.local_rank, args.batch_size))

    for epoch in range(1, args.num_epochs + 1):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(train_loader), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        model.train()
        end = time.time()

        # Set epoch count for DistributedSampler
        if args.multigpus_distributed and not args.model_parallel:
            train_sampler.set_epoch(epoch)

        for batch_idx, (input, target) in enumerate(train_loader):
            input = input.to(args.device)
            target = target.to(args.device)
            batch_idx += 1

            if args.model_parallel:
                output, loss = dis_util.train_step(model, criterion, input,
                                                   target, args.scaler, args)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()

            else:
                output = model(input)
                loss = criterion(output, target)

#             if not args.model_parallel:
#                 # compute gradient and do SGD step
#                 optimizer.zero_grad()

            if args.apex:
                dis_util.apex_loss(loss, optimizer)
            elif not args.model_parallel:
                loss.backward()

            optimizer.step()

            if args.model_parallel:
                # compute gradient and do SGD step
                optimizer.zero_grad()

            if args.rank == 0:
                #             if args.rank == 0 and batch_idx % args.log_interval == 1:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                if args.model_parallel:
                    output = torch.cat(output.outputs)

                # Measure accuracy
                prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), input.size(0))
                top1.update(util.to_python_float(prec1), input.size(0))
                top5.update(util.to_python_float(prec5), input.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - end) / args.log_interval)
                end = time.time()

                #                 if args.rank == 0:
                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}), '
                    'Train_Prec@1={top1.val:.3f}:({top1.avg:.3f}), '
                    'Train_Prec@5={top5.val:.3f}:({top5.avg:.3f})'.format(
                        epoch,
                        batch_idx,
                        len(train_loader),
                        args.world_size * args.batch_size / batch_time.val,
                        args.world_size * args.batch_size / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses,
                        top1=top1,
                        top5=top5))

        acc1 = validate(test_loader, model, criterion, epoch, model_history,
                        args)

        is_best = False

        if args.rank == 0:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        if not args.multigpus_distributed or (args.multigpus_distributed
                                              and not args.model_parallel
                                              and args.rank == 0):
            model_history['epoch'].append(epoch)
            model_history['batch_idx'].append(batch_idx)
            model_history['batch_time'].append(batch_time.val)
            model_history['losses'].append(losses.val)
            model_history['top1'].append(top1.val)
            model_history['top5'].append(top5.val)

            util.save_history(
                os.path.join(args.output_data_dir, 'model_history.p'),
                model_history)
            util.save_model(
                {
                    'epoch': epoch + 1,
                    'model_name': args.model_name,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'class_to_idx': train_loader.dataset.class_to_idx,
                }, is_best, args)
        elif args.model_parallel:
            if args.rank == 0:
                util.save_history(
                    os.path.join(args.output_data_dir, 'model_history.p'),
                    model_history)
            dis_util.smp_savemodel(model, optimizer, is_best, args)
            
    if args.model_parallel:
        dis_util.smp_barrier()

    if args.data_parallel:
        dis_util.sdp_barrier(args)

    return 1