def make_criteo_data_and_loaders()

in torchbenchmark/models/dlrm/dlrm_data_pytorch.py [0:0]


def make_criteo_data_and_loaders(args):

    if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
        # more efficient for larger batches
        data_directory = path.dirname(args.raw_data_file)

        if args.mlperf_bin_loader:
            lstr = args.processed_data_file.split("/")
            d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0]
            train_file = d_path + "_train.bin"
            test_file = d_path + "_test.bin"
            # val_file = d_path + "_val.bin"
            counts_file = args.raw_data_file + '_fea_count.npz'

            if any(not path.exists(p) for p in [train_file,
                                                test_file,
                                                counts_file]):
                ensure_dataset_preprocessed(args, d_path)

            train_data = data_loader_terabyte.CriteoBinDataset(
                data_file=train_file,
                counts_file=counts_file,
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
                sampler=RandomSampler(train_data) if args.mlperf_bin_shuffle else None
            )

            test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            test_loader = torch.utils.data.DataLoader(
                test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:
            data_filename = args.raw_data_file.split("/")[-1]

            train_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "train",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            test_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "test",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            train_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=list(range(23)),
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="train"
            )

            test_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=[23],
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="test"
            )
    else:
        train_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "train",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        test_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "test",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.mini_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
        )

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=args.test_mini_batch_size,
            shuffle=False,
            num_workers=args.test_num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
        )

    return train_data, train_loader, test_data, test_loader