in src/nanotron/data/sft_processing.py [0:0]
def process_sft(examples, tokenizer, trainer_sequence_length):
"""
Process examples for supervised fine-tuning by:
1. Tokenizing prompts and completions separately
2. Combining them into full examples with EOS token
3. Creating position_ids for each token in the sequence
4. Creating label_mask that only enables loss on completion tokens
Args:
examples: Dictionary with 'prompt' and 'completion' fields
tokenizer: HuggingFace tokenizer
trainer_sequence_length: Maximum sequence length for training
Returns:
Dictionary with processed tokens including:
- input_ids: Combined tokenized sequence
- position_ids: Sequential position IDs for each token
- label_mask: Boolean mask with True only for completion tokens
- attention_mask: Attention mask for padding
- label_ids: Same as input_ids, used for loss calculation
"""
# First tokenize prompts and completions separately to get lengths
tokenizer(examples["prompt"], padding=False, truncation=False, return_tensors=None)
tokenizer(examples["completion"], padding=False, truncation=False, return_tensors=None)
# Combine prompt and completion with EOS token
texts = [
f"{prompt}{completion}{tokenizer.eos_token}"
for prompt, completion in zip(examples["prompt"], examples["completion"])
]
# Use trainer_sequence_length + 1 to match collator's expectation
max_length = trainer_sequence_length + 1
# Tokenize combined text
tokenized = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
# Get filtered prompt tokens for length calculation
filtered_prompt_tokens = tokenizer(examples["prompt"], padding=False, truncation=False, return_tensors=None)
# Get attention mask and convert to bool
attention_mask = tokenized["attention_mask"].bool()
batch_size, seq_length = attention_mask.shape
# Create sequential position_ids initialized to -1 (for padding)
position_ids = torch.full((batch_size, seq_length), fill_value=-1, dtype=torch.long)
# Create label_mask (initialize to False)
label_mask = torch.zeros((batch_size, seq_length), dtype=torch.bool)
# For each sequence in the batch
for i in range(batch_size):
# Get the actual prompt length, but ensure we don't exceed sequence length
prompt_length = min(len(filtered_prompt_tokens["input_ids"][i]), seq_length)
# Set position ids for all tokens (prompt and completion) to sequential values
# But only where attention_mask is True (non-padding tokens)
valid_length = attention_mask[i].sum().item()
position_ids[i, :valid_length] = torch.arange(valid_length)
# Set label_mask to True only for completion tokens
# If prompt consumes the entire sequence, no tokens are used for loss
if prompt_length < seq_length:
# Set completion tokens label mask to True (rest remains False)
label_mask[i, prompt_length:valid_length] = True
# Create label_ids (same as input_ids)
tokenized["label_ids"] = tokenized["input_ids"].clone()
# Add the created tensors
tokenized["position_ids"] = position_ids
tokenized["label_mask"] = label_mask
# Keep attention_mask for model's use
tokenized["attention_mask"] = attention_mask
# Log examples where prompt consumes all tokens
too_long_prompts = sum(1 for i in range(batch_size) if not label_mask[i].any())
if too_long_prompts > 0:
print(
f"Warning: {too_long_prompts}/{batch_size} examples have prompts that are too long, no completion tokens"
)
return tokenized