in src/lic/ppl/world/utils.py [0:0]
def get_default_transforms(distribution: Distribution) -> List:
"""
Get transforms of a distribution to transform it from constrained space
into unconstrained space.
:param distribution: the distribution to check
:returns: the list of transforms that need to be applied to the distribution
to transform it from constrained space into unconstrained space
"""
# pyre-fixme
support = distribution.support
# pyre-fixme
sample = distribution.sample()
if is_discrete(distribution):
return []
elif is_constraint_eq(support, constraints.real):
return []
elif is_constraint_eq(support, constraints.interval):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
lower_bound = tensor(lower_bound, dtype=sample.dtype)
upper_bound = support.upper_bound
if not isinstance(upper_bound, Tensor):
upper_bound = tensor(upper_bound, dtype=sample.dtype)
lower_bound_zero = dist.AffineTransform(-lower_bound, 1.0)
upper_bound_one = dist.AffineTransform(0, 1.0 / (upper_bound - lower_bound))
beta_dimension = BetaDimensionTransform()
stick_breaking = dist.StickBreakingTransform().inv
return [lower_bound_zero, upper_bound_one, beta_dimension, stick_breaking]
elif is_constraint_eq(support, constraints.greater_than) or isinstance(
support, constraints.greater_than_eq
):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
lower_bound = tensor(lower_bound, dtype=sample.dtype)
lower_bound_zero = dist.AffineTransform(-lower_bound, 1.0)
log_transform = dist.ExpTransform().inv
return [lower_bound_zero, log_transform]
elif is_constraint_eq(support, constraints.less_than):
upper_bound = support.upper_bound
if not isinstance(upper_bound, Tensor):
upper_bound = tensor(upper_bound, dtype=sample.dtype)
upper_bound_zero = dist.AffineTransform(-upper_bound, 1.0)
flip_to_greater = dist.AffineTransform(0, -1.0)
log_transform = dist.ExpTransform().inv
return [upper_bound_zero, flip_to_greater, log_transform]
elif is_constraint_eq(support, constraints.simplex):
return [dist.StickBreakingTransform().inv]
return []