def convert()

in onnxconverter_common/onnx2py.py [0:0]


def convert(model, out_path):
    global needed_types, const_dir, const_counter, DATA_DIR_TRACED
    needed_types = set()
    if out_path.endswith(".py"):
        out_path = out_path[:-3]
    if os.path.exists(out_path):
        clear_directory(out_path)
    const_dir = out_path
    const_dir_name = os.path.basename(out_path)
    const_counter = 0
    TracingObject.reset_cnt(clear_field_traced)
    TracingObject.reset_cnt(make_external_tensor_traced)
    DATA_DIR_TRACED = TracingObject("DATA_DIR", const_dir)

    model_trace = convert_field(model)

    code = FILE_HEADER % os.path.basename(out_path) + "\n"
    code += "\nfrom onnx import helper, numpy_helper, TensorProto\n"
    if TracingObject.get_cnt(make_external_tensor_traced):
        code += ", external_data_helper"
    code += "\n"
    code += "import onnx\n"
    code += "import numpy as np\n"
    code += "import sys\n"
    if os.path.exists(const_dir):
        code += "import os\n"
        code += "\nDATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), %r)\n" % const_dir_name
    if TracingObject.get_cnt(clear_field_traced):
        code += "\n" + inspect.getsource(clear_field)
    code += "\n" + inspect.getsource(order_repeated_field)
    if TracingObject.get_cnt(make_external_tensor_traced):
        code += "\n" + inspect.getsource(make_external_tensor)
    code += "\n" + inspect.getsource(make_node)
    code += "\n" + inspect.getsource(make_graph)
    code += "\n" + "model = " + repr(model_trace) + "\n"
    code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n"
    code += "    _, out_path = sys.argv\n"
    if TracingObject.get_cnt(make_external_tensor_traced):
        code += "    with open(out_path, 'wb') as f:\n"
        code += "        f.write(model.SerializeToString())\n"
    else:
        code += "    onnx.save(model, out_path)\n"
    with open(out_path + ".py", "wt", encoding='utf8') as file:
        file.write(code)
    if needed_types:
        raise MissingHandlerException("Missing handler for types: %s" % list(needed_types))
    return model_trace