def patch_pytorch_ops()

in src/exporters/coreml/models.py [0:0]


    def patch_pytorch_ops(self):
        # Copied from https://github.com/apple/coremltools/blob/b2f719075dc5bc19280a3045c1762d7d32bd3fdc/coremltools/converters/mil/frontend/torch/ops.py#L4326
        # with fallback of `bfloat16` to `float32`.
        def to(context, node):
            from coremltools.converters.mil import Builder as mb
            from coremltools.converters.mil.mil import types
            from coremltools.converters.mil.frontend.torch.ops import (
                _get_inputs,
                NUMPY_DTYPE_TO_TORCH_NUM,
                NUM_TO_TORCH_DTYPE,
                NUM_TO_DTYPE_STRING,
                NUM_TO_NUMPY_DTYPE,
                TORCH_DTYPE_TO_NUM,
            )
            from coremltools.converters.mil.mil.types import nptype_from_builtin
            from coremltools.converters.mil.mil.var import Var
            import numpy as _np
            import torch

            inputs = _get_inputs(context, node)

            # There are a lot of variants of `to` op.
            # - When len(inputs) is 7 or 8, we only care about the first two params (input and dtype).
            # - When len(inputs) == 6, the parameter is (input, _, dtype, non_blocking, copy, memory_format)
            # - When len(inputs) == 5, the parameter is (input, dtype, non_blocking, copy, memory_format)
            # - When len(inputs) == 4, the parameter is (input, dtype, non_blocking, copy)
            # - When len(inputs) == 3, the parameter is (input, non_blocking, copy)
            # We only use `input` and `dtype`, and `non_blocking` and `copy` are unused.
            _input = inputs[0]

            inputs_len = len(inputs)
            if inputs_len in (4, 5, 7, 8):
                target_dtype = inputs[1]
            elif inputs_len == 6:
                target_dtype = inputs[2]
            elif inputs_len <= 3:
                target_dtype = None
            else:
                raise ValueError(
                    "Received invalid arguments for PyTorch conversion of op {}".format(node)
                )

            if target_dtype is None:
                # When target_dtype is None, it means the input's dtype is already the target dtype.
                context.add(_input, torch_name=node.name)
                return
            elif types.is_scalar(target_dtype.sym_type) and target_dtype.val is not None:
                dtype = target_dtype.val
            else:
                # When the val of dtype is not available, bridge from the np dtype.
                np_type = nptype_from_builtin(target_dtype.dtype)
                dtype = NUMPY_DTYPE_TO_TORCH_NUM[np_type]

            if dtype in NUM_TO_TORCH_DTYPE:
                torch_dtype = NUM_TO_TORCH_DTYPE[dtype]
            else:
                # Fallback `bfloat32` to `fp32` for now.
                torch_dtype = torch.float32

            if isinstance(_input, Var) and _input.can_be_folded_to_const():
                # numpy -> torch -> torch cast -> numpy
                # This path is needed to use the mapping of passed in dtypes to torch dtypes.
                casted_input = torch.tensor(_input.val).type(torch_dtype).cpu().numpy()
                res = mb.const(val=casted_input, name=node.name)
            else:
                if dtype in NUM_TO_DTYPE_STRING:
                    res = mb.cast(x=_input, dtype=NUM_TO_DTYPE_STRING[dtype], name=node.name)
                else:
                    # For dtype that is not supported by mb.cast, we do it in best-effort to cast it to int
                    # or float based on the dtype.
                    np_dtype = NUM_TO_NUMPY_DTYPE[dtype]
                    if _np.issubdtype(np_dtype, _np.integer):
                        res = mb.cast(x=_input, dtype="int32", name=node.name)
                    elif _np.issubdtype(np_dtype, _np.floating):
                        res = mb.cast(x=_input, dtype="fp32", name=node.name)
                    else:
                        raise ValueError(f"Unsupported op {node} with target dtype {np_dtype}")
            context.add(res)

        # Workaround until https://github.com/apple/coremltools/pull/2046 is released 
        def numpy_t(context, node):
            from coremltools.converters.mil import Builder as mb

            assert len(node.outputs) == 1
            assert len(node.inputs) == 1

            x = context[node.inputs[0]]
            assert len(x.shape) == 2

            res = mb.transpose(x=x, perm=[1, 0], name=node.name)
            context.add(res)

        return {"to": to, "numpy_t": numpy_t}