flowtorch/bijectors/softplus.py (25 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc
from typing import Optional, Sequence, Tuple
import flowtorch.ops
import torch
import torch.distributions.constraints as constraints
import torch.nn.functional as F
from flowtorch.bijectors.fixed import Fixed
class Softplus(Fixed):
r"""
Elementwise bijector via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
"""
codomain = constraints.positive
def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
y = F.softplus(x)
ladj = self._log_abs_det_jacobian(x, y, params)
return y, ladj
def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = flowtorch.ops.softplus_inv(y)
ladj = self._log_abs_det_jacobian(x, y, params)
return x, ladj
def _log_abs_det_jacobian(
self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> torch.Tensor:
return -F.softplus(-x)