def __init__()

in step5_data_parallel_naive/dataloader.py [0:0]


    def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"):
        
        self.micro_batch_size = micro_batch_size
        self.grad_acc_steps = grad_acc_steps
        self.seq_len = seq_len

        self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.dataset = load_dataset(dataset_name, split=split)

        # Tokenize and chunk the dataset
        self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc)
        
        total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1)
        assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens"

        self.sampler = DistributedSampler(
            self.tokenized_dataset, 
            num_replicas=pgm.process_group_manager.dp_world_size, 
            rank=pgm.process_group_manager.dp_rank, 
            seed=seed,
            shuffle=False
        )

        super().__init__(
            self.tokenized_dataset,
            batch_size=micro_batch_size,
            collate_fn=self.collate_batch, 
            pin_memory=True, 
            num_workers=num_workers, 
            sampler=self.sampler,
            shuffle=False,
        )