def compose()

in optimum/fx/optimization/transformations.py [0:0]


def compose(*args: Transformation, inplace: bool = True) -> Transformation:
    """
    Composes a list of transformations together.

    Args:
        args ([`~optimum.fx.optimization.Transformation`]):
            The transformations to compose together.
        inplace (`bool`, defaults to `True`):
            Whether the resulting transformation should be inplace, or create a new graph module.

    Returns:
        The composition transformation object.

    Example:

    ```python
    >>> from transformers import BertModel
    >>> from transformers.utils.fx import symbolic_trace
    >>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose

    >>> model = BertModel.from_pretrained("bert-base-uncased")
    >>> traced = symbolic_trace(
    ...     model,
    ...     input_names=["input_ids", "attention_mask", "token_type_ids"],
    ... )
    >>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
    >>> transformed_model = composition(traced)
    ```
    """
    transformations = list(reversed(args))

    composition_preserves_computation = all(t.preserves_computation for t in transformations)
    composition_is_reversible = all((isinstance(t, ReversibleTransformation) for t in transformations))

    if not inplace:
        transformations.append(DeepCopy())

    if not composition_is_reversible:

        def reduce_fn(f, g):
            def composition(graph_module, lint_and_recompile=False):
                return f(g(graph_module, lint_and_recompile=lint_and_recompile))

            return composition

        class ComposeTransformation(Transformation):
            preserves_computation = composition_preserves_computation

            _composition = functools.reduce(reduce_fn, transformations)

            def transform(self, graph_module):
                return ComposeTransformation._composition(graph_module)

    else:

        def make_reduce_fn(reverse):
            def reduce_fn(f, g):
                def composition(graph_module, lint_and_recompile=False, reverse=reverse):
                    return f(
                        g(graph_module, lint_and_recompile=lint_and_recompile, reverse=reverse),
                        lint_and_recompile=lint_and_recompile,
                        reverse=reverse,
                    )

                return composition

            return reduce_fn

        class ComposeTransformation(ReversibleTransformation):
            preserves_computation = composition_preserves_computation

            _composition = functools.reduce(make_reduce_fn(False), transformations)
            _reverse_composition = functools.reduce(make_reduce_fn(True), reversed(transformations))

            def transform(self, graph_module):
                return ComposeTransformation._composition(graph_module)

            def reverse(self, graph_module):
                return ComposeTransformation._reverse_composition(graph_module)

    return ComposeTransformation()