in torchrec/datasets/criteo.py [0:0]
def __iter__(self) -> Iterator[Batch]:
# Invariant: buffer never contains more than batch_size rows.
buffer: Optional[List[np.ndarray]] = None
def append_to_buffer(
dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray
) -> None:
nonlocal buffer
if buffer is None:
buffer = [dense, sparse, labels]
else:
for idx, arr in enumerate([dense, sparse, labels]):
buffer[idx] = np.concatenate((buffer[idx], arr))
# Maintain a buffer that can contain up to batch_size rows. Fill buffer as
# much as possible on each iteration. Only return a new batch when batch_size
# rows are filled.
file_idx = 0
row_idx = 0
batch_idx = 0
while batch_idx < self.num_batches:
buffer_row_count = 0 if buffer is None else none_throws(buffer)[0].shape[0]
if buffer_row_count == self.batch_size:
yield self._np_arrays_to_batch(*none_throws(buffer))
batch_idx += 1
buffer = None
else:
rows_to_get = min(
self.batch_size - buffer_row_count,
self.num_rows_per_file[file_idx] - row_idx,
)
slice_ = slice(row_idx, row_idx + rows_to_get)
append_to_buffer(
self.dense_arrs[file_idx][slice_, :],
self.sparse_arrs[file_idx][slice_, :],
self.labels_arrs[file_idx][slice_, :],
)
row_idx += rows_to_get
if row_idx >= self.num_rows_per_file[file_idx]:
file_idx += 1
row_idx = 0