in coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/fuse_dilation_conv.py [0:0]
def _pattern_match_and_rewrite(gddict, conv_op):
node = gddict[conv_op]
channel_first = node.attr["data_format"].startswith("NC")
if len(node.inputs) == 0 or len(node.outputs) == 0:
return
prev_node = gddict[node.inputs[0]]
next_node = gddict[node.outputs[0]]
expand_node = None
squeeze_node = None
# Check for Conv1D cases
if prev_node.op == "ExpandDims":
# All Conv1D has ExpandDims and Squeeze as pairs.
if next_node.op != "Squeeze":
return
expand_node = prev_node
squeeze_node = next_node
if len(prev_node.inputs) == 0 or len(next_node.outputs) == 0:
return
prev_node = gddict[prev_node.inputs[0]]
next_node = gddict[next_node.outputs[0]]
# Check if Conv1D/Conv2D is surrounded by SpaceToBatchND and BatchToSpaceND
if prev_node.op != "SpaceToBatchND" or next_node.op != "BatchToSpaceND":
return
else:
stb_node = prev_node
bts_node = next_node
dilation_node = gddict[stb_node.inputs[1]]
if dilation_node.value is None:
return
dilation_factor = dilation_node.value.val
if gddict[bts_node.inputs[1]].value is None or np.any(
dilation_factor != gddict[bts_node.inputs[1]].value.val
):
# If SpaceToBatchND and BatchToSpaceND doesn't match, we do not fuse.
return
padding_node = gddict[stb_node.inputs[2]]
if padding_node.value is None:
return
padding_val = padding_node.value.val.flatten()
crop_node = gddict[bts_node.inputs[2]]
if crop_node.value is None:
return
crop_val = crop_node.value.val.flatten()
if expand_node:
dilation_factor = [1] + list(dilation_factor)
padding_val = [0, 0] + list(padding_val)
crop_val = [0, 0] + list(crop_val)
# Trying to inverse the logic of TF generating padding/cropping values for
# SpaceToBatchND and BatchToSpaceND with different padding values in Conv2D.
# Logic extracted from TF's builder at:
# tensorflow/python/ops/nn_ops.py and tensorflow/python/ops/array_ops.py
is_same = False
if np.any(padding_val != 0):
input_shape = gddict[stb_node.inputs[0]].attr.get("_output_shapes", None)
if input_shape is None:
input_shape = gddict[stb_node.inputs[0]].attr.get("shape", None)
else:
input_shape = input_shape[0]
W_node = gddict[node.inputs[1]]
W_shape = None if W_node.op != "Const" else W_node.datatype.get_shape()
if input_shape is None or W_shape is None:
return
W_h, W_w = W_shape[0], W_shape[1]
HW = input_shape[2:] if channel_first else input_shape[1:-1]
if expand_node:
HW = [1] + list(HW)
is_same = _try_same(
HW[0], HW[1], W_h, W_w, dilation_factor, padding_val, crop_val
)
# Re-wiring the nodes to skip SpaceToBatchND.
# We change BatchToSpaceND to Identity since it might be a terminate op.
deleted_nodes = set()
if expand_node:
replace_source(gddict, stb_node, expand_node, stb_node.inputs[0])
else:
replace_source(gddict, stb_node, node, stb_node.inputs[0])
bts_node.op = "Identity"
bts_node.attr = {}
deleted_nodes.update(stb_node.inputs[1:])
deleted_nodes.update([stb_node.name])
deleted_nodes.update(bts_node.inputs[1:])
# Rewrite dilation attribute for (Depthwise)Conv2D
dilation_val = (
[1, 1] + list(dilation_factor)
if node.attr["data_format"] == "NCHW"
else [1] + list(dilation_factor) + [1]
)
node.attr["dilations"] = dilation_val
# Rewrite padding attribute for (Depthwise)Conv2D
# This is due to, TF always plug in VALID padding for Conv2D after
# SpaceToBatchND. If, the original Conv2D is SAME padding, TF would
# automatically insert padding, therefore, we set it as SAME over here.
if is_same:
node.attr["padding"] = "SAME"
# Removing stale attributes for nodes.
if expand_node and "_output_shapes" in expand_node.attr:
del expand_node.attr["_output_shapes"]
if squeeze_node and "_output_shapes" in squeeze_node.attr:
del squeeze_node.attr["_output_shapes"]
if "_output_shapes" in node.attr:
del node.attr["_output_shapes"]
if expand_node and "shape" in expand_node.attr:
del expand_node.attr["shape"]
if squeeze_node and "shape" in squeeze_node.attr:
del squeeze_node.attr["shape"]
if "shape" in node.attr:
del node.attr["shape"]
for d in deleted_nodes:
delete_node(gddict, d)