in 5-4o_fine_tuning/eval.py [0:0]
def fetch_few_shot_train_examples(prompt: str, num_examples: int = 0, use_similarity: bool = False):
"""
Fetches few-shot training examples from the "patched-codes/synth-vuln-fixes" dataset based on the given prompt.
Args:
prompt (str): The input prompt for which few-shot examples are to be fetched.
num_examples (int, optional): The number of few-shot examples to fetch. Defaults to 0.
use_similarity (bool, optional): If True, uses a similarity-based approach to fetch examples. Defaults to False.
Returns:
list: A list of few-shot training examples in the form of dialogue messages.
The function operates in two modes:
1. Random Selection: If use_similarity is False, it randomly selects num_examples from the dataset.
2. Similarity-based Selection: If use_similarity is True, it uses a two-step process:
a. Initial Retrieval: Uses a lightweight model to encode the prompt and user messages from the dataset,
and retrieves the top_k most similar examples based on cosine similarity.
b. Reranking: Uses a cross-encoder model to rerank the initially retrieved examples and selects the top num_examples.
The function returns a list of few-shot training examples, excluding system messages.
"""
dataset = load_dataset("patched-codes/synth-vuln-fixes", split="train")
if use_similarity:
# lightweight model for initial retrieval
retrieval_model = SentenceTransformer('all-MiniLM-L6-v2')
# cross-encoder model for reranking
rerank_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Extract user messages
user_messages = [
next(msg['content']
for msg in item['messages'] if msg['role'] == 'user')
for item in dataset
]
# encode the prompt and user messages for initial retrieval
prompt_embedding = retrieval_model.encode(
prompt, convert_to_tensor=False)
corpus_embeddings = retrieval_model.encode(
user_messages, convert_to_tensor=False, show_progress_bar=True)
similarities = cosine_similarity(
[prompt_embedding], corpus_embeddings)[0]
top_k = min(100, len(dataset))
top_indices = similarities.argsort()[-top_k:][::-1]
# reranking
rerank_pairs = [[prompt, user_messages[idx]] for idx in top_indices]
# rerank using the cross-encoder model
rerank_scores = rerank_model.predict(rerank_pairs)
# Select top num_examples based on rerank scores
reranked_indices = [top_indices[i]
for i in np.argsort(rerank_scores)[::-1][:num_examples]]
top_indices = reranked_indices
else:
top_indices = np.random.choice(
len(dataset), num_examples, replace=False)
few_shot_messages = []
for index in top_indices:
py_index = int(index)
messages = dataset[py_index]["messages"]
dialogue = [msg for msg in messages if msg['role'] != 'system']
few_shot_messages.extend(dialogue)
return few_shot_messages