in train.py [0:0]
def get_dataloaders(train_cfg, vlm_cfg):
# Create datasets
image_processor = get_image_processor(vlm_cfg.vit_img_size)
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)
# Load and combine all training datasets
combined_train_data = []
for dataset_name in train_cfg.train_dataset_name:
train_ds = load_dataset(train_cfg.train_dataset_path, dataset_name)
combined_train_data.append(train_ds['train'])
train_ds = concatenate_datasets(combined_train_data)
test_ds = load_dataset(train_cfg.test_dataset_path)
train_ds = train_ds.shuffle(seed=0) # Shuffle the training dataset, so train and val get equal contributions from all concatenated datasets
if is_dist(): # We need to shard the dataset in DDP since we are using an iterable dataset instead of the distributed sampler
train_ds = train_ds.shard(num_shards=get_world_size(), index=get_rank())
# Apply cutoff if specified
if train_cfg.data_cutoff_idx is None:
total_samples = len(train_ds) # Use the entire dataset
else:
total_samples = min(len(train_ds), train_cfg.data_cutoff_idx)
val_size = int(total_samples * train_cfg.val_ratio)
train_size = total_samples - val_size
train_dataset = VQADataset(train_ds.select(range(train_size)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)
train_dataset = ConstantLengthDataset(train_dataset, infinite=False, max_sample_length=train_cfg.max_sample_length, seq_length=vlm_cfg.lm_max_length, num_of_sequences=train_cfg.batch_size*64, queue_size=train_cfg.batch_size*64*2,
max_images_per_example=train_cfg.max_images_per_example, max_images_per_knapsack=train_cfg.max_images_per_knapsack)
val_dataset = VQADataset(train_ds.select(range(train_size, total_samples)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)
# Create collators
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length)
g = torch.Generator()
g.manual_seed(0)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=train_cfg.batch_size, # =per device BS in DDP
collate_fn=vqa_collator,
num_workers=8,
pin_memory=True,
drop_last=True,
worker_init_fn=seed_worker,
generator=g,
)
val_sampler = DistributedSampler(
val_dataset,
rank=get_rank(),
num_replicas=get_world_size(),
shuffle=False # Usually False for validation
)
val_loader = DataLoader(
val_dataset,
batch_size=train_cfg.batch_size,
sampler=val_sampler,
collate_fn=vqa_collator,
num_workers=8,
pin_memory=True,
drop_last=True,
worker_init_fn=seed_worker,
generator=g,
)
return train_loader, val_loader