def augmentation_dali()

in src/sm_augmentation_train-script.py [0:0]


def augmentation_dali(train_dir, batch_size, workers, host_rank, world_size, seed, aug_load_factor, dali_cpu):
    if dali_cpu:
        print ("Image augmentation using DALI pipelines on CPUs")
    else:
        print ("Image augmentation using DALI pipelines on GPUs")

    """
    Augmentation on GPU with DALI is not implemented at the moment for distributed training. Refer to: 
    https://github.com/NVIDIA/DALI/blob/c4e86b55dccba083ae944cf00a478678b7e906cc/docs/examples/use_cases/pytorch/resnet50/main.py
    """

    dataloaders = {}
    dataset_sizes = {}

    train_path = train_dir + '/train/'
    dataset_sizes['train'] = sum([len(files) for r, d, files in os.walk(train_path)])
    train_pipe = create_dali_pipeline(batch_size=batch_size,
                                      num_threads=workers,
                                      device_id=host_rank,
                                      seed=seed,
                                      data_dir=train_path,
                                      crop=224,
                                      size=256,
                                      dali_cpu=dali_cpu,
                                      shard_id=host_rank,
                                      num_shards=world_size,
                                      is_training=True,
                                      aug_load_factor=aug_load_factor)
    train_pipe.build()
    dataloaders['train'] = DALIClassificationIterator(train_pipe,
                                                      reader_name="Reader",
                                                      last_batch_policy=LastBatchPolicy.PARTIAL)

    val_path = train_dir + '/val/'
    dataset_sizes['val'] = sum([len(files) for r, d, files in os.walk(val_path)])
    val_pipe = create_dali_pipeline(batch_size=batch_size,
                                    num_threads=workers,
                                    device_id=host_rank,
                                    seed=seed,
                                    data_dir=val_path,
                                    crop=224,
                                    size=256,
                                    dali_cpu=dali_cpu,
                                    shard_id=host_rank,
                                    num_shards=world_size,
                                    is_training=False,
                                    aug_load_factor=aug_load_factor)
    val_pipe.build()
    dataloaders['val'] = DALIClassificationIterator(val_pipe,
                                                    reader_name="Reader",
                                                    last_batch_policy=LastBatchPolicy.PARTIAL)
    return dataloaders, dataset_sizes