def augmentation_pytorch()

in src/sm_augmentation_train-script.py [0:0]


def augmentation_pytorch(train_dir, batch_size, workers, is_distributed, use_cuda, aug_load_factor):
    print ("Image augmentation using PyTorch Dataloaders on CPUs")
    aug_ops = [
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(5)
    ]
    crop_norm_ops = [
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]

    train_aug_ops = []
    # Repeating Augmentation to influence bottleneck
    for iteration in range(aug_load_factor):
        train_aug_ops = train_aug_ops + aug_ops

    data_transforms = {
        'train': transforms.Compose(train_aug_ops + crop_norm_ops),
        'val': transforms.Compose(crop_norm_ops),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(train_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val']}
    train_sampler = torch.utils.data.distributed.DistributedSampler(image_datasets) if is_distributed else None
    dataloaders = {x: torch.utils.data.DataLoader(dataset=image_datasets[x],
                                                  batch_size=batch_size,
                                                  shuffle=train_sampler,
                                                  num_workers=workers,
                                                  pin_memory=True if use_cuda else False)
                   for x in ['train', 'val']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    return dataloaders, dataset_sizes