def split()

in crlapi/sl/clmodels/firefly.py [0:0]


    def split(self, net, loader, grow_ratio, n_batches=-1, split_method='fireflyn'):

        self.num_group = 1# if self.config.backbne != 'mobile' else 2
        if split_method not in ['random', 'exact', 'fast', 'firefly', 'fireflyn']:
            raise NotImplementedError

        if self.verbose:
            print('[INFO] start splitting ...')

        start_time = time.time()
        net.eval()

        if self.num_group == 1:
            self.get_num_elites(net, grow_ratio)
        else:
            self.get_num_elites_group(net, grow_ratio, self.num_group)
        split_fn = {
            #'exact': self.spe,
            #'fast': self.spf,
            #'firefly': self.spff,
            'fireflyn': self.spffn,
        }

        if split_method != 'random':
            split_fn[split_method](net, loader, n_batches)

        n_neurons_added = {}

        if split_method == 'random':
            # n_layers = len(self.layers_to_split)
            n_layers = len(net.layers_to_split)
            n_total_neurons = 0
            threshold = 0.
            # for l in self.layers_to_split:
            for l in net.layers_to_split:
                n_total_neurons += net[l].get_n_neurons()
            n_grow = int(n_total_neurons * grow_ratio)
            n_new1 = np.random.choice(n_grow, n_layers, replace=False)
            n_new1 = np.sort(n_new1)
            n_news = []
            for i in range(len(n_new1) - 1):
                if i == 0:
                    n_news.append(n_new1[i])
                    n_news.append(n_new1[i + 1] - n_new1[i])
                else:
                    n_news.append(n_new1[i + 1] - n_new1[i])
                    n_news[-1] += 1
            # for i, n_new_ in zip(reversed(self.layers_to_split), n_news):
            for i, n_new_ in zip(reversed(net.layers_to_split), n_news):
                if isinstance(net[i], sp.SpModule) and net[i].can_split:
                    n_new, idx = net[i].random_split(n_new_)
                    n_neurons_added[i] = n_new
                    if n_new > 0: # we have indeed splitted this layer
                        # for j in self.next_layers[i]:
                        for j in net.next_layers[i]:
                            net[j].passive_split(idx)
        elif split_method == 'fireflyn':
            if self.num_group == 1:
                threshold = self.sp_threshold(net)
            # for i in reversed(self.layers_to_split):
            for i in reversed(net.layers_to_split):
                if isinstance(net[i], sp.SpModule) and net[i].can_split:
                    if self.num_group != 1:
                        group = self.total_group[i]
                        threshold = self.sp_threshold_group(net, group)
                    n_new, split_idx, new_idx = net[i].spffn_active_grow(threshold)
                    sp_new = split_idx.shape[0] if split_idx is not None else 0
                    n_neurons_added[i] = (sp_new, n_new-sp_new)
                    if net[i].kh == 1:
                        isfirst = True
                    else:
                        isfirst = False
                    # for j in self.next_layers[i]:
                    for j in net.next_layers[i]:
                        print('passive', net[j].module.weight.shape)
                        net[j].spffn_passive_grow(split_idx, new_idx)

        else:
            threshold= self.sp_threshold()
            # actual splitting
            # for i in reversed(self.layers_to_split):
            for i in reversed(net.layers_to_split):
                if isinstance(net[i], sp.SpModule) and net[i].can_split:
                    n_new, idx = net[i].active_split(threshold)
                    n_neurons_added[i] = n_new
                    if n_new > 0: # we have indeed splitted this layer
                        # for j in self.next_layers[i]:
                        for j in net.next_layers[i]:
                            net[j].passive_split(idx)

        net.train()
        self.clear(net) # cleanup auxiliaries

        end_time = time.time()
        if self.verbose:
            print('[INFO] splitting takes %10.4f sec. Threshold value is %10.9f' % (
                end_time - start_time, threshold))
            if split_method == 'fireflyn':
                print('[INFO] number of added neurons: \n%s\n' % \
                        '\n'.join(['-- %d grows (sp %d | new %d)' % (x, y1, y2) for x, (y1, y2) in n_neurons_added.items()]))
            else:
                print('[INFO] number of added neurons: \n%s\n' % \
                        '\n'.join(['-- %d grows %d neurons' % (x, y) for x, y in n_neurons_added.items()]))
        return n_neurons_added