def grow()

in crlapi/sl/architectures/mixture_model.py [0:0]


    def grow(self,dataset_loader,**args):
        if self.n_experts_split==0:
            return self

        with torch.no_grad():
            usage=None
            n=0
            for x,y in dataset_loader:
                x, y = x.to(self.device), y.to(self.device)
                out,gate_scores=self(x,with_gate_scores=True)
                loss=F.cross_entropy(out,y,reduction='none')
                gate_scores=[[(gg[0],gg[1].sum(0)) for gg in g] for g in gate_scores]
                n+=x.size()[0]
                if usage is None:
                    usage=gate_scores
                else:
                    for k,g in enumerate(gate_scores):
                        for kk,gg in enumerate(g):
                            assert gg[0]==usage[k][kk][0]
                            usage[k][kk]=(gg[0],gg[1]+usage[k][kk][1])

        self.zero_grad()
        new_layers=[]
        p=0
        for k,l in enumerate(self.layers):
            if isinstance(l,MixtureLayer):
                u=usage[p]
                us=[uu[1].item() for uu in u]
                idx=np.argmax(us)
                print("Expert usage at layer ",k," is ",{str(uu[0]):uu[1].item() for uu in u})
                max_expert=u[idx][0]
                print("\tSplitting expert ",max_expert)
                new_layers.append(self._grow_layer(l,max_expert))
                p+=1
            else:
                new_layers.append(copy.deepcopy(l))
        return MoE_UsageGrow(new_layers,self.n_experts_split)