def training()

in src/sm_augmentation_train-script.py [0:0]


def training(args):
    num_gpus = args.num_gpus
    hosts = args.hosts
    current_host = args.current_host
    backend = args.backend
    seed = args.seed

    is_distributed = len(hosts) > 1 and backend is not None
    logger.debug("Distributed training - {}".format(is_distributed))
    use_cuda = num_gpus > 0
    logger.debug("Number of gpus available - {}".format(num_gpus))
    device = torch.device("cuda" if use_cuda else "cpu")

    world_size = len(hosts)
    os.environ['WORLD_SIZE'] = str(world_size)
    host_rank = hosts.index(current_host)

    if is_distributed:
        # Initialize the distributed environment.
        dist.init_process_group(backend=backend, rank=host_rank, world_size=world_size)
        logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
            backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format(
            dist.get_rank(), num_gpus))
    # set the seed for generating random numbers
    torch.manual_seed(seed)

    if use_cuda:
        torch.cuda.manual_seed(seed)

    # Loading training and validation data
    batch_size = args.batch_size
    train_dir = args.train_dir

    # Set to the available #CPUs here — Hits the file system concurrency with large #workers for large #CPU instances
    workers = os.cpu_count() if use_cuda else 0

    # By factor to repeat augmentation to influence bottleneck
    aug_load_factor = args.aug_load_factor

    # Deciding on the augmentation approach to use
    USE_PYTORCH = False
    USE_DALI_CPU = False
    if args.aug == 'pytorch-cpu':
        USE_PYTORCH = True
    if args.aug == 'dali-cpu':
        USE_DALI_CPU = True
    if USE_PYTORCH == True:
        dataloaders, dataset_sizes = augmentation_pytorch(train_dir,
                                                          batch_size,
                                                          workers,
                                                          is_distributed,
                                                          use_cuda,
                                                          aug_load_factor)
    else:
        dataloaders, dataset_sizes = augmentation_dali(train_dir,
                                                       batch_size,
                                                       workers,
                                                       host_rank,
                                                       world_size,
                                                       seed,
                                                       aug_load_factor,
                                                       dali_cpu=USE_DALI_CPU)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Deciding on the model to use
    if args.model_type == 'RESNET18':
        model_ft = models.resnet18(pretrained=False)
    elif args.model_type == 'RESNET50':
        model_ft = models.resnet50(pretrained=False)
    elif args.model_type == 'RESNET152':
        model_ft = models.resnet152(pretrained=False)
    else:
        sys.exit('Requested Model not found')

    model_ft = model_ft.to(device)

    if is_distributed and use_cuda:
        model_ft = torch.nn.parallel.DistributedDataParallel(model_ft)
    else:
        model_ft = torch.nn.DataParallel(model_ft)

    num_epochs = args.epochs
    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), args.lr, args.momentum)

    # Running Model Training   
    since = time.time()

    # Not using the trained model or accuracy score for this experiment
    model_ft, best_acc = run_training_epochs(model_ft,
                                             num_epochs,
                                             criterion,
                                             optimizer_ft,
                                             dataloaders,
                                             dataset_sizes,
                                             device,
                                             USE_PYTORCH)
    time_elapsed = time.time() - since

    print('-' * 25)
    print ("Model — ", args.model_type)
    print ("Augmentation Approach — ", args.aug)
    print ("Batch Size — ", batch_size)
    print ("Augmentation Load factor — ", aug_load_factor)
    print('-' * 25)