in PyTorchClassification/onnx_export.py [0:0]
def main():
parser = argparse.ArgumentParser(description='PyTorch model ONNX export.')
parser.add_argument('--model_type', default=ModelType.resnext101,
metavar='ARCH', type=ModelType.from_string, choices=list(ModelType),
help='model architecture: ' + ' | '.join([m.name for m in ModelType]) +
' (default: resnext101)')
parser.add_argument('--image_size', default=224, nargs='+',
type=int, metavar='RESOLUTION', help='The side length of the CNN input image ' + \
'(default: 448). For ensembles, provide one resolution for each network.')
parser.add_argument('--weights_path', default=None, nargs='+',
type=str, metavar='PATH', help='Path to a checkpoint to load from. Can ' + \
'be multiple checkpoints when combining multiple models to an ensemble from two different checkpoints.')
parser.add_argument('--num_classes', default=5089,
type=int, metavar='NUM_CLASSES', help='Number of output elements of the network.')
parser.add_argument('--output_prefix', default='onnx_export',
type=str, metavar='OUTPUT', help='Prefix for all output files. It is possible ' + \
'to provide a full path prior to the prefix, e.g. /path/to/output_prefix')
parser.add_argument('--class_mask', default=None, metavar='CLASS_MASK.txt',
type=str, help='You can pass the path path to a text file containing a list of 0 and 1 here. The list ' + \
'should have the same number of outputs as the number of predicted classes. If the i-th entry in the list is 0, ' + \
'then we will remove this class from the exported onnx model.')
args = parser.parse_args()
model = ClassificationModel(args.weights_path, args.image_size, True, args.model_type)
if args.class_mask is not None:
selected_classes = np.loadtxt(args.class_mask).astype(np.bool)
old_fc = model.state_dict()['model.last_linear.weight'].detach().cpu().numpy()[selected_classes]
old_bias = model.state_dict()['model.last_linear.bias'].detach().cpu().numpy()[selected_classes]
model.model.last_linear = nn.Linear(model.model.last_linear.in_features, np.sum(selected_classes)).cuda()
model.state_dict()['model.last_linear.weight'].data.copy_(torch.from_numpy(old_fc))
model.state_dict()['model.last_linear.bias'].data.copy_(torch.from_numpy(old_bias))
model.eval()
def disable_ceil_mode(m):
for mc in list(m.children()):
if type(mc) in [torch.nn.MaxPool2d, torch.nn.AvgPool2d]:
mc.ceil_mode = False
disable_ceil_mode(mc)
disable_ceil_mode(model)
model_wrapper = ModelWrapper(model)
dummy_input = torch.autograd.Variable(torch.randn(1, 3, max(args.image_size), max(args.image_size))).cuda()
print(model_wrapper(dummy_input))
torch.onnx.export(model_wrapper, dummy_input, args.output_prefix + '_model.onnx', verbose=False)
data_loader.save_model({
'epoch': -1,
'args': args,
'state_dict': model.state_dict(),
'best_prec1': -1,
'best_prec3': -1,
'best_prec5': -1,
'classnames' : model.get_classnames(),
'num_classes' : args.num_classes,
'model_type' : args.model_type,
}, False, filename = args.output_prefix + '_model.pytorch')
# Export class names
classname_list = [model.get_classnames()[cid] for cid in range(args.num_classes)]
with open(args.output_prefix + '_classes.txt', 'w', encoding='utf-8') as outfile:
outfile.write('\n'.join(classname_list))