in coremltools/converters/mil/mil/passes/layernorm_instancenorm_pattern_fusion.py [0:0]
def _try_match_and_transform_pattern_4(reduce_op: Operation, block: Block) -> bool:
"""
Identify the pattern:
y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])
This pattern corresponds to, should be fused as instance_norm.
All of the following must be satisty:
1) Input is rank 4 tensor
2) Reduce operates on spatial dimensions axes=[-2, -1], or axes=[-3, -2] (a
channel first to channel last transpose would be inserted in such case)
3) Gamma and beta are both shape (C,) after squeeze, where C is number of channels
|-----------|
| V
|------> mul_square1 -----> sum1 -----> mul_mean1
| |
| V
x --> sum --> mul_mean ==> mul_square --> sub_variance --> add_eps --> rsqrt
| | |
| | V
| | mul_gamma
| | |
| | |----------------|
| | | V
| |--------------------------------------------+-------------> mul2
| V |
|----------------------------------------------------------> mul1 |
| V
| sub_beta --> add --> [...]
| ^
|---------------------------|
"""
ops_to_remove = []
root_var = reduce_op.x
if root_var.shape is None:
return False
rank = len(root_var.shape)
# check that root_var feeds into exactly 4 ops
if len(root_var.child_ops) != 4:
return False
if root_var.op is not None and not _check_child_op_types(
root_var.op, child_op_types=["mul", "mul", "reduce_sum", "mul"]
):
return False
# check 1st reduce_sum op
if not _check_reduce_op(reduce_op, mode="reduce_sum"):
return False
ops_to_remove.append(reduce_op)
# check mul (mean) op
mul_mean_op = _try_get_child_op_type(reduce_op, "mul")
if mul_mean_op is None:
return False
if mul_mean_op.y.shape != ():
return False
ops_to_remove.append(mul_mean_op)
# check 1st mul (square) op
if not _check_child_op_types(mul_mean_op, child_op_types=["mul", "mul", "mul"]):
return False
# both 0 and 1 should be mul square op
mul_square_op = _try_get_child_op_type(mul_mean_op, "mul")
if mul_square_op is None:
return False
if _try_get_child_op_type(mul_mean_op, "mul", index=1) is None:
return False
ops_to_remove.append(mul_square_op)
# Check another branch
# check 2nd mul (square) op
# both 0 and 1 should be mul square op 1
mul_square_op2 = list(root_var.child_ops)[0]
ops_to_remove.append(mul_square_op2)
# check 2nd reduce sum
reduce_op2 = _try_get_child_op_type(mul_square_op2, child_op_type="reduce_sum")
if not _check_reduce_op(reduce_op2, "reduce_sum"):
return False
ops_to_remove.append(reduce_op2)
# check mul after 2nd reduce op
mul_mean_op2 = _try_get_child_op_type(reduce_op2, "mul")
if mul_mean_op2 is None:
return False
if mul_mean_op2.y.shape != ():
return False
ops_to_remove.append(mul_mean_op2)
# check sub (variance)
sub_variance_op = _try_get_child_op_type(mul_mean_op2, "sub")
if sub_variance_op is None:
return False
if sub_variance_op.y != mul_square_op.outputs[0]:
return False
ops_to_remove.append(sub_variance_op)
# check add op (epsilon)
add_eps_op = _try_get_child_op_type(sub_variance_op, "add")
if add_eps_op is None:
return False
epsilon_var = (
add_eps_op.y if add_eps_op.x == sub_variance_op.outputs[0] else add_eps_op.x
)
if epsilon_var.val is None or len(epsilon_var.val.shape) != 0:
return False # must be scalar
ops_to_remove.append(add_eps_op)
# check rsqrt
rsqrt_op = _try_get_child_op_type(add_eps_op, "rsqrt")
if rsqrt_op is None:
return False
ops_to_remove.append(rsqrt_op)
# check mul (gamma)
mul_gamma_op = _try_get_child_op_type(rsqrt_op, "mul")
if mul_gamma_op is None:
return False
gamma_var = (
mul_gamma_op.y if mul_gamma_op.x == rsqrt_op.outputs[0] else mul_gamma_op.x
)
if gamma_var.val is None:
return False
ops_to_remove.append(mul_gamma_op)
# check 2 muls after the gamma mul
if not _check_child_op_types(mul_gamma_op, ["mul", "mul"]):
return False
mul_gamma_child_ops = list(mul_gamma_op.outputs[0].child_ops)
mul_op1 = mul_gamma_child_ops[0]
mul_op2 = mul_gamma_child_ops[1]
mul_op1_other_var = mul_op1.x if mul_op1.y == mul_gamma_op.outputs[0] else mul_op1.y
mul_op2_other_var = mul_op2.x if mul_op2.y == mul_gamma_op.outputs[0] else mul_op2.y
if not (
(mul_op1_other_var == root_var and mul_op2_other_var == mul_square_op.x)
or (mul_op1_other_var == mul_square_op.x and mul_op2_other_var == root_var)
):
return False
if mul_op1_other_var == root_var:
mul_op1, mul_op2 = mul_op1, mul_op2
else:
mul_op2, mul_op1 = mul_op1, mul_op2
ops_to_remove.append(mul_op1)
ops_to_remove.append(mul_op2)
# check sub with beta
sub_beta_op = _try_get_child_op_type(mul_op2, "sub")
if sub_beta_op is None:
return False
if sub_beta_op.y != mul_op2.outputs[0]:
return False
beta_var = sub_beta_op.x
if beta_var.val is None:
return False
ops_to_remove.append(sub_beta_op)
# check last add op
add_op = _try_get_child_op_type(sub_beta_op, "add")
if add_op is None:
return False
if not (
(add_op.x == mul_op1.outputs[0] and add_op.y == sub_beta_op.outputs[0])
or (add_op.y == mul_op1.outputs[0] and add_op.x == sub_beta_op.outputs[0])
):
return False
ops_to_remove.append(add_op)
return _try_apply_transform(
reduce_op, block, gamma_var, beta_var, epsilon_var, add_op, ops_to_remove
)