in crlapi/sl/clmodels/firefly.py [0:0]
def spffn(self, net, loader, n_batches):
v_params = []
for i, layer in enumerate(net):
if isinstance(layer, sp.SpModule):
enlarge_in = (i > 0)
enlarge_out = (i < len(net)-1)
net[i].spffn_add_new(enlarge_in=enlarge_in, enlarge_out=enlarge_out)
net[i].spffn_reset()
if layer.can_split:
v_params += [net[i].v]
if enlarge_in:
v_params += [net[i].vni]
if enlarge_out:
v_params += [net[i].vno]
opt_v = torch.optim.RMSprop(nn.ParameterList(v_params), lr=1e-3, momentum=0.1, alpha=0.9)
self.device = next(iter(net.parameters())).device
torch.cuda.empty_cache()
n_batches = 0
for i, (x, y) in enumerate(loader):
n_batches += 1
x, y = x.to(self.device), y.to(self.device)
loss = self.spffn_loss_fn(net, x, y)
opt_v.zero_grad()
loss.backward()
for layer in net:
if isinstance(layer, sp.SpModule):
layer.spffn_penalty()
opt_v.step()
self.config.model.granularity = 1
alphas = np.linspace(0, 1, self.config.model.granularity*2+1)
for alpha in alphas[1::2]:
for x, y in loader:
x, y = x.to(self.device), y.to(self.device)
loss = self.spffn_loss_fn(net, x, y, alpha=1.0)
opt_v.zero_grad()
loss.backward()
# for i in self.layers_to_split:
for i in net.layers_to_split:
net[i].spffn_update_w(self.config.model.granularity * n_batches, output = False)