in augerino_lib/aug_modules.py [0:0]
def forward(self,x,y=0):
if self.disabled:
return self.model(x)
else:
if self.training and self.onecopy:
return self.model(self.aug(x,y))
else:
#Faster batched implementation
#return (sum(F.log_softmax(self.model(self.aug(x)),dim=-1) for _ in range(self.ncopies))/self.ncopies)#.log()
bs = x.shape[0]
aug_x = torch.cat([self.aug(x,y) for _ in range(self.ncopies)],dim=0)
return sum(torch.split(F.log_softmax(self.model(aug_x),dim=-1),bs))/self.ncopies