in src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py [0:0]
def main():
args = parser.parse_args()
args.pretrained = True
if args.checkpoint:
args.pretrained = False
print("==> Creating PyTorch {} model".format(args.model))
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
# for models using SAME padding
model = geffnet.create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
exportable=True)
model.eval()
example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
# the input img_size specified in this script.
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
model(example_input)
print("==> Exporting model to ONNX format at '{}'".format(args.output))
input_names = ["input0"]
output_names = ["output0"]
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
if args.dynamic_size:
dynamic_axes['input0'][2] = 'height'
dynamic_axes['input0'][3] = 'width'
if args.aten_fallback:
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
export_type = torch.onnx.OperatorExportTypes.ONNX
torch_out = torch.onnx._export(
model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
opset_version=args.opset, operator_export_type=export_type)
print("==> Loading and checking exported model from '{}'".format(args.output))
onnx_model = onnx.load(args.output)
onnx.checker.check_model(onnx_model) # assuming throw on error
print("==> Passed")
if args.keep_init and args.aten_fallback:
import caffe2.python.onnx.backend as onnx_caffe2
# Caffe2 loading only works properly in newer PyTorch/ONNX combos when
# keep_initializers_as_inputs and aten_fallback are set to True.
print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
caffe2_backend = onnx_caffe2.prepare(onnx_model)
B = {onnx_model.graph.input[0].name: x.data.numpy()}
c2_out = caffe2_backend.run(B)[0]
np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
print("==> Passed")