def fetch_few_shot_train_examples()

in 5-4o_fine_tuning/eval.py [0:0]


def fetch_few_shot_train_examples(prompt: str, num_examples: int = 0, use_similarity: bool = False):
    """
    Fetches few-shot training examples from the "patched-codes/synth-vuln-fixes" dataset based on the given prompt.

    Args:
        prompt (str): The input prompt for which few-shot examples are to be fetched.
        num_examples (int, optional): The number of few-shot examples to fetch. Defaults to 0.
        use_similarity (bool, optional): If True, uses a similarity-based approach to fetch examples. Defaults to False.

    Returns:
        list: A list of few-shot training examples in the form of dialogue messages.

    The function operates in two modes:
    1. Random Selection: If use_similarity is False, it randomly selects num_examples from the dataset.
    2. Similarity-based Selection: If use_similarity is True, it uses a two-step process:
        a. Initial Retrieval: Uses a lightweight model to encode the prompt and user messages from the dataset,
           and retrieves the top_k most similar examples based on cosine similarity.
        b. Reranking: Uses a cross-encoder model to rerank the initially retrieved examples and selects the top num_examples.

    The function returns a list of few-shot training examples, excluding system messages.
    """
    dataset = load_dataset("patched-codes/synth-vuln-fixes", split="train")
    if use_similarity:
        # lightweight model for initial retrieval
        retrieval_model = SentenceTransformer('all-MiniLM-L6-v2')
        # cross-encoder model for reranking
        rerank_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        # Extract user messages
        user_messages = [
            next(msg['content']
                 for msg in item['messages'] if msg['role'] == 'user')
            for item in dataset
        ]
        # encode the prompt and user messages for initial retrieval
        prompt_embedding = retrieval_model.encode(
            prompt, convert_to_tensor=False)
        corpus_embeddings = retrieval_model.encode(
            user_messages, convert_to_tensor=False, show_progress_bar=True)

        similarities = cosine_similarity(
            [prompt_embedding], corpus_embeddings)[0]
        top_k = min(100, len(dataset))
        top_indices = similarities.argsort()[-top_k:][::-1]

        # reranking
        rerank_pairs = [[prompt, user_messages[idx]] for idx in top_indices]

        # rerank using the cross-encoder model
        rerank_scores = rerank_model.predict(rerank_pairs)

        # Select top num_examples based on rerank scores
        reranked_indices = [top_indices[i]
                            for i in np.argsort(rerank_scores)[::-1][:num_examples]]
        top_indices = reranked_indices
    else:
        top_indices = np.random.choice(
            len(dataset), num_examples, replace=False)

    few_shot_messages = []
    for index in top_indices:
        py_index = int(index)
        messages = dataset[py_index]["messages"]

        dialogue = [msg for msg in messages if msg['role'] != 'system']
        few_shot_messages.extend(dialogue)

    return few_shot_messages