in picotron/data.py [0:0]
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps
self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
self.dataset = load_dataset(dataset_name, split=split, name=subset_name)
if pgm.process_group_manager.global_rank == 0:
print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
objects = [self.tokenizer]
else:
objects = [None]
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
dist.broadcast_object_list(objects, src=0, device=device)
self.tokenizer = objects[0]
if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
# Tokenize and chunk the dataset
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=pin_memory,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
)