in torchrec/datasets/criteo.py [0:0]
def _load_data_for_rank(self) -> None:
file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
lengths=[
BinaryCriteoUtils.get_shape_from_npy(
path, path_manager_key=self.path_manager_key
)[0]
for path in self.dense_paths
],
rank=self.rank,
world_size=self.world_size,
)
self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
for arrs, paths in zip(
[self.dense_arrs, self.sparse_arrs, self.labels_arrs],
[self.dense_paths, self.sparse_paths, self.labels_paths],
):
for idx, (range_left, range_right) in file_idx_to_row_range.items():
arrs.append(
BinaryCriteoUtils.load_npy_range(
paths[idx],
range_left,
range_right - range_left + 1,
path_manager_key=self.path_manager_key,
)
)
if self.hashes is not None:
hashes_np = np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT))
for sparse_arr in self.sparse_arrs:
sparse_arr %= hashes_np