in crlapi/sl/architectures/mixture_model.py [0:0]
def grow(self,dataset_loader,**args):
if self.n_experts_split==0:
return self
with torch.no_grad():
usage=None
n=0
for x,y in dataset_loader:
x, y = x.to(self.device), y.to(self.device)
out,gate_scores=self(x,with_gate_scores=True)
loss=F.cross_entropy(out,y,reduction='none')
gate_scores=[[(gg[0],gg[1].sum(0)) for gg in g] for g in gate_scores]
n+=x.size()[0]
if usage is None:
usage=gate_scores
else:
for k,g in enumerate(gate_scores):
for kk,gg in enumerate(g):
assert gg[0]==usage[k][kk][0]
usage[k][kk]=(gg[0],gg[1]+usage[k][kk][1])
self.zero_grad()
new_layers=[]
p=0
for k,l in enumerate(self.layers):
if isinstance(l,MixtureLayer):
u=usage[p]
us=[uu[1].item() for uu in u]
idx=np.argmax(us)
print("Expert usage at layer ",k," is ",{str(uu[0]):uu[1].item() for uu in u})
max_expert=u[idx][0]
print("\tSplitting expert ",max_expert)
new_layers.append(self._grow_layer(l,max_expert))
p+=1
else:
new_layers.append(copy.deepcopy(l))
return MoE_UsageGrow(new_layers,self.n_experts_split)