def prepare_optimization()

in rbr_weight_fitter.py [0:0]


    def prepare_optimization(self):
        self.weights.register_hook(self.zero_some_feature_weights)
        opt = torch.optim.AdamW([self.weights], lr=self.lr, weight_decay=self.wd)
        base_rewards = torch.tensor([ex.base_reward for ex in self.examples])
        vals = torch.tensor([[ex.features[ft] for ft in self.feature_names] for ex in self.examples])
        assert vals.shape[0] == len(self.examples), f"vals shape[0] should be {len(self.examples)} but got {vals.shape[0]}"
        assert vals.shape[1] == len(self.feature_names), f"vals shape[1] should be {len(self.feature_names)} but got {vals.shape[1]}"
        return opt, base_rewards, vals