in downstream/semseg/lib/dataset.py [0:0]
def initialize_data_loader(DatasetClass,
config,
phase,
num_workers,
shuffle,
repeat,
augment_data,
batch_size,
limit_numpoints,
input_transform=None,
target_transform=None):
if isinstance(phase, str):
phase = str2datasetphase_type(phase)
if config.data.return_transformation:
collate_fn = t.cflt_collate_fn_factory(limit_numpoints)
else:
collate_fn = t.cfl_collate_fn_factory(limit_numpoints)
prevoxel_transform_train = []
if augment_data:
prevoxel_transform_train.append(t.ElasticDistortion(DatasetClass.ELASTIC_DISTORT_PARAMS))
if len(prevoxel_transform_train) > 0:
prevoxel_transforms = t.Compose(prevoxel_transform_train)
else:
prevoxel_transforms = None
input_transforms = []
if input_transform is not None:
input_transforms += input_transform
if augment_data:
input_transforms += [
t.RandomDropout(0.2),
t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL),
t.ChromaticAutoContrast(),
t.ChromaticTranslation(config.augmentation.data_aug_color_trans_ratio),
t.ChromaticJitter(config.augmentation.data_aug_color_jitter_std),
# t.HueSaturationTranslation(config.data_aug_hue_max, config.data_aug_saturation_max),
]
if len(input_transforms) > 0:
input_transforms = t.Compose(input_transforms)
else:
input_transforms = None
dataset = DatasetClass(
config,
prevoxel_transform=prevoxel_transforms,
input_transform=input_transforms,
target_transform=target_transform,
cache=config.data.cache_data,
augment_data=augment_data,
phase=phase)
data_args = {
'dataset': dataset,
'num_workers': num_workers,
'batch_size': batch_size,
'collate_fn': collate_fn,
}
if repeat:
if get_world_size() > 1:
data_args['sampler'] = DistributedInfSampler(dataset, shuffle=shuffle) # torch.utils.data.distributed.DistributedSampler(dataset)
else:
data_args['sampler'] = InfSampler(dataset, shuffle)
else:
data_args['shuffle'] = shuffle
data_loader = DataLoader(**data_args)
return data_loader