in distributed_training/src_dir/main_trainer.py [0:0]
def _get_test_data_loader(args, **kwargs):
logger.info("Get test data loader")
transform = Compose([
Resize(args.height, args.width),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
image_path = os.path.join(args.data_dir, 'val')
dataset = AlbumentationImageDataset(image_path=image_path,
transform=transform,
args=args)
drop_last = args.model_parallel
print("test drop_last : {}".format(drop_last))
test_sampler = data.distributed.DistributedSampler(
dataset, num_replicas=int(args.world_size), rank=int(
args.rank)) if args.multigpus_distributed else None
return data.DataLoader(dataset,
batch_size=args.test_batch_size,
shuffle=False,
sampler=test_sampler,
drop_last=drop_last)