in models/encoders/mvconv1.py [0:0]
def forward(self, x, losslist=[]):
x = self.pad(x)
x = [self.down1[0 if self.tied else i](x[:, i*3:(i+1)*3, :, :]).view(-1, 256 * 3 * 4) for i in range(self.ninputs)]
x = torch.cat(x, dim=1)
x = self.down2(x)
mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01
if self.training:
z = mu + torch.exp(logstd) * torch.randn(*logstd.size(), device=logstd.device)
else:
z = mu
losses = {}
if "kldiv" in losslist:
losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1)
return {"encoding": z, "losses": losses}