def __init__()

in step3_dataloader/dataloader.py [0:0]


    def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, split="train"):
        
        self.micro_batch_size = micro_batch_size
        self.grad_acc_steps = grad_acc_steps
        self.seq_len = seq_len

        self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.dataset = load_dataset(dataset_name, split=split)

        # Tokenize and chunk the dataset
        self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc)
        
        total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1)
        assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens"

        super().__init__(
            self.tokenized_dataset,
            batch_size=micro_batch_size,
            collate_fn=self.collate_batch, 
            pin_memory=True, 
            num_workers=num_workers, 
            shuffle=False,
        )