in fvcore/nn/jit_handles.py [0:0]
def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the einsum operation.
"""
# Inputs of einsum should be a list of length 2.
# Inputs[0] stores the equation used for einsum.
# Inputs[1] stores the list of input shapes.
assert len(inputs) == 2, len(inputs)
equation = inputs[0].toIValue()
# Get rid of white space in the equation string.
equation = equation.replace(" ", "")
input_shapes_jit = inputs[1].node().inputs()
input_shapes = [get_shape(v) for v in input_shapes_jit]
# Re-map equation so that same equation with different alphabet
# representations will look the same.
letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
equation = equation.translate(mapping)
if equation == "abc,abd->acd":
n, c, t = input_shapes[0]
p = input_shapes[-1][-1]
flop = n * c * t * p
return flop
elif equation == "abc,adc->adb":
n, t, g = input_shapes[0]
c = input_shapes[-1][1]
flop = n * t * g * c
return flop
else:
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
raise NotImplementedError("Unsupported einsum operation.")