def data_generator()

in src/nanotron/data/dataloader.py [0:0]


    def data_generator() -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        # Random generator
        generator = torch.Generator(device="cuda")
        # Make sure that TP are synced always
        generator.manual_seed(
            seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg))
        )

        if use_position_ids:
            document_lengths = [[4, 6, sequence_length - 10]] + [[sequence_length]] * (micro_batch_size - 1)
            position_ids = torch.full(
                (micro_batch_size, sequence_length), fill_value=-1, dtype=torch.long, device="cuda"
            )
            for i in range(micro_batch_size):
                prev_idx = 0
                for doc_idx, doc_len in enumerate(document_lengths[i]):
                    position_ids[i, prev_idx : prev_idx + doc_len] = torch.arange(
                        0, doc_len, dtype=torch.long, device="cuda"
                    )
                    prev_idx += doc_len
            while True:
                yield {
                    "input_ids": torch.randint(
                        0,
                        vocab_size,
                        (micro_batch_size, sequence_length),
                        dtype=torch.long,
                        device="cuda",
                        generator=generator,
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
                    else TensorPointer(group_rank=input_pp_rank),
                    "position_ids": position_ids[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
                    else TensorPointer(group_rank=input_pp_rank),
                    "label_ids": torch.randint(
                        0,
                        vocab_size,
                        (micro_batch_size, sequence_length),
                        dtype=torch.long,
                        device="cuda",
                        generator=generator,
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
                    else TensorPointer(group_rank=output_pp_rank),
                    "label_mask": torch.ones(
                        micro_batch_size,
                        sequence_length,
                        dtype=torch.bool,
                        device="cuda",
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
                    else TensorPointer(group_rank=output_pp_rank),
                }
        else:
            while True:
                yield {
                    "input_ids": torch.randint(
                        0,
                        vocab_size,
                        (micro_batch_size, sequence_length),
                        dtype=torch.long,
                        device="cuda",
                        generator=generator,
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
                    else TensorPointer(group_rank=input_pp_rank),
                    "input_mask": torch.ones(
                        micro_batch_size,
                        sequence_length,
                        dtype=torch.bool,
                        device="cuda",
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
                    else TensorPointer(group_rank=input_pp_rank),
                    "label_ids": torch.randint(
                        0,
                        vocab_size,
                        (micro_batch_size, sequence_length),
                        dtype=torch.long,
                        device="cuda",
                        generator=generator,
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
                    else TensorPointer(group_rank=output_pp_rank),
                    "label_mask": torch.ones(
                        micro_batch_size,
                        sequence_length,
                        dtype=torch.bool,
                        device="cuda",
                    )[:, local_slice]
                    if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
                    else TensorPointer(group_rank=output_pp_rank),
                }