in src/nanotron/data/dataloader.py [0:0]
def data_generator() -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Random generator
generator = torch.Generator(device="cuda")
# Make sure that TP are synced always
generator.manual_seed(
seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg))
)
if use_position_ids:
document_lengths = [[4, 6, sequence_length - 10]] + [[sequence_length]] * (micro_batch_size - 1)
position_ids = torch.full(
(micro_batch_size, sequence_length), fill_value=-1, dtype=torch.long, device="cuda"
)
for i in range(micro_batch_size):
prev_idx = 0
for doc_idx, doc_len in enumerate(document_lengths[i]):
position_ids[i, prev_idx : prev_idx + doc_len] = torch.arange(
0, doc_len, dtype=torch.long, device="cuda"
)
prev_idx += doc_len
while True:
yield {
"input_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"position_ids": position_ids[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"label_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
"label_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
}
else:
while True:
yield {
"input_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"input_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"label_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
"label_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)[:, local_slice]
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
}