rbr_weight_fitter.py (102 lines of code) (raw):
from dataclasses import dataclass, field
from itertools import combinations
from typing import Optional, Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
@dataclass
class RBR_ExampleWithMetadata:
convo_prompt: str
convo_completion: str
response_type: str
completion_label: str
base_reward: Optional[float] = None
features: Optional[Dict[str, float]] = field(default_factory=dict)
def add_base_reward(self, base_reward: float):
self.base_reward = base_reward
def add_features(self, features: Dict[str, float]):
self.features.update(features)
class RBRWeightFitter:
def __init__(self,
feature_names: List[str],
examples: List[RBR_ExampleWithMetadata],
orderings: Dict[str, List[str]],
ignore_features: List[str] = [],
#optimization hyperparameters
train_data_frac: float = 0.95,
margin: float = 1,
lr: float = 1e-2,
wd: float = 0.0,
n_iters: int = 1000,
completion_label_margin: Dict[str, float] = {}):
"""
Optimize a linear function for combining the RBR with the Reward Model by optimizing the hinge loss.
Args:
feature_names (list[str]): the list of feature names
examples (list[RBR_ExampleWithMetadata]): the list of examples to optimize on
orderings (dict[str, list[str]]): a dict of str -> list[str] where an example whose label is a key is considered better than any example whose label is in the corresponding values.
ignore_features (list[str]): the list of features to ignore
train_data_frac (float): the fraction of data to use for training
margin (float): the margin for the hinge loss
lr (float): the learning rate for the optimizer
wd (float): the weight decay for the optimizer
n_iters (int): the number of iterations to run the optimizer
completion_label_margin (dict[str, float]): Pass in specific margins for completion labels to override the default margin. (ex may want to set a higher margin for "ideal" completions)
"""
self.feature_names = feature_names
self.examples = examples
self.orderings = orderings
self.ignore_features = ignore_features
self.train_data_frac = train_data_frac
self.margin = margin
self.lr = lr
self.wd = wd
self.n_iters = n_iters
self.completion_label_margin = completion_label_margin
self.weights = torch.zeros(len(feature_names), requires_grad=True)
self.metrics = {"frac_clipped": [], "loss": [], "valid_loss": []}
def validate_examples(self):
for ex in self.examples:
assert ex.completion_label in self.orderings, "All examples must have a valid completion label."
assert set(ex.features.keys()) == set(self.feature_names), "Must have values for all features."
def get_orderings(self):
def apply_ordering(i1, i2):
ex1, ex2 = self.examples[i1], self.examples[i2]
if not ex1.convo_prompt == ex2.convo_prompt:
return None
if ex2.completion_label in self.orderings[ex1.completion_label]:
margin = self.completion_label_margin.get(ex1.completion_label, 1)
return [i1, i2, margin]
elif ex1.completion_label in self.orderings[ex2.completion_label]:
margin = self.completion_label_margin.get(ex2.completion_label, 1)
return [i2, i1, margin]
return None
pairs = list(combinations(range(len(self.examples)), 2))
pairs = [apply_ordering(i1, i2) for i1, i2 in pairs]
pairs = [p for p in pairs if p is not None]
np.random.default_rng().shuffle(pairs)
pairs = torch.tensor(pairs, dtype=torch.int64)
return pairs
def zero_some_feature_weights(self, grad):
ignore_feature_idxs = [ft in self.ignore_features for ft in self.feature_names]
grad[ignore_feature_idxs] = 0
return grad
def prepare_optimization(self):
self.weights.register_hook(self.zero_some_feature_weights)
opt = torch.optim.AdamW([self.weights], lr=self.lr, weight_decay=self.wd)
base_rewards = torch.tensor([ex.base_reward for ex in self.examples])
vals = torch.tensor([[ex.features[ft] for ft in self.feature_names] for ex in self.examples])
assert vals.shape[0] == len(self.examples), f"vals shape[0] should be {len(self.examples)} but got {vals.shape[0]}"
assert vals.shape[1] == len(self.feature_names), f"vals shape[1] should be {len(self.feature_names)} but got {vals.shape[1]}"
return opt, base_rewards, vals
def get_loss(self, idxs1, idxs2, margins, base_rewards, vals):
total_scores = base_rewards + vals @ self.weights
per_example_loss = F.relu(((self.margin * margins) + total_scores[idxs2]) - total_scores[idxs1])
frac_clipped = 1 - (per_example_loss > 0).float().mean().item()
loss = per_example_loss.mean()
return loss, frac_clipped
def fit_weights(self):
"""
Returns:
weights (dict[str, float]): the optimized weights for the features
metrics (dict[str, list]): some metrics from the optimization
"""
self.validate_examples()
pairs = self.get_orderings()
opt, base_rewards, vals = self.prepare_optimization()
split_idxs = lambda P: (p.squeeze(-1) for p in P.split(1, dim=-1))
train_points = int(len(pairs) * self.train_data_frac)
train_idxs1, train_idxs2, train_margins = split_idxs(pairs[:train_points])
valid_idxs1, valid_idxs2, valid_margins = split_idxs(pairs[train_points:])
for _ in tqdm(range(self.n_iters)):
opt.zero_grad()
loss, frac_clipped = self.get_loss(train_idxs1, train_idxs2, train_margins, base_rewards, vals)
loss.backward()
opt.step()
valid_loss, _ = self.get_loss(valid_idxs1, valid_idxs2, valid_margins, base_rewards, vals)
self.metrics["frac_clipped"].append(frac_clipped)
self.metrics["loss"].append(loss.item())
self.metrics["valid_loss"].append(valid_loss.item())
weights_dict = {ft: w for ft, w in zip(self.feature_names, self.weights.detach().numpy())}
return weights_dict, self.metrics