cached_classes.py (42 lines of code) (raw):
#Implements the RewardModelScores and PropositionProbabilities classes using cached data from file
from base_classes import RewardModelScores, FeaturesValues
from typing import List, Dict
from utils import read_jsonl, read_yaml, to_string
CONFIG_FILE_PATH = 'config/proposition_prompts.yaml' #not used in this example since everything is already cached, but it generated the cached data
CACHED_FEATURES = 'data/weight_fitting_data/prop_probs/{split}.jsonl'
CACHED_REWARD_MODEL_SCORES = 'data/weight_fitting_data/rewards/{rm_size}/{split}.jsonl'
RESPONSE_TYPES = ['Hard Refuse', 'Comply', 'Safe Refuse 1', 'Safe Refuse 2']
class CachedRewardModelScores(RewardModelScores):
def __init__(self, rm_size="large"):
train_data_raw = read_jsonl(CACHED_REWARD_MODEL_SCORES.format(rm_size=rm_size, split='train'))
test_data_raw = read_jsonl(CACHED_REWARD_MODEL_SCORES.format(rm_size=rm_size, split='test'))
#items are hashed by (prompt, completion) pairs
self.train_data = {tuple(d['prompt_completion']): d['rm_score'] for d in train_data_raw}
self.test_data = {tuple(d['prompt_completion']): d['rm_score'] for d in test_data_raw}
def get_reward_model_scores(self, prompt: List[Dict[str, str]], completions: List[List[Dict[str, str]]]) -> List[float]:
completion_rewards = []
for idx_c, completion in enumerate(completions):
lookup_hash = (to_string(prompt), to_string(completion))
if lookup_hash in self.train_data:
completion_rewards.append(self.train_data[lookup_hash])
elif lookup_hash in self.test_data:
completion_rewards.append(self.test_data[lookup_hash])
else:
raise ValueError(f"No cached reward model score found for prompt: {prompt}, completion: {idx_c}, {completion}")
return completion_rewards
class CachedFeaturesValues(FeaturesValues):
def __init__(self):
train_data_raw = read_jsonl(CACHED_FEATURES.format(split='train'))
test_data_raw = read_jsonl(CACHED_FEATURES.format(split='test'))
#items are hashed by (prompt, completion) pairs
self.train_data = {tuple(d['prompt_completion']): d['features'] for d in train_data_raw}
self.test_data = {tuple(d['prompt_completion']): d['features'] for d in test_data_raw}
#not used in this example since everything is already cached, but it generated the cached data
self.features_config = read_yaml(CONFIG_FILE_PATH)
def get_features(self, prompt: List[Dict[str, str]], completions: List[List[Dict[str, str]]]) -> List[float]:
completion_features = []
for idx_c, completion in enumerate(completions):
lookup_hash = (to_string(prompt), to_string(completion))
if lookup_hash in self.train_data:
completion_features.append(self.train_data[lookup_hash])
elif lookup_hash in self.test_data:
completion_features.append(self.test_data[lookup_hash])
else:
raise ValueError(f"No cached features found for prompt: {prompt}, completion: {idx_c}, {completion}")
return completion_features