in rbr_weight_fitter.py [0:0]
def fit_weights(self):
"""
Returns:
weights (dict[str, float]): the optimized weights for the features
metrics (dict[str, list]): some metrics from the optimization
"""
self.validate_examples()
pairs = self.get_orderings()
opt, base_rewards, vals = self.prepare_optimization()
split_idxs = lambda P: (p.squeeze(-1) for p in P.split(1, dim=-1))
train_points = int(len(pairs) * self.train_data_frac)
train_idxs1, train_idxs2, train_margins = split_idxs(pairs[:train_points])
valid_idxs1, valid_idxs2, valid_margins = split_idxs(pairs[train_points:])
for _ in tqdm(range(self.n_iters)):
opt.zero_grad()
loss, frac_clipped = self.get_loss(train_idxs1, train_idxs2, train_margins, base_rewards, vals)
loss.backward()
opt.step()
valid_loss, _ = self.get_loss(valid_idxs1, valid_idxs2, valid_margins, base_rewards, vals)
self.metrics["frac_clipped"].append(frac_clipped)
self.metrics["loss"].append(loss.item())
self.metrics["valid_loss"].append(valid_loss.item())
weights_dict = {ft: w for ft, w in zip(self.feature_names, self.weights.detach().numpy())}
return weights_dict, self.metrics