def forward()

in augerino_lib/aug_modules.py [0:0]


    def forward(self,x,y=0):
        
        if self.disabled:
            return self.model(x)
        else:
            if self.training and self.onecopy: 
                return self.model(self.aug(x,y))
            else: 
                #Faster batched implementation
                #return (sum(F.log_softmax(self.model(self.aug(x)),dim=-1) for _ in range(self.ncopies))/self.ncopies)#.log()
                bs = x.shape[0]
                aug_x = torch.cat([self.aug(x,y) for _ in range(self.ncopies)],dim=0)
                return sum(torch.split(F.log_softmax(self.model(aug_x),dim=-1),bs))/self.ncopies