def __init__()

in picotron/data.py [0:0]


    def __init__(self,  micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True):
        self.micro_batch_size = micro_batch_size
        self.seq_length = seq_length
        self.grad_acc_steps = grad_acc_steps
        self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size
        self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
        
        self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
        self.dataset = load_dataset(dataset_name, split=split, name=subset_name)

        if pgm.process_group_manager.global_rank == 0:
            print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
            objects = [self.tokenizer]
        else:
            objects = [None]

        print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
        dist.broadcast_object_list(objects, src=0, device=device)
        self.tokenizer = objects[0]
        
        if num_samples:
            self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
        
        # Tokenize and chunk the dataset
        self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
        
        self.sampler = DistributedSampler(
            self.tokenized_dataset, 
            num_replicas=pgm.process_group_manager.dp_world_size, 
            rank=pgm.process_group_manager.dp_rank, 
            shuffle=False
        )
        
        super().__init__(
            self.tokenized_dataset,
            batch_size=micro_batch_size,
            collate_fn=self.collate_batch, 
            pin_memory=pin_memory, 
            num_workers=num_workers, 
            sampler=self.sampler, 
            shuffle=False
        )