in src/exporters/coreml/models.py [0:0]
def patch_pytorch_ops(self):
# Copied from https://github.com/apple/coremltools/blob/b2f719075dc5bc19280a3045c1762d7d32bd3fdc/coremltools/converters/mil/frontend/torch/ops.py#L4326
# with fallback of `bfloat16` to `float32`.
def to(context, node):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs,
NUMPY_DTYPE_TO_TORCH_NUM,
NUM_TO_TORCH_DTYPE,
NUM_TO_DTYPE_STRING,
NUM_TO_NUMPY_DTYPE,
TORCH_DTYPE_TO_NUM,
)
from coremltools.converters.mil.mil.types import nptype_from_builtin
from coremltools.converters.mil.mil.var import Var
import numpy as _np
import torch
inputs = _get_inputs(context, node)
# There are a lot of variants of `to` op.
# - When len(inputs) is 7 or 8, we only care about the first two params (input and dtype).
# - When len(inputs) == 6, the parameter is (input, _, dtype, non_blocking, copy, memory_format)
# - When len(inputs) == 5, the parameter is (input, dtype, non_blocking, copy, memory_format)
# - When len(inputs) == 4, the parameter is (input, dtype, non_blocking, copy)
# - When len(inputs) == 3, the parameter is (input, non_blocking, copy)
# We only use `input` and `dtype`, and `non_blocking` and `copy` are unused.
_input = inputs[0]
inputs_len = len(inputs)
if inputs_len in (4, 5, 7, 8):
target_dtype = inputs[1]
elif inputs_len == 6:
target_dtype = inputs[2]
elif inputs_len <= 3:
target_dtype = None
else:
raise ValueError(
"Received invalid arguments for PyTorch conversion of op {}".format(node)
)
if target_dtype is None:
# When target_dtype is None, it means the input's dtype is already the target dtype.
context.add(_input, torch_name=node.name)
return
elif types.is_scalar(target_dtype.sym_type) and target_dtype.val is not None:
dtype = target_dtype.val
else:
# When the val of dtype is not available, bridge from the np dtype.
np_type = nptype_from_builtin(target_dtype.dtype)
dtype = NUMPY_DTYPE_TO_TORCH_NUM[np_type]
if dtype in NUM_TO_TORCH_DTYPE:
torch_dtype = NUM_TO_TORCH_DTYPE[dtype]
else:
# Fallback `bfloat32` to `fp32` for now.
torch_dtype = torch.float32
if isinstance(_input, Var) and _input.can_be_folded_to_const():
# numpy -> torch -> torch cast -> numpy
# This path is needed to use the mapping of passed in dtypes to torch dtypes.
casted_input = torch.tensor(_input.val).type(torch_dtype).cpu().numpy()
res = mb.const(val=casted_input, name=node.name)
else:
if dtype in NUM_TO_DTYPE_STRING:
res = mb.cast(x=_input, dtype=NUM_TO_DTYPE_STRING[dtype], name=node.name)
else:
# For dtype that is not supported by mb.cast, we do it in best-effort to cast it to int
# or float based on the dtype.
np_dtype = NUM_TO_NUMPY_DTYPE[dtype]
if _np.issubdtype(np_dtype, _np.integer):
res = mb.cast(x=_input, dtype="int32", name=node.name)
elif _np.issubdtype(np_dtype, _np.floating):
res = mb.cast(x=_input, dtype="fp32", name=node.name)
else:
raise ValueError(f"Unsupported op {node} with target dtype {np_dtype}")
context.add(res)
# Workaround until https://github.com/apple/coremltools/pull/2046 is released
def numpy_t(context, node):
from coremltools.converters.mil import Builder as mb
assert len(node.outputs) == 1
assert len(node.inputs) == 1
x = context[node.inputs[0]]
assert len(x.shape) == 2
res = mb.transpose(x=x, perm=[1, 0], name=node.name)
context.add(res)
return {"to": to, "numpy_t": numpy_t}