def merge_transform_to_mxnet_model()

in apps/deploy/resnet_export.py [0:0]


def merge_transform_to_mxnet_model(mod):
    """ Add Image Transform Logic Into Model """
    svalue = np.array([123., 117., 104.])
    sub_data = relay.Constant(tvm.nd.array(svalue)).astype("float32")
    dvalue = np.array([58.395, 57.12, 57.37])
    divide_data = relay.Constant(tvm.nd.array(dvalue)).astype("float32")

    data_shape = (224, 224, 3)
    data = relay.var("data", relay.TensorType(data_shape, "float32"))

    simple_net = relay.expand_dims(data, axis=0, num_newaxis=1)
    # To do, relay not support dynamic shape now, future need to add resize logic
    # simple_net = relay.image.resize(simple_net, (224, 224), "NHWC", "bilinear", "align_corners")
    simple_net = relay.subtract(simple_net, sub_data)
    simple_net = relay.divide(simple_net, divide_data)
    simple_net = relay.transpose(simple_net, ((0, 3, 1, 2)))

    #merge tranform into pretrained model network
    entry = mod["main"]
    anf = run_opt_pass(entry.body, transform.ToANormalForm())
    call = anf.value
    call_data, weights = call.args
    first_op = op.nn.conv2d(
        simple_net,
        weights,
        strides=call.attrs.strides,
        padding=call.attrs.padding,
        dilation=call.attrs.dilation,
        groups=call.attrs.groups,
        channels=call.attrs.channels,
        kernel_size=call.attrs.kernel_size,
        out_dtype=call.attrs.out_dtype)
    net = relay.expr.Let(anf.var, first_op, anf.body)
    new_params = [data]
    for indx in range(len(entry.params)):
        '''
        By pass first parameter which is input data and get replace with
        new data format(from (1, 224, 224, 3) to (224,224,3))
        '''
        if (indx > 0):
            new_params.append(entry.params[indx])
    '''
    Add param information to fix free varaible found error
    '''
    func = tvm.relay.Function(new_params,
                              net,
                              entry.ret_type,
                              entry.type_params,
                              entry.attrs)
    func = run_opt_pass(func, transform.ToGraphNormalForm())

    mod['main'] = func
    return mod