def _get_test_data_loader()

in distributed_training/src_dir/main_trainer.py [0:0]


def _get_test_data_loader(args, **kwargs):
    logger.info("Get test data loader")
    transform = Compose([
        Resize(args.height, args.width),
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )
    ])

    image_path = os.path.join(args.data_dir, 'val')
    dataset = AlbumentationImageDataset(image_path=image_path,
                                        transform=transform,
                                        args=args)

    drop_last = args.model_parallel
    print("test drop_last : {}".format(drop_last))
    test_sampler = data.distributed.DistributedSampler(
        dataset, num_replicas=int(args.world_size), rank=int(
            args.rank)) if args.multigpus_distributed else None

    return data.DataLoader(dataset,
                           batch_size=args.test_batch_size,
                           shuffle=False,
                           sampler=test_sampler,
                           drop_last=drop_last)