flowtorch/distributions/neals_funnel.py (38 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc
from typing import Any, Dict, Optional, Union
import torch
import torch.distributions as dist
from torch.distributions import constraints
from torch.distributions.utils import _standard_normal
class NealsFunnel(dist.Distribution):
"""
Neal's funnel.
p(x,y) = N(y|0,3) N(x|0,exp(y/2))
"""
support = constraints.real
arg_constraints: Dict[str, dist.constraints.Constraint] = {}
def __init__(self, validate_args: Any = None) -> None:
d = 2
batch_shape, event_shape = torch.Size([]), (d,)
super(NealsFunnel, self).__init__(
batch_shape, event_shape, validate_args=validate_args
)
def rsample(
self,
sample_shape: Union[torch.Tensor, torch.Size] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not sample_shape:
sample_shape = torch.Size()
eps = _standard_normal(
(sample_shape[0], 2), dtype=torch.float, device=torch.device("cpu")
)
z = torch.zeros(eps.shape)
z[..., 1] = torch.tensor(3.0) * eps[..., 1]
z[..., 0] = torch.exp(z[..., 1] / 2.0) * eps[..., 0]
return z
def log_prob(
self, value: torch.Tensor, context: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self._validate_args:
self._validate_sample(value)
x = value[..., 0]
y = value[..., 1]
log_prob = dist.Normal(0, 3).log_prob(y)
log_prob += dist.Normal(0, torch.exp(y / 2)).log_prob(x)
return log_prob