in coremltools/converters/mil/mil/passes/gelu_tanh_approximation_fusion.py [0:0]
def _try_to_transform(pow_op, block):
all_ops = [pow_op]
root_var = pow_op.x
# check that root_var feeds into exactly 3 ops
if len(list(root_var.child_ops)) != 3:
return False
# check for 1st mul op
if not _check_child_op_type(pow_op, "mul"):
return False
mul_op1 = list(pow_op.outputs[0].child_ops)[0]
if not (
(
mul_op1.x == pow_op.outputs[0]
and _check_var_scalar_value(mul_op1.y, 0.044715)
)
or (
mul_op1.y == pow_op.outputs[0]
and _check_var_scalar_value(mul_op1.x, 0.044715)
)
):
return False
all_ops.append(mul_op1)
# check for 1st add op
if not _check_child_op_type(mul_op1, "add"):
return False
add_op1 = list(mul_op1.outputs[0].child_ops)[0]
if not (
(add_op1.x == mul_op1.outputs[0] and add_op1.y == root_var)
or (add_op1.y == mul_op1.outputs[0] and add_op1.x == root_var)
):
return False
all_ops.append(add_op1)
# check for 2nd mul op
if not _check_child_op_type(add_op1, "mul"):
return False
mul_op2 = list(add_op1.outputs[0].child_ops)[0]
if not (
(
mul_op2.x == add_op1.outputs[0]
and _check_var_scalar_value(mul_op2.y, 0.79788)
)
or (
mul_op2.y == add_op1.outputs[0]
and _check_var_scalar_value(mul_op2.x, 0.79788)
)
):
return False
all_ops.append(mul_op2)
# check for tanh op
if not _check_child_op_type(mul_op2, "tanh"):
return False
tanh_op = list(mul_op2.outputs[0].child_ops)[0]
all_ops.append(tanh_op)
# check for 2nd add op
if not _check_child_op_type(tanh_op, "add"):
return False
add_op2 = list(tanh_op.outputs[0].child_ops)[0]
if not (
(add_op2.x == tanh_op.outputs[0] and _check_var_scalar_value(add_op2.y, 1))
or (add_op2.y == tanh_op.outputs[0] and _check_var_scalar_value(add_op2.x, 1))
):
return False
all_ops.append(add_op2)
# check for 3rd mul op
if not _check_child_op_type(add_op2, "mul"):
return False
mul_op3 = list(add_op2.outputs[0].child_ops)[0]
if not (
(mul_op3.x == add_op2.outputs[0] and _check_var_scalar_value(mul_op3.y, 0.5))
or (mul_op3.y == add_op2.outputs[0] and _check_var_scalar_value(mul_op3.x, 0.5))
):
return False
all_ops.append(mul_op3)
# check for 4th mul op
if not _check_child_op_type(mul_op3, "mul"):
return False
mul_op4 = list(mul_op3.outputs[0].child_ops)[0]
if not (
(mul_op4.x == mul_op3.outputs[0] and mul_op4.y == root_var)
or (mul_op4.y == mul_op3.outputs[0] and mul_op4.x == root_var)
):
return False
all_ops.append(mul_op4)
# check that none of the op in this pattern is connected to the output
# (except the last mul op)
for i, op in enumerate(all_ops):
if i == len(all_ops) - 1:
continue
for out in op.outputs:
if out in block.outputs:
return False
# remove all the ops, and replace with a gelu op
out_name = mul_op4.outputs[0].name
x = mb.gelu(x=root_var, mode="TANH_APPROXIMATION", name=out_name, before_op=pow_op)
mul_op4.enclosing_block.replace_uses_of_var_after_op(
anchor_op=mul_op4, old_var=mul_op4.outputs[0], new_var=x
)
# Remove all the ops at once
block.remove_ops(all_ops)
return True