in src/nanotron/data/nemo_dataset/__init__.py [0:0]
def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch, drop_last=True, add_extra_token=1):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
if not drop_last:
num_samples = -(-(num_epochs * tokens_per_epoch - add_extra_token) // seq_length)
else:
num_samples = (num_epochs * tokens_per_epoch - add_extra_token) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Beginning offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + add_extra_token
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += remaining_seq_length + doc_length - add_extra_token
remaining_seq_length = 0
else:
# Otherwise, start from the beginning of the next document.
if doc_idx_index == (len(doc_idx) - 1):
assert (
sample_index == num_samples
), f"sample_index={sample_index} and num_samples={num_samples} should be the same"
doc_offset = sizes[doc_idx[doc_idx_index]] - add_extra_token
break
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx