def main()

in src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py [0:0]


def main():
    args = parser.parse_args()

    args.pretrained = True
    if args.checkpoint:
        args.pretrained = False

    print("==> Creating PyTorch {} model".format(args.model))
    # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
    # for models using SAME padding
    model = geffnet.create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained,
        checkpoint_path=args.checkpoint,
        exportable=True)

    model.eval()

    example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)

    # Run model once before export trace, sets padding for models with Conv2dSameExport. This means
    # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
    # the input img_size specified in this script.
    # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
    # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
    # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
    model(example_input)

    print("==> Exporting model to ONNX format at '{}'".format(args.output))
    input_names = ["input0"]
    output_names = ["output0"]
    dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
    if args.dynamic_size:
        dynamic_axes['input0'][2] = 'height'
        dynamic_axes['input0'][3] = 'width'
    if args.aten_fallback:
        export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
    else:
        export_type = torch.onnx.OperatorExportTypes.ONNX

    torch_out = torch.onnx._export(
        model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
        output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
        opset_version=args.opset, operator_export_type=export_type)

    print("==> Loading and checking exported model from '{}'".format(args.output))
    onnx_model = onnx.load(args.output)
    onnx.checker.check_model(onnx_model)  # assuming throw on error
    print("==> Passed")

    if args.keep_init and args.aten_fallback:
        import caffe2.python.onnx.backend as onnx_caffe2
        # Caffe2 loading only works properly in newer PyTorch/ONNX combos when
        # keep_initializers_as_inputs and aten_fallback are set to True.
        print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
        caffe2_backend = onnx_caffe2.prepare(onnx_model)
        B = {onnx_model.graph.input[0].name: x.data.numpy()}
        c2_out = caffe2_backend.run(B)[0]
        np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
        print("==> Passed")