in src/nanotron/data/sft_processing.py [0:0]
def prepare_sft_dataset(raw_dataset, tokenizer, trainer_sequence_length, debug_max_samples=None, num_proc=1):
"""
Prepare a dataset for supervised fine-tuning by processing the examples
and filtering invalid samples.
Args:
raw_dataset: Dataset containing 'prompt' and 'completion' fields
tokenizer: HuggingFace tokenizer
trainer_sequence_length: Maximum sequence length for training
debug_max_samples: If set, limit the dataset to this many samples
num_proc: Number of processes for parallelization
Returns:
Processed dataset ready for training
"""
# If in debug mode, limit the dataset size before processing
if debug_max_samples is not None:
print(f"DEBUG MODE: Limiting dataset to {debug_max_samples} samples")
raw_dataset = raw_dataset.select(range(min(debug_max_samples, len(raw_dataset))))
# Create a wrapper function that handles empty examples correctly
def process_fn(examples):
# Check if there are any examples to process
if len(examples["prompt"]) == 0:
return {k: [] for k in ["input_ids", "position_ids", "label_ids", "label_mask", "attention_mask"]}
# Process the examples
result = process_sft(examples, tokenizer, trainer_sequence_length)
return result
# Apply the map function to process the dataset
train_dataset = raw_dataset.map(
process_fn, batched=True, remove_columns=raw_dataset.column_names, num_proc=num_proc
)
# Filter out examples where:
# 1. All position_ids are -1 (completely padding)
# 2. All label_mask values are False (no tokens contribute to loss)
def is_valid_sample(example):
# Check if there's at least one non-padding token
has_content = any(pos_id >= 0 for pos_id in example["position_ids"])
# Check if there's at least one token that contributes to loss
has_label = any(mask for mask in example["label_mask"])
# Sample is valid if it has content AND at least one token for loss
return has_content and has_label
# Apply the filter
original_size = len(train_dataset)
train_dataset = train_dataset.filter(is_valid_sample)
filtered_size = len(train_dataset)
# Log how many samples were filtered out
if original_size > filtered_size:
print(
f"Filtered out {original_size - filtered_size} samples ({(original_size - filtered_size) / original_size:.2%}) with no valid tokens for loss calculation"
)
return train_dataset