def get_default_transforms()

in src/lic/ppl/world/utils.py [0:0]


def get_default_transforms(distribution: Distribution) -> List:
    """
    Get transforms of a distribution to transform it from constrained space
    into unconstrained space.

    :param distribution: the distribution to check
    :returns: the list of transforms that need to be applied to the distribution
    to transform it from constrained space into unconstrained space
    """
    # pyre-fixme
    support = distribution.support
    # pyre-fixme
    sample = distribution.sample()
    if is_discrete(distribution):
        return []
    elif is_constraint_eq(support, constraints.real):
        return []

    elif is_constraint_eq(support, constraints.interval):
        lower_bound = support.lower_bound
        if not isinstance(lower_bound, Tensor):
            lower_bound = tensor(lower_bound, dtype=sample.dtype)
        upper_bound = support.upper_bound
        if not isinstance(upper_bound, Tensor):
            upper_bound = tensor(upper_bound, dtype=sample.dtype)

        lower_bound_zero = dist.AffineTransform(-lower_bound, 1.0)
        upper_bound_one = dist.AffineTransform(0, 1.0 / (upper_bound - lower_bound))
        beta_dimension = BetaDimensionTransform()
        stick_breaking = dist.StickBreakingTransform().inv

        return [lower_bound_zero, upper_bound_one, beta_dimension, stick_breaking]

    elif is_constraint_eq(support, constraints.greater_than) or isinstance(
        support, constraints.greater_than_eq
    ):
        lower_bound = support.lower_bound
        if not isinstance(lower_bound, Tensor):
            lower_bound = tensor(lower_bound, dtype=sample.dtype)
        lower_bound_zero = dist.AffineTransform(-lower_bound, 1.0)
        log_transform = dist.ExpTransform().inv

        return [lower_bound_zero, log_transform]

    elif is_constraint_eq(support, constraints.less_than):
        upper_bound = support.upper_bound
        if not isinstance(upper_bound, Tensor):
            upper_bound = tensor(upper_bound, dtype=sample.dtype)

        upper_bound_zero = dist.AffineTransform(-upper_bound, 1.0)
        flip_to_greater = dist.AffineTransform(0, -1.0)
        log_transform = dist.ExpTransform().inv

        return [upper_bound_zero, flip_to_greater, log_transform]

    elif is_constraint_eq(support, constraints.simplex):
        return [dist.StickBreakingTransform().inv]

    return []