in flow_layers/planar.py [0:0]
def forward(self, x, logpx=None, reverse=False, z0=None, log_alpha=None, beta=None, **kwargs):
if reverse:
raise ValueError(f"{self.__class__.__name__} does not support reverse.")
if self.hypernet:
assert z0 is not None and log_alpha is not None and beta is not None
beta = (-torch.exp(log_alpha) + F.softplus(beta))
else:
z0 = self.z0
log_alpha = self.log_alpha
beta = (-torch.exp(log_alpha) + F.softplus(self._beta))
z0 = z0.expand_as(x)
r = torch.norm(x - z0, dim=-1, keepdim=True)
h = 1 / (torch.exp(log_alpha) + r)
f = x + beta * h * (x - z0)
if logpx is not None:
logdetgrad = (self.nd - 1) * torch.log(1 + beta * h) + \
torch.log(1 + beta * h - beta * r / (torch.exp(log_alpha) + r) ** 2)
logpy = logpx - logdetgrad.reshape(-1)
return f, logpy
else:
return f