in src/nanotron/data/nemo_dataset/__init__.py [0:0]
def _get_text(self, idx: int) -> np.ndarray:
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f, offset_f = self.sample_idx[idx]
doc_index_l, offset_l = self.sample_idx[idx + 1]
# offset_f = self.sample_idx[idx][1]
# offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + self.add_extra_token
)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(
self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + self.add_extra_token)
)
sample = np.concatenate(sample_list)
if len(sample) != (self.seq_length + self.add_extra_token):
log_rank(
f" > WARNING: Got sample of length: {len(sample)} for sequence length={self.seq_length+self.add_extra_token}, padding the sample to match sequence length",
logger=logger,
level=logging.WARNING,
rank=0,
)
sample = np.array(sample, dtype=np.int64)
sample = np.pad(
sample, (0, self.seq_length + self.add_extra_token - len(sample)), mode="constant", constant_values=-1
)
if self.fim_rate == 0:
return sample.astype(np.int64)
# Code from: https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L109
# TODO(Hailey): can merge the code below this line with code above this line.
# TODO(Hailey), cont: above already iterates through loop, so just add the permuting in there?
sample = np.array(sample, dtype=np.int64)
sample_len = sample.shape[0]
# # print(sample, sample.shape)
# # do FIM here, if enabled
# TODO: Do we handle the following point from FIM paper?
# To transform data in the character space for context-level FIM, the tokenized documents have to be decoded back into strings before FIM augmentation. Depending on the vocabulary, some care has to be given to ensure decoding does not introduce any spurious characters into training. For example, utf-8 characters are encoded as multiple tokens with a BPE vocabulary; they can result in fragments from chunking and fail to decode. To prevent unforeseen errors midway through training, we encourage checking for these fragments at the beginning or end of a context and removing them.
segment_breaks = np.argwhere(sample == self.eod_tok_id) # split sample by document
def fim_permute_sequence(sequence, rate):
return permute(
sequence,
self.np_rng,
rate,
self.fim_spm_rate,
self.tokenizer,
truncate_or_pad=False,
suffix_tok_id=self.suffix_tok_id,
prefix_tok_id=self.prefix_tok_id,
middle_tok_id=self.middle_tok_id,
pad_tok_id=self.pad_tok_id,
no_fim_prefix=self.no_fim_prefix,
)
def fim_split_and_permute_sequence(sequence):
"""
If self.fim_split_sample is not None, split the sequence.
Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None.
"""
if self.fim_split_sample is None:
return fim_permute_sequence(sequence, self.fim_rate)
# fim_split_sample is set: split the sample on this token and permute each fragment separately.
# Typically, if each sample is a repository, then we split again on the file level.
# Each fragment is a file, and we permute the files.
fragment_breaks = np.argwhere(sequence == self.fim_split_sample)
if fragment_breaks.shape == (0, 1):
# no split token in this sample
return fim_permute_sequence(sequence, self.fim_rate)
if not self.np_rng.binomial(1, self.fim_rate):
# don't do FIM preproc
return sequence
# Do FIM on each fragment
curr_start_position = 0
new_samples = []
for loc in np.nditer(fragment_breaks):
if loc - curr_start_position > 0:
permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate)
new_samples += [permuted, [self.fim_split_sample]]
curr_start_position = loc + 1 # Jump over the split token
# Permute the segment after the last split token
permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate)
new_samples.append(permuted)
return np.concatenate(new_samples)
if segment_breaks.shape != (0, 1): # then there is an EOD token in this example
curr_start_position = 0
new_samples = []
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc])
new_samples += [permuted, [self.eod_tok_id]]
curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = fim_split_and_permute_sequence(sample[curr_start_position:])
new_samples.append(permuted)
sample = np.concatenate(new_samples)
else:
sample = fim_split_and_permute_sequence(sample)
# Truncate or pad sequence to max-length
diff = sample.shape[0] - sample_len
if diff > 0: # too long
sample = sample[:sample_len]
elif diff < 0: # too short
sample = np.concatenate([sample, np.full((-1 * diff), self.pad_tok_id)])
assert sample.shape[0] == sample_len
# end FIM-specific code
return sample.astype(np.int64)