def _alter_conv2d_layout()

in python/tvm/topi/x86/conv2d_alter_op.py [0:0]


def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
    target = tvm.target.Target.current(allow_none=False)
    dispatch_ctx = autotvm.task.DispatchContext.current
    new_attrs = {k: attrs[k] for k in attrs.keys()}

    # Parse the attributes.
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
    data_layout = attrs["data_layout"]
    kernel_layout = attrs["kernel_layout"]
    data_tensor, kernel_tensor = tinfos
    data_dtype = data_tensor.dtype
    kernel_dtype = kernel_tensor.dtype
    out_dtype = out_type.dtype

    if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest):
        cfg = dispatch_ctx.query(target, None)
        workload = cfg.workload
    else:
        impl, outs = relay.backend.compile_engine.select_implementation(
            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
        )
        workload = autotvm.task.get_workload(outs)
        if workload is None:
            # The best implementation is not an AutoTVM template.
            # It may be from the auto-scheduler
            if impl.name.find("winograd") != -1:
                if dilation != (1, 1):
                    logger.warning("Does not support weight pre-transform for dilated convolution.")
                    return None

                assert data_layout == "NHWC" and kernel_layout == "HWIO"
                N, H, W, CI = get_const_tuple(data_tensor.shape)
                KH, KW, _, CO = get_const_tuple(kernel_tensor.shape)

                # Pre-compute weight transformation in winograd
                tile_size = 4
                # HWIO -> OIHW
                kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1])
                # alpha, alpha, CO, CI
                weight = relay.nn.contrib_conv2d_winograd_weight_transform(
                    kernel_transform, tile_size=tile_size
                )
                new_attrs["tile_size"] = tile_size
                new_attrs["channels"] = CO
                return relay.nn.contrib_conv2d_winograd_without_weight_transform(
                    inputs[0], weight, **new_attrs
                )
            return None

        cfg = dispatch_ctx.query(target, workload)

    topi_tmpl = workload[0]

    if topi_tmpl == "conv2d_NCHWc.x86":
        # we only convert conv2d_NCHW to conv2d_NCHWc for x86
        if data_layout == "NCHW" and kernel_layout == "OIHW":
            if cfg.is_fallback:
                _get_default_config(
                    cfg,
                    data_tensor,
                    kernel_tensor,
                    strides,
                    padding,
                    dilation,
                    out_dtype,
                    False,
                    data_layout,
                )
            batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
            out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
            ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]

            # update new attrs
            new_attrs["channels"] = out_channel
            new_attrs["data_layout"] = "NCHW%dc" % ic_bn
            # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
            new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
            new_attrs["out_layout"] = "NCHW%dc" % oc_bn

            # Store altered operator's config
            new_data = te.placeholder(
                (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
            )
            new_kernel = te.placeholder(
                (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn, oc_bn),
                dtype=kernel_tensor.dtype,
            )
            new_workload = autotvm.task.args_to_workload(
                [
                    new_data,
                    new_kernel,
                    strides,
                    padding,
                    dilation,
                    new_attrs["data_layout"],
                    new_attrs["out_layout"],
                    out_dtype,
                ],
                topi_tmpl,
            )
            dispatch_ctx.update(target, new_workload, cfg)
        else:
            assert _NCHWc_matcher.match(data_layout)
            assert _OIHWio_matcher.match(kernel_layout)
        return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

    if topi_tmpl == "conv2d_NCHWc_int8.x86":
        # TODO(@icemelon9, @anijain2305): Need to support data layout NHWC with kernel layout HWIO
        assert data_layout == "NCHW" and kernel_layout == "OIHW"
        if cfg.is_fallback:
            _get_default_config_int8(
                cfg,
                data_tensor,
                kernel_tensor,
                strides,
                padding,
                dilation,
                out_dtype,
                False,
                data_layout,
            )

        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
        out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
        n_elems = 4

        # convert kernel data layout from 4D to 7D
        data_expr, kernel_expr = inputs
        kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
        kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn))
        kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
        kernel_OHWoIi = relay.reshape(
            kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn)
        )
        kernel_OHWoIie = relay.reshape(
            kernel_OHWoIi,
            (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems),
        )
        kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))

        # update new attrs
        new_attrs["channels"] = out_channel
        new_attrs["data_layout"] = "NCHW%dc" % ic_bn
        new_attrs["out_layout"] = "NCHW%dc" % oc_bn

        # Store altered operator's config.
        new_data = te.placeholder(
            (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
        )
        new_kernel = te.placeholder(
            (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
            dtype=kernel_dtype,
        )
        new_workload = autotvm.task.args_to_workload(
            [
                new_data,
                new_kernel,
                strides,
                padding,
                dilation,
                new_attrs["data_layout"],
                new_attrs["out_layout"],
                out_dtype,
            ],
            topi_tmpl,
        )
        dispatch_ctx.update(target, new_workload, cfg)

        return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)

    if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
        if data_layout == "NCHW" and kernel_layout == "OIHW":
            if cfg.is_fallback:
                _get_default_config(
                    cfg,
                    data_tensor,
                    kernel_tensor,
                    strides,
                    padding,
                    dilation,
                    out_dtype,
                    True,
                    data_layout,
                )

            batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
            out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
            ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
            assert channel_multiplier == 1

            # update new attrs
            new_attrs["channels"] = out_channel
            new_attrs["data_layout"] = "NCHW%dc" % ic_bn
            new_attrs["kernel_layout"] = "OIHW1i%do" % oc_bn
            new_attrs["out_layout"] = "NCHW%dc" % oc_bn

            # Store altered operator's config.
            new_data = te.placeholder(
                (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
            )
            new_kernel = te.placeholder(
                (out_channel // oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype
            )
            new_workload = autotvm.task.args_to_workload(
                [
                    new_data,
                    new_kernel,
                    strides,
                    padding,
                    dilation,
                    new_attrs["data_layout"],
                    new_attrs["out_layout"],
                    out_dtype,
                ],
                topi_tmpl,
            )
            dispatch_ctx.update(target, new_workload, cfg)
        else:
            assert _NCHWc_matcher.match(data_layout)
            assert _OIHWio_matcher.match(kernel_layout)
        return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)

    return None