def spffn()

in crlapi/sl/clmodels/firefly.py [0:0]


    def spffn(self, net, loader, n_batches):
        v_params = []
        for i, layer in enumerate(net):
            if isinstance(layer, sp.SpModule):
                enlarge_in = (i > 0)
                enlarge_out = (i < len(net)-1)
                net[i].spffn_add_new(enlarge_in=enlarge_in, enlarge_out=enlarge_out)
                net[i].spffn_reset()
                if layer.can_split:
                    v_params += [net[i].v]
                if enlarge_in:
                    v_params += [net[i].vni]
                if enlarge_out:
                    v_params += [net[i].vno]

        opt_v = torch.optim.RMSprop(nn.ParameterList(v_params), lr=1e-3, momentum=0.1, alpha=0.9)

        self.device = next(iter(net.parameters())).device

        torch.cuda.empty_cache()

        n_batches = 0
        for i, (x, y) in enumerate(loader):
            n_batches += 1

            x, y = x.to(self.device), y.to(self.device)
            loss = self.spffn_loss_fn(net, x, y)
            opt_v.zero_grad()
            loss.backward()
            for layer in net:
                if isinstance(layer, sp.SpModule):
                    layer.spffn_penalty()
            opt_v.step()

        self.config.model.granularity = 1
        alphas = np.linspace(0, 1, self.config.model.granularity*2+1)
        for alpha in alphas[1::2]:
            for x, y in loader:
                x, y = x.to(self.device), y.to(self.device)
                loss = self.spffn_loss_fn(net, x, y, alpha=1.0)
                opt_v.zero_grad()
                loss.backward()
                # for i in self.layers_to_split:
                for i in net.layers_to_split:
                    net[i].spffn_update_w(self.config.model.granularity * n_batches, output = False)