in data_loader_terabyte.py [0:0]
def _test_bin():
parser = argparse.ArgumentParser()
parser.add_argument('--output_directory', required=True)
parser.add_argument('--input_data_prefix', required=True)
parser.add_argument('--split', choices=['train', 'test', 'val'],
required=True)
args = parser.parse_args()
_preprocess(args)
binary_data_file = os.path.join(args.output_directory,
'{}_data.bin'.format(args.split))
counts_file = os.path.join(args.output_directory, 'day_fea_count.npz')
dataset_binary = CriteoBinDataset(data_file=binary_data_file,
counts_file=counts_file,
batch_size=2048,)
from dlrm_data_pytorch import CriteoDataset
from dlrm_data_pytorch import collate_wrapper_criteo_offset as collate_wrapper_criteo
binary_loader = torch.utils.data.DataLoader(
dataset_binary,
batch_size=None,
shuffle=False,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
original_dataset = CriteoDataset(
dataset='terabyte',
max_ind_range=10 * 1000 * 1000,
sub_sample_rate=1,
randomize=True,
split=args.split,
raw_path=args.input_data_prefix,
pro_data='dummy_string',
memory_map=True
)
original_loader = torch.utils.data.DataLoader(
original_dataset,
batch_size=2048,
shuffle=False,
num_workers=0,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
drop_last=False,
)
assert len(dataset_binary) == len(original_loader)
for i, (old_batch, new_batch) in tqdm(enumerate(zip(original_loader,
binary_loader)),
total=len(dataset_binary)):
for j in range(len(new_batch)):
if not np.array_equal(old_batch[j], new_batch[j]):
raise ValueError('FAILED: Datasets not equal')
if i > len(dataset_binary):
break
print('PASSED')