def _get_train_data_loader()

in notebook/source/monai_dicom.py [0:0]


def _get_train_data_loader(batch_size, trainX, trainY, is_distributed, **kwargs):
    logger.info("Get train data loader")
    
    train_transforms = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        RandRotate(range_x=15, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True),
        Resize(spatial_size=(108,96)),
        ToTensor()
    ])
    
    dataset = DICOMDataset(trainX, trainY, train_transforms)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None,
                                       sampler=train_sampler, **kwargs)