def get_nodes_json_string()

in python/tvm/contrib/mrvl.py [0:0]


def get_nodes_json_string(graph_json):
    """This takes the graph_json string from MrvlJSONSerializer and adds / modifies
    the json string to a form suitable for the Marvell Backend.

    Parameters
    ----------
    graph_json: String
        This is the graph_json string from the MrvlJSONSerializer

    Returns
    -------
    nodes_json_string: string
        This returns the nodes_json string which can be accepted by the Marvell backend.
    """

    dictionary = json.loads(graph_json)
    # Add Marvell Index and rename "op" and "name" fields
    mrvl_idx = 1
    num_in = 0
    for iterator in dictionary["nodes"]:
        if iterator["op"] == "kernel":
            iterator["op"] = "tvm_op"
            iterator["attrs"]["mrvl_nodes_idx"] = [mrvl_idx]
            iterator["attrs"]["kernel_const"] = {}
            iterator["attrs"]["bias_const"] = {}
            iterator["attrs"]["beta_const"] = {}
            iterator["attrs"]["gamma_const"] = {}
            iterator["attrs"]["var_const"] = {}
            iterator["attrs"]["mean_const"] = {}
            iterator["name"] = "tvmgen_mrvl_main" + "_" + str(mrvl_idx - 1)
            mrvl_idx = mrvl_idx + 1
        if iterator["op"] == "input":
            iterator["attrs"]["layer_name"] = ["input"]
            iterator["inputs"] = []
            in_id = iterator["name"].split("_i")[-1]
            iterator["input_id"] = [in_id]
            iterator["attrs"]["dtype"] = iterator["attrs"]["dtype"][0]
            iterator["attrs"]["shape"] = iterator["attrs"]["shape"][0]
            if len(iterator["attrs"]["shape"][0]) == 2:
                iterator["attrs"]["data_layout"] = ["NC"]
            else:
                iterator["attrs"]["data_layout"] = ["NCHW"]
            # Infer Batch Size from the input shape
            batch_size = iterator["attrs"]["shape"][0][0]
            dictionary["batch_size"] = f"{batch_size}"
            num_in = num_in + 1

    # Create a new inputs to store only the previous node input and not the const inputs
    for iterator in dictionary["nodes"]:
        if iterator["op"] == "tvm_op":
            list_prev = []
            for prev in iterator["inputs"]:
                if dictionary["nodes"][prev[0]]["op"] == "tvm_op":
                    mrvl_idx_prev = dictionary["nodes"][prev[0]]["attrs"]["mrvl_nodes_idx"][0]
                    list_prev.append([mrvl_idx_prev + num_in - 1, 0, 0])
                if dictionary["nodes"][prev[0]]["op"] == "input":
                    idx_in = int(dictionary["nodes"][prev[0]]["input_id"][0])
                    list_prev.append([idx_in, 0, 0])
            iterator["node_prev"] = list_prev

    for iterator in dictionary["nodes"]:
        if iterator["op"] == "tvm_op":
            del iterator["inputs"]

    for iterator in dictionary["nodes"]:
        if iterator["op"] == "tvm_op":
            iterator["inputs"] = iterator["node_prev"]

    for iterator in dictionary["nodes"]:
        if iterator["op"] == "tvm_op":
            del iterator["node_prev"]

    # Remove unneeded fields
    del dictionary["node_row_ptr"]

    # Patch up arg_nodes and heads to remove references to constant inputs
    list_nodes = dictionary["arg_nodes"]
    list_nodes_updated = []

    for iterator in list_nodes:
        if dictionary["nodes"][iterator]["op"] != "const":
            if dictionary["nodes"][iterator]["op"] == "input":
                input_name = dictionary["nodes"][iterator]["name"]
                input_num_str = input_name.split("_i", 1)[1]
                input_num = int(input_num_str)
                list_nodes_updated.append(input_num)
            else:
                list_nodes_updated.append(
                    dictionary["nodes"][iterator]["attrs"]["mrvl_nodes_idx"][0]
                )
    dictionary["arg_nodes"] = list_nodes_updated

    # Add additional data required by the runtime such as number of inputs
    # and number of outputs to the subgraph
    num_subgraph_inputs = str(len(list_nodes_updated))
    dictionary["num_subgraph_inputs"] = f"{num_subgraph_inputs}"
    list_heads = dictionary["heads"]
    list_heads_updated = []
    for iterator in list_heads:
        if dictionary["nodes"][iterator[0]]["op"] != "const":
            if iterator[0] != 0:
                get_index = dictionary["nodes"][iterator[0]]["attrs"]["mrvl_nodes_idx"][0]
                new_index = get_index + num_in - 1
                list_heads_updated.append([new_index, 0, 0])
    dictionary["heads"] = list_heads_updated

    num_subgraph_outputs = str(len(list_heads_updated))
    dictionary["num_subgraph_outputs"] = f"{num_subgraph_outputs}"

    # Delete the constant nodes, these are not required for the constants file
    dictionary["nodes"] = [
        feature for feature in dictionary["nodes"] if "const" not in feature["op"]
    ]

    # Remove un-needed array nesting
    for iterator in dictionary["nodes"]:
        if iterator["op"] not in "input":
            for it2 in iterator["attrs"]:
                if it2 not in [
                    "num_inputs",
                    "num_outputs",
                    "mrvl_nodes_idx",
                    "mean_const",
                    "var_const",
                    "beta_const",
                    "kernel_const",
                    "bias_const",
                    "gamma_const",
                    "input_const",
                ]:
                    iterator["attrs"][it2] = iterator["attrs"][it2][0]

    # Now create the dltype and dlshape attributes
    dltype = ["list_str"]
    shape = ["list_shape"]
    list_types = []
    list_shapes = []
    for iterator in dictionary["nodes"]:
        list_types.append(iterator["attrs"]["dtype"][0])
        list_shapes.append(iterator["attrs"]["shape"][0])
    dltype.append(list_types)
    shape.append(list_shapes)
    dict_shape_type = {}
    dict_shape_type["shape"] = shape
    dict_shape_type["dltype"] = dltype
    dictionary["attrs"] = dict_shape_type

    nodes_json_string = json.dumps(dictionary)
    return nodes_json_string