in optimum/fx/optimization/transformations.py [0:0]
def compose(*args: Transformation, inplace: bool = True) -> Transformation:
"""
Composes a list of transformations together.
Args:
args ([`~optimum.fx.optimization.Transformation`]):
The transformations to compose together.
inplace (`bool`, defaults to `True`):
Whether the resulting transformation should be inplace, or create a new graph module.
Returns:
The composition transformation object.
Example:
```python
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
>>> transformed_model = composition(traced)
```
"""
transformations = list(reversed(args))
composition_preserves_computation = all(t.preserves_computation for t in transformations)
composition_is_reversible = all((isinstance(t, ReversibleTransformation) for t in transformations))
if not inplace:
transformations.append(DeepCopy())
if not composition_is_reversible:
def reduce_fn(f, g):
def composition(graph_module, lint_and_recompile=False):
return f(g(graph_module, lint_and_recompile=lint_and_recompile))
return composition
class ComposeTransformation(Transformation):
preserves_computation = composition_preserves_computation
_composition = functools.reduce(reduce_fn, transformations)
def transform(self, graph_module):
return ComposeTransformation._composition(graph_module)
else:
def make_reduce_fn(reverse):
def reduce_fn(f, g):
def composition(graph_module, lint_and_recompile=False, reverse=reverse):
return f(
g(graph_module, lint_and_recompile=lint_and_recompile, reverse=reverse),
lint_and_recompile=lint_and_recompile,
reverse=reverse,
)
return composition
return reduce_fn
class ComposeTransformation(ReversibleTransformation):
preserves_computation = composition_preserves_computation
_composition = functools.reduce(make_reduce_fn(False), transformations)
_reverse_composition = functools.reduce(make_reduce_fn(True), reversed(transformations))
def transform(self, graph_module):
return ComposeTransformation._composition(graph_module)
def reverse(self, graph_module):
return ComposeTransformation._reverse_composition(graph_module)
return ComposeTransformation()