def _try_match_and_transform_pattern_4()

in coremltools/converters/mil/mil/passes/layernorm_instancenorm_pattern_fusion.py [0:0]


def _try_match_and_transform_pattern_4(reduce_op: Operation, block: Block) -> bool:
    """
    Identify the pattern:
    y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])

    This pattern corresponds to, should be fused as instance_norm.
    All of the following must be satisty:
    1) Input is rank 4 tensor
    2) Reduce operates on spatial dimensions axes=[-2, -1], or axes=[-3, -2] (a
       channel first to channel last transpose would be inserted in such case)
    3) Gamma and beta are both shape (C,) after squeeze, where C is number of channels

    |-----------|
    |           V
    |------> mul_square1 -----> sum1 -----> mul_mean1
    |                                           |
    |                                           V
    x --> sum --> mul_mean ==> mul_square --> sub_variance --> add_eps --> rsqrt
    |                |                                                      |
    |                |                                                      V
    |                |                                                  mul_gamma
    |                |                                                      |
    |                |                                            |----------------|
    |                |                                            |                V
    |                |--------------------------------------------+-------------> mul2
    |                                                             V                |
    |----------------------------------------------------------> mul1              |
                                                                  |                V
                                                                  |             sub_beta --> add --> [...]
                                                                  |                           ^
                                                                  |---------------------------|
    """
    ops_to_remove = []
    root_var = reduce_op.x

    if root_var.shape is None:
        return False

    rank = len(root_var.shape)

    # check that root_var feeds into exactly 4 ops
    if len(root_var.child_ops) != 4:
        return False
    if root_var.op is not None and not _check_child_op_types(
        root_var.op, child_op_types=["mul", "mul", "reduce_sum", "mul"]
    ):
        return False

    # check 1st reduce_sum op
    if not _check_reduce_op(reduce_op, mode="reduce_sum"):
        return False
    ops_to_remove.append(reduce_op)

    # check mul (mean) op
    mul_mean_op = _try_get_child_op_type(reduce_op, "mul")
    if mul_mean_op is None:
        return False
    if mul_mean_op.y.shape != ():
        return False
    ops_to_remove.append(mul_mean_op)

    # check 1st mul (square) op
    if not _check_child_op_types(mul_mean_op, child_op_types=["mul", "mul", "mul"]):
        return False
    # both 0 and 1 should be mul square op
    mul_square_op = _try_get_child_op_type(mul_mean_op, "mul")
    if mul_square_op is None:
        return False
    if _try_get_child_op_type(mul_mean_op, "mul", index=1) is None:
        return False
    ops_to_remove.append(mul_square_op)

    # Check another branch

    # check 2nd mul (square) op
    # both 0 and 1 should be mul square op 1
    mul_square_op2 = list(root_var.child_ops)[0]
    ops_to_remove.append(mul_square_op2)

    # check 2nd reduce sum
    reduce_op2 = _try_get_child_op_type(mul_square_op2, child_op_type="reduce_sum")
    if not _check_reduce_op(reduce_op2, "reduce_sum"):
        return False
    ops_to_remove.append(reduce_op2)

    # check mul after 2nd reduce op
    mul_mean_op2 = _try_get_child_op_type(reduce_op2, "mul")
    if mul_mean_op2 is None:
        return False
    if mul_mean_op2.y.shape != ():
        return False
    ops_to_remove.append(mul_mean_op2)

    # check sub (variance)
    sub_variance_op = _try_get_child_op_type(mul_mean_op2, "sub")
    if sub_variance_op is None:
        return False
    if sub_variance_op.y != mul_square_op.outputs[0]:
        return False
    ops_to_remove.append(sub_variance_op)

    # check add op (epsilon)
    add_eps_op = _try_get_child_op_type(sub_variance_op, "add")
    if add_eps_op is None:
        return False
    epsilon_var = (
        add_eps_op.y if add_eps_op.x == sub_variance_op.outputs[0] else add_eps_op.x
    )
    if epsilon_var.val is None or len(epsilon_var.val.shape) != 0:
        return False  # must be scalar
    ops_to_remove.append(add_eps_op)

    # check rsqrt
    rsqrt_op = _try_get_child_op_type(add_eps_op, "rsqrt")
    if rsqrt_op is None:
        return False
    ops_to_remove.append(rsqrt_op)

    # check mul (gamma)
    mul_gamma_op = _try_get_child_op_type(rsqrt_op, "mul")
    if mul_gamma_op is None:
        return False
    gamma_var = (
        mul_gamma_op.y if mul_gamma_op.x == rsqrt_op.outputs[0] else mul_gamma_op.x
    )
    if gamma_var.val is None:
        return False
    ops_to_remove.append(mul_gamma_op)

    # check 2 muls after the gamma mul
    if not _check_child_op_types(mul_gamma_op, ["mul", "mul"]):
        return False
    mul_gamma_child_ops = list(mul_gamma_op.outputs[0].child_ops)
    mul_op1 = mul_gamma_child_ops[0]
    mul_op2 = mul_gamma_child_ops[1]
    mul_op1_other_var = mul_op1.x if mul_op1.y == mul_gamma_op.outputs[0] else mul_op1.y
    mul_op2_other_var = mul_op2.x if mul_op2.y == mul_gamma_op.outputs[0] else mul_op2.y
    if not (
        (mul_op1_other_var == root_var and mul_op2_other_var == mul_square_op.x)
        or (mul_op1_other_var == mul_square_op.x and mul_op2_other_var == root_var)
    ):
        return False
    if mul_op1_other_var == root_var:
        mul_op1, mul_op2 = mul_op1, mul_op2
    else:
        mul_op2, mul_op1 = mul_op1, mul_op2
    ops_to_remove.append(mul_op1)
    ops_to_remove.append(mul_op2)

    # check sub with beta
    sub_beta_op = _try_get_child_op_type(mul_op2, "sub")
    if sub_beta_op is None:
        return False
    if sub_beta_op.y != mul_op2.outputs[0]:
        return False
    beta_var = sub_beta_op.x
    if beta_var.val is None:
        return False
    ops_to_remove.append(sub_beta_op)

    # check last add op
    add_op = _try_get_child_op_type(sub_beta_op, "add")
    if add_op is None:
        return False
    if not (
        (add_op.x == mul_op1.outputs[0] and add_op.y == sub_beta_op.outputs[0])
        or (add_op.y == mul_op1.outputs[0] and add_op.x == sub_beta_op.outputs[0])
    ):
        return False
    ops_to_remove.append(add_op)

    return _try_apply_transform(
        reduce_op, block, gamma_var, beta_var, epsilon_var, add_op, ops_to_remove
    )