in python/tvm/relay/frontend/tensorflow_ops.py [0:0]
def _conv(opname):
def _impl(inputs, attr, params, mod):
attr["data_format"] = attr["data_format"].decode("utf-8")
flip_layout = False
if opname == "conv_transpose" and attr["data_format"] == "NHWC":
# transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr["strides"][1], attr["strides"][2], attr["strides"][3] = (
attr["strides"][3],
attr["strides"][1],
attr["strides"][2],
)
attr["data_format"] = "NCHW"
# Check whether output shapes attribute is set and not None
if (
opname == "conv_transpose"
and len(attr["_output_shapes"]) > 0
and attr["_output_shapes"][0]
):
tmp_shape = attr["_output_shapes"][0]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
attr["_output_shapes"][0] = tmp_shape
flip_layout = True
inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2]
# NCHW Layout require weights transpose
weights_shape = _infer_shape(inputs[1], mod)
if attr["data_format"] == "NCHW":
tmp_shape = weights_shape
if opname in ["conv", "conv_transpose"]:
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
weights_shape = tmp_shape
input_shape = _infer_shape(inputs_data, mod)
if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
if opname in ["conv", "conv_transpose"]:
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr["data_format"] = "NCHW"
attr["strides"] = [attr["strides"][ii] for ii in (0, 3, 1, 2)]
flip_layout = True
if attr["data_format"] == "NHWC":
in_channels = input_shape[3]
kernel_h, kernel_w, _, depth_mult = weights_shape
attr["kernel_shape"] = (weights_shape[0], weights_shape[1])
if opname == "conv":
attr["channels"] = weights_shape[3]
elif opname == "conv_transpose":
attr["channels"] = weights_shape[2]
else:
attr["channels"] = input_shape[3] * depth_mult
if "dilations" in attr:
attr["dilations"] = (attr["dilations"][1], attr["dilations"][2])
attr["strides"] = (attr["strides"][1], attr["strides"][2])
elif attr["data_format"] == "NCHW":
in_channels = input_shape[1]
_, depth_mult, kernel_h, kernel_w = weights_shape
attr["kernel_shape"] = (weights_shape[2], weights_shape[3])
if opname == "conv":
attr["channels"] = weights_shape[0]
elif opname == "conv_transpose":
attr["channels"] = weights_shape[1]
else:
attr["channels"] = input_shape[1] * depth_mult
if attr["channels"] < 0:
attr["channels"] *= -1
if "dilations" in attr:
attr["dilations"] = (attr["dilations"][2], attr["dilations"][3])
attr["strides"] = (attr["strides"][2], attr["strides"][3])
else:
msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid."
raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
if opname == "depthwise":
attr["groups"] = in_channels
# Fix padding
attr["padding"] = attr["padding"].decode("utf-8")
if attr["padding"] == "VALID":
attr["padding"] = [0, 0]
elif attr["padding"] == "SAME":
stride_h, stride_w = attr["strides"]
kernel_h, kernel_w = attr["kernel_shape"]
pdata_shape = input_shape
# Check whether output shapes attribute is set and not None
if (
opname == "conv_transpose"
and len(attr["_output_shapes"]) > 0
and attr["_output_shapes"][0]
):
pdata_shape = attr["_output_shapes"][0]
if attr["data_format"] == "NHWC":
in_h = pdata_shape[1]
in_w = pdata_shape[2]
else:
in_h = pdata_shape[2]
in_w = pdata_shape[3]
dilation_h = attr["dilations"][0]
dilation_w = attr["dilations"][1]
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
elif attr["padding"] == "EXPLICIT":
paddings = attr["explicit_paddings"]
assert len(paddings) == 8
if flip_layout or attr["data_format"] == "NHWC":
attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]]
else:
attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
if "kernel_layout" not in attr:
if opname in ["conv", "conv_transpose"]:
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
else:
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
op_name=_dimension_picker(
"conv", surfix="_transpose" if opname == "conv_transpose" else ""
),
ignores=["explicit_paddings"],
transforms={
"kernel_shape": "kernel_size",
"data_format": "data_layout",
"dilations": ("dilation", (0, 0)),
"group": ("groups", 1),
},
custom_check=_dimension_constraint(),
)([inputs_data, inputs[1]], attr)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl