def export_onnx()

in src/export.py [0:0]


def export_onnx(train_config=None, out_dir=None):
    print("exporting onnx started ... ")

    if train_config is None:
        config = Config.init()
        train_config = TrainConfig()
        train_config.initialize(config)
        if config.checkPointName is not None and config.checkPointName != "":
            train_config.load_specific_weights(config.checkPointName)
        else:
            train_config.load_latest_weights()
    else:
        config = train_config.config_file

    if out_dir is None:
        out_dir = config.logDir
    else:
        os.makedirs(out_dir, exist_ok=True)

    with open(f"{out_dir}/dataset_info.txt", "w") as f:
        f.write("view_cell_center = " + str(train_config.dataset_info.view.view_cell_center) + "\n")
        f.write("view_cell_size = " + str(train_config.dataset_info.view.view_cell_size) + "\n")
        f.write("depth_range = " + str(train_config.dataset_info.depth_range) + "\n")
        f.write("fov = " + str(train_config.dataset_info.view.fov) + "\n")
        f.write("focal = " + str(train_config.dataset_info.view.focal) + "\n")
        f.write("camera_scale = " + str(train_config.dataset_info.view.camera_scale) + "\n")
        f.write("max_depth = " + str(train_config.dataset_info.depth_max) + "\n")

    if train_config.train_dataset is None:
        train_config.import_train_dataset()

    img_samples = create_sample_wrapper(train_config.train_dataset[0], train_config, True)
    input_names = ["input_1"]
    output_names = ["output1"]

    for b1 in img_samples.batches(128):
        _, inference_dicts = train_config.inference(b1)

        input_feature_batches = []

        for i in range(len(inference_dicts)):
            input_feature_batches.append(inference_dicts[i]['InputFeatureBatch'])

        del inference_dicts
        torch.cuda.empty_cache()

        for model_idx in range(0, len(train_config.models)):
            m = train_config.models[model_idx]

            torch.onnx.export(m, input_feature_batches[model_idx], f"{out_dir}/model{model_idx}.onnx", verbose=True,
                              export_params=True, input_names=input_names, output_names=output_names,
                              dynamic_axes={'input_1': {0: '-1'}, 'output1': {0: '-1'}})

            b_ = input_feature_batches[model_idx].cpu().numpy()
            f = open(f"{out_dir}/feature_sample.txt", "w")

            batch = input_feature_batches[model_idx]
            input_feature_batches[model_idx] = None
            del batch
            torch.cuda.empty_cache()

        break