def process_sft()

in src/nanotron/data/sft_processing.py [0:0]


def process_sft(examples, tokenizer, trainer_sequence_length):
    """
    Process examples for supervised fine-tuning by:
    1. Tokenizing prompts and completions separately
    2. Combining them into full examples with EOS token
    3. Creating position_ids for each token in the sequence
    4. Creating label_mask that only enables loss on completion tokens

    Args:
        examples: Dictionary with 'prompt' and 'completion' fields
        tokenizer: HuggingFace tokenizer
        trainer_sequence_length: Maximum sequence length for training

    Returns:
        Dictionary with processed tokens including:
        - input_ids: Combined tokenized sequence
        - position_ids: Sequential position IDs for each token
        - label_mask: Boolean mask with True only for completion tokens
        - attention_mask: Attention mask for padding
        - label_ids: Same as input_ids, used for loss calculation
    """
    # First tokenize prompts and completions separately to get lengths
    tokenizer(examples["prompt"], padding=False, truncation=False, return_tensors=None)

    tokenizer(examples["completion"], padding=False, truncation=False, return_tensors=None)

    # Combine prompt and completion with EOS token
    texts = [
        f"{prompt}{completion}{tokenizer.eos_token}"
        for prompt, completion in zip(examples["prompt"], examples["completion"])
    ]

    # Use trainer_sequence_length + 1 to match collator's expectation
    max_length = trainer_sequence_length + 1

    # Tokenize combined text
    tokenized = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

    # Get filtered prompt tokens for length calculation
    filtered_prompt_tokens = tokenizer(examples["prompt"], padding=False, truncation=False, return_tensors=None)

    # Get attention mask and convert to bool
    attention_mask = tokenized["attention_mask"].bool()
    batch_size, seq_length = attention_mask.shape

    # Create sequential position_ids initialized to -1 (for padding)
    position_ids = torch.full((batch_size, seq_length), fill_value=-1, dtype=torch.long)

    # Create label_mask (initialize to False)
    label_mask = torch.zeros((batch_size, seq_length), dtype=torch.bool)

    # For each sequence in the batch
    for i in range(batch_size):
        # Get the actual prompt length, but ensure we don't exceed sequence length
        prompt_length = min(len(filtered_prompt_tokens["input_ids"][i]), seq_length)

        # Set position ids for all tokens (prompt and completion) to sequential values
        # But only where attention_mask is True (non-padding tokens)
        valid_length = attention_mask[i].sum().item()
        position_ids[i, :valid_length] = torch.arange(valid_length)

        # Set label_mask to True only for completion tokens
        # If prompt consumes the entire sequence, no tokens are used for loss
        if prompt_length < seq_length:
            # Set completion tokens label mask to True (rest remains False)
            label_mask[i, prompt_length:valid_length] = True

    # Create label_ids (same as input_ids)
    tokenized["label_ids"] = tokenized["input_ids"].clone()

    # Add the created tensors
    tokenized["position_ids"] = position_ids
    tokenized["label_mask"] = label_mask

    # Keep attention_mask for model's use
    tokenized["attention_mask"] = attention_mask

    # Log examples where prompt consumes all tokens
    too_long_prompts = sum(1 for i in range(batch_size) if not label_mask[i].any())
    if too_long_prompts > 0:
        print(
            f"Warning: {too_long_prompts}/{batch_size} examples have prompts that are too long, no completion tokens"
        )

    return tokenized