def prepare_sft_dataset()

in src/nanotron/data/sft_processing.py [0:0]


def prepare_sft_dataset(raw_dataset, tokenizer, trainer_sequence_length, debug_max_samples=None, num_proc=1):
    """
    Prepare a dataset for supervised fine-tuning by processing the examples
    and filtering invalid samples.

    Args:
        raw_dataset: Dataset containing 'prompt' and 'completion' fields
        tokenizer: HuggingFace tokenizer
        trainer_sequence_length: Maximum sequence length for training
        debug_max_samples: If set, limit the dataset to this many samples
        num_proc: Number of processes for parallelization

    Returns:
        Processed dataset ready for training
    """
    # If in debug mode, limit the dataset size before processing
    if debug_max_samples is not None:
        print(f"DEBUG MODE: Limiting dataset to {debug_max_samples} samples")
        raw_dataset = raw_dataset.select(range(min(debug_max_samples, len(raw_dataset))))

    # Create a wrapper function that handles empty examples correctly
    def process_fn(examples):
        # Check if there are any examples to process
        if len(examples["prompt"]) == 0:
            return {k: [] for k in ["input_ids", "position_ids", "label_ids", "label_mask", "attention_mask"]}

        # Process the examples
        result = process_sft(examples, tokenizer, trainer_sequence_length)
        return result

    # Apply the map function to process the dataset
    train_dataset = raw_dataset.map(
        process_fn, batched=True, remove_columns=raw_dataset.column_names, num_proc=num_proc
    )

    # Filter out examples where:
    # 1. All position_ids are -1 (completely padding)
    # 2. All label_mask values are False (no tokens contribute to loss)
    def is_valid_sample(example):
        # Check if there's at least one non-padding token
        has_content = any(pos_id >= 0 for pos_id in example["position_ids"])

        # Check if there's at least one token that contributes to loss
        has_label = any(mask for mask in example["label_mask"])

        # Sample is valid if it has content AND at least one token for loss
        return has_content and has_label

    # Apply the filter
    original_size = len(train_dataset)
    train_dataset = train_dataset.filter(is_valid_sample)
    filtered_size = len(train_dataset)

    # Log how many samples were filtered out
    if original_size > filtered_size:
        print(
            f"Filtered out {original_size - filtered_size} samples ({(original_size - filtered_size) / original_size:.2%}) with no valid tokens for loss calculation"
        )

    return train_dataset