def _try_to_transform()

in coremltools/converters/mil/mil/passes/gelu_tanh_approximation_fusion.py [0:0]


def _try_to_transform(pow_op, block):
    all_ops = [pow_op]
    root_var = pow_op.x

    # check that root_var feeds into exactly 3 ops
    if len(list(root_var.child_ops)) != 3:
        return False

    # check for 1st mul op
    if not _check_child_op_type(pow_op, "mul"):
        return False
    mul_op1 = list(pow_op.outputs[0].child_ops)[0]
    if not (
        (
            mul_op1.x == pow_op.outputs[0]
            and _check_var_scalar_value(mul_op1.y, 0.044715)
        )
        or (
            mul_op1.y == pow_op.outputs[0]
            and _check_var_scalar_value(mul_op1.x, 0.044715)
        )
    ):
        return False
    all_ops.append(mul_op1)

    # check for 1st add op
    if not _check_child_op_type(mul_op1, "add"):
        return False
    add_op1 = list(mul_op1.outputs[0].child_ops)[0]
    if not (
        (add_op1.x == mul_op1.outputs[0] and add_op1.y == root_var)
        or (add_op1.y == mul_op1.outputs[0] and add_op1.x == root_var)
    ):
        return False
    all_ops.append(add_op1)

    # check for 2nd mul op
    if not _check_child_op_type(add_op1, "mul"):
        return False
    mul_op2 = list(add_op1.outputs[0].child_ops)[0]
    if not (
        (
            mul_op2.x == add_op1.outputs[0]
            and _check_var_scalar_value(mul_op2.y, 0.79788)
        )
        or (
            mul_op2.y == add_op1.outputs[0]
            and _check_var_scalar_value(mul_op2.x, 0.79788)
        )
    ):
        return False
    all_ops.append(mul_op2)

    # check for tanh op
    if not _check_child_op_type(mul_op2, "tanh"):
        return False
    tanh_op = list(mul_op2.outputs[0].child_ops)[0]
    all_ops.append(tanh_op)

    # check for 2nd add op
    if not _check_child_op_type(tanh_op, "add"):
        return False
    add_op2 = list(tanh_op.outputs[0].child_ops)[0]
    if not (
        (add_op2.x == tanh_op.outputs[0] and _check_var_scalar_value(add_op2.y, 1))
        or (add_op2.y == tanh_op.outputs[0] and _check_var_scalar_value(add_op2.x, 1))
    ):
        return False
    all_ops.append(add_op2)

    # check for 3rd mul op
    if not _check_child_op_type(add_op2, "mul"):
        return False
    mul_op3 = list(add_op2.outputs[0].child_ops)[0]
    if not (
        (mul_op3.x == add_op2.outputs[0] and _check_var_scalar_value(mul_op3.y, 0.5))
        or (mul_op3.y == add_op2.outputs[0] and _check_var_scalar_value(mul_op3.x, 0.5))
    ):
        return False
    all_ops.append(mul_op3)

    # check for 4th mul op
    if not _check_child_op_type(mul_op3, "mul"):
        return False
    mul_op4 = list(mul_op3.outputs[0].child_ops)[0]
    if not (
        (mul_op4.x == mul_op3.outputs[0] and mul_op4.y == root_var)
        or (mul_op4.y == mul_op3.outputs[0] and mul_op4.x == root_var)
    ):
        return False
    all_ops.append(mul_op4)

    # check that none of the op in this pattern is connected to the output
    # (except the last mul op)
    for i, op in enumerate(all_ops):
        if i == len(all_ops) - 1:
            continue
        for out in op.outputs:
            if out in block.outputs:
                return False

    # remove all the ops, and replace with a gelu op
    out_name = mul_op4.outputs[0].name
    x = mb.gelu(x=root_var, mode="TANH_APPROXIMATION", name=out_name, before_op=pow_op)

    mul_op4.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=mul_op4, old_var=mul_op4.outputs[0], new_var=x
    )
    # Remove all the ops at once
    block.remove_ops(all_ops)
    return True