in model.py [0:0]
def forward(self, inputs):
embs_all = self.lt.weight.unsqueeze(0)
embs_all = embs_all.expand(len(inputs), self.size, self.dim)
embs_inputs = self.lt(inputs).unsqueeze(1)
embs_inputs = embs_inputs.expand_as(embs_all)
dists = self.dist().apply(embs_inputs, embs_all).squeeze(-1)
if self.lossfnname == 'kl':
if self.Qdist == 'laplace':
return self.lsm(-self.gamma * dists)
elif self.Qdist == 'gaussian':
return self.lsm(-self.gamma * dists.pow(2))
elif self.Qdist == 'student':
return self.lsm(-torch.log(1 + self.gamma * dists))
else:
raise NotImplementedError
elif self.lossfnname == 'klSym':
if self.Qdist == 'laplace':
return self.sm(-self.gamma * dists)
elif self.Qdist == 'gaussian':
return self.sm(-self.gamma * dists.pow(2))
elif self.Qdist == 'student':
return self.sm(-torch.log(1 + self.gamma * dists))
else:
raise NotImplementedError
elif self.lossfnname == 'mse':
return self.sm(-self.gamma * dists)
else:
raise NotImplementedError