def _test_bin()

in data_loader_terabyte.py [0:0]


def _test_bin():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_directory', required=True)
    parser.add_argument('--input_data_prefix', required=True)
    parser.add_argument('--split', choices=['train', 'test', 'val'],
                        required=True)
    args = parser.parse_args()

    _preprocess(args)

    binary_data_file = os.path.join(args.output_directory,
                                    '{}_data.bin'.format(args.split))

    counts_file = os.path.join(args.output_directory, 'day_fea_count.npz')
    dataset_binary = CriteoBinDataset(data_file=binary_data_file,
                                            counts_file=counts_file,
                                            batch_size=2048,)
    from dlrm_data_pytorch import CriteoDataset 
    from dlrm_data_pytorch import collate_wrapper_criteo_offset as collate_wrapper_criteo

    binary_loader = torch.utils.data.DataLoader(
        dataset_binary,
        batch_size=None,
        shuffle=False,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
        drop_last=False,
    )

    original_dataset = CriteoDataset(
        dataset='terabyte',
        max_ind_range=10 * 1000 * 1000,
        sub_sample_rate=1,
        randomize=True,
        split=args.split,
        raw_path=args.input_data_prefix,
        pro_data='dummy_string',
        memory_map=True
    )

    original_loader = torch.utils.data.DataLoader(
        original_dataset,
        batch_size=2048,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_wrapper_criteo,
        pin_memory=False,
        drop_last=False,
    )

    assert len(dataset_binary) == len(original_loader)
    for i, (old_batch, new_batch) in tqdm(enumerate(zip(original_loader,
                                                        binary_loader)),
                                          total=len(dataset_binary)):

        for j in range(len(new_batch)):
            if not np.array_equal(old_batch[j], new_batch[j]):
                raise ValueError('FAILED: Datasets not equal')
        if i > len(dataset_binary):
            break
    print('PASSED')