flowtorch/bijectors/spline.py (24 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc from typing import Optional import flowtorch import torch from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.ops.spline import Spline as SplineOp class Spline(SplineOp, Elementwise): def __init__( self, params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, count_bins: int = 8, bound: float = 3.0, order: str = "linear" ) -> None: super().__init__( params_fn, shape=shape, context_shape=context_shape, count_bins=count_bins, bound=bound, order=order, )