in src/datatrove/utils/dataset.py [0:0]
def _get_pos_from_index_file(self, item):
"""
Reads document ends from .index file and returns positions for the requested window.
Positions represent the index of the token within its document.
For example, if the documents in the window end at token positions [3, 5, 8], for seq_len+1=10,
the positions will be [0, 1, 2, 0, 1, 0, 1, 2, 0, 1].
Args:
item (int): The index of the window to retrieve positions for.
Returns:
torch.Tensor: A tensor containing the positions for the tokens in the window.
"""
# Calculate token window range
window_start = item * (self.seq_len + 1)
window_end = window_start + self.seq_len # exclusive, but .index is also exclusive
# Initialize file if first access
if self._last_item is None or item < self._last_item:
if self._f_pos is None:
self._f_pos = self.fs.open(self.file_path + ".index", "rb")
self._idx_buffer = deque()
# we could binary search but we are assuming sequential reads (which is what we optimized for by pre-shuffling the data), so we always read from the start
self._f_pos.seek(0)
# 1. Drop positions before the window
while self._idx_buffer and self._idx_buffer[0] < window_start:
self._idx_buffer.popleft()
# 2. Read until we have at least one position beyond the window or EOF
while not self._idx_buffer or self._idx_buffer[-1] <= window_end:
buffer = self._f_pos.read(1024 * 8) # uint64 = 8 bytes
if not buffer:
break # End of file
self._idx_buffer.extend(np.frombuffer(buffer, np.uint64))
# 3. Extract positions within the window and convert to local indices
doc_ends = torch.tensor(
[0] + [pos - window_start for pos in self._idx_buffer if window_start < pos <= window_end],
dtype=torch.int,
)
# get actual positions
# example: doc_ends = [0, 3, 5, 8]. seq_len+1=10
# pos = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# prev_ends = [-1, 0, 3, 5]
# offsets = [0, -2, -1, -2]
# pos = [0, 1, 1, -2, 1, -1, 1, 1, -2, 1]
# cumsum = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1]
pos = torch.ones(self.seq_len + 1, dtype=torch.int)
prev_ends = torch.cat([torch.tensor([-1], dtype=torch.int), doc_ends[:-1]])
offsets = prev_ends - doc_ends + 1
pos[doc_ends] = offsets
return torch.cumsum(pos, dim=0)