def split()

in crlapi/sl/architectures/firefly_vgg/sp/net.py [0:0]


    def split(self, split_method, dataset, n_batches=-1):
        self.num_group = 1 if self.config.model != 'mobile' else 2
        if split_method not in ['random', 'exact', 'fast', 'firefly', 'fireflyn']:
            raise NotImplementedError

        if self.verbose:
            print('[INFO] start splitting ...')

        start_time = time.time()
        self.net.eval()

        if self.num_group == 1:
            self.get_num_elites()
        else:
            self.get_num_elites_group(self.num_group)
        split_fn = {
            'exact': self.spe,
            'fast': self.spf,
            'firefly': self.spff,
            'fireflyn': self.spffn,
        }

        if split_method != 'random':
            split_fn[split_method](dataset, n_batches)

        n_neurons_added = {}

        if split_method == 'random':
            n_layers = len(self.layers_to_split)
            n_total_neurons = 0
            threshold = 0.
            for l in self.layers_to_split:
                n_total_neurons += self.net[l].get_n_neurons()
            n_grow = int(n_total_neurons * self.grow_ratio)
            n_new1 = np.random.choice(n_grow, n_layers, replace=False)
            n_new1 = np.sort(n_new1)
            n_news = []
            for i in range(len(n_new1) - 1):
                if i == 0:
                    n_news.append(n_new1[i])
                    n_news.append(n_new1[i + 1] - n_new1[i])
                else:
                    n_news.append(n_new1[i + 1] - n_new1[i])
                    n_news[-1] += 1
            for i, n_new_ in zip(reversed(self.layers_to_split), n_news):
                if isinstance(self.net[i], sp.SpModule) and self.net[i].can_split:
                    n_new, idx = self.net[i].random_split(n_new_)
                    n_neurons_added[i] = n_new
                    if n_new > 0: # we have indeed splitted this layer
                        for j in self.next_layers[i]:
                            self.net[j].passive_split(idx)
        elif split_method == 'fireflyn':
            if self.num_group == 1:
                threshold = self.sp_threshold()
            for i in reversed(self.layers_to_split):
                if isinstance(self.net[i], sp.SpModule) and self.net[i].can_split:
                    if self.num_group != 1:
                        group = self.total_group[i]
                        threshold = self.sp_threshold_group(group)
                    n_new, split_idx, new_idx = self.net[i].spffn_active_grow(threshold)
                    sp_new = split_idx.shape[0] if split_idx is not None else 0
                    n_neurons_added[i] = (sp_new, n_new-sp_new)
                    if self.net[i].kh == 1:
                        isfirst = True
                    else:
                        isfirst = False
                    for j in self.next_layers[i]:
                        print('passive', self.net[j].module.weight.shape)
                        self.net[j].spffn_passive_grow(split_idx, new_idx)

        else:
            threshold= self.sp_threshold()
            # actual splitting
            for i in reversed(self.layers_to_split):
                if isinstance(self.net[i], sp.SpModule) and self.net[i].can_split:
                    n_new, idx = self.net[i].active_split(threshold)
                    n_neurons_added[i] = n_new
                    if n_new > 0: # we have indeed splitted this layer
                        for j in self.next_layers[i]:
                            self.net[j].passive_split(idx)

        self.net.train()
        self.clear() # cleanup auxiliaries
        self.create_optimizer() # re-initialize optimizer

        end_time = time.time()
        if self.verbose:
            print('[INFO] splitting takes %10.4f sec. Threshold value is %10.9f' % (
                end_time - start_time, threshold))
            if split_method == 'fireflyn':
                print('[INFO] number of added neurons: \n%s\n' % \
                        '\n'.join(['-- %d grows (sp %d | new %d)' % (x, y1, y2) for x, (y1, y2) in n_neurons_added.items()]))
            else:
                print('[INFO] number of added neurons: \n%s\n' % \
                        '\n'.join(['-- %d grows %d neurons' % (x, y) for x, y in n_neurons_added.items()]))
        return n_neurons_added