def fit_weights()

in rbr_weight_fitter.py [0:0]


    def fit_weights(self):
        """
        Returns:
            weights (dict[str, float]): the optimized weights for the features
            metrics (dict[str, list]): some metrics from the optimization
        """
        self.validate_examples()
        pairs = self.get_orderings()
        opt, base_rewards, vals = self.prepare_optimization()

        split_idxs = lambda P: (p.squeeze(-1) for p in P.split(1, dim=-1))
        train_points = int(len(pairs) * self.train_data_frac)
        train_idxs1, train_idxs2, train_margins = split_idxs(pairs[:train_points])
        valid_idxs1, valid_idxs2, valid_margins = split_idxs(pairs[train_points:])

        for _ in tqdm(range(self.n_iters)):
            opt.zero_grad()
            loss, frac_clipped = self.get_loss(train_idxs1, train_idxs2, train_margins, base_rewards, vals)
            loss.backward()
            opt.step()

            valid_loss, _ = self.get_loss(valid_idxs1, valid_idxs2, valid_margins, base_rewards, vals)
            self.metrics["frac_clipped"].append(frac_clipped)
            self.metrics["loss"].append(loss.item())
            self.metrics["valid_loss"].append(valid_loss.item())

        weights_dict = {ft: w for ft, w in zip(self.feature_names, self.weights.detach().numpy())}
        return weights_dict, self.metrics