in src/exporters/coreml/models.py [0:0]
def patch_pytorch_ops(self):
# https://github.com/apple/coremltools/issues/1852
def einsum(context, node):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend._utils import build_einsum_mil
from coremltools.converters.mil.mil import types
a = context[node.inputs[1]][0]
b = context[node.inputs[1]][1]
equation = context[node.inputs[0]].val
equation = "".join(equation.split(" "))
if equation == "i,j->ij" and types.is_int(a.dtype):
a = mb.cast(x=a, dtype="fp32")
x = build_einsum_mil(a, b, equation, node.name)
context.add(x)
return {"einsum": einsum}