in mmdnn/conversion/_script/IRToCode.py [0:0]
def _convert(args):
if args.dstFramework == 'caffe':
from mmdnn.conversion.caffe.caffe_emitter import CaffeEmitter
if args.IRWeightPath is None:
emitter = CaffeEmitter(args.IRModelPath)
else:
assert args.dstWeightPath
emitter = CaffeEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'keras':
from mmdnn.conversion.keras.keras2_emitter import Keras2Emitter
emitter = Keras2Emitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'tensorflow':
from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter
if args.IRWeightPath is None:
# Convert network architecture only
emitter = TensorflowEmitter(args.IRModelPath)
else:
emitter = TensorflowEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'cntk':
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
if args.IRWeightPath is None:
emitter = CntkEmitter(args.IRModelPath)
else:
emitter = CntkEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'coreml':
raise NotImplementedError("CoreML emitter is not finished yet.")
elif args.dstFramework == 'pytorch':
if not args.dstWeightPath or not args.IRWeightPath:
raise ValueError("Need to set a target weight filename.")
from mmdnn.conversion.pytorch.pytorch_emitter import PytorchEmitter
emitter = PytorchEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'mxnet':
from mmdnn.conversion.mxnet.mxnet_emitter import MXNetEmitter
if args.IRWeightPath is None:
emitter = MXNetEmitter(args.IRModelPath)
else:
if args.dstWeightPath is None:
raise ValueError("MXNet emitter needs argument [dstWeightPath(dw)], like -dw mxnet_converted-0000.param")
emitter = MXNetEmitter((args.IRModelPath, args.IRWeightPath, args.dstWeightPath))
elif args.dstFramework == 'onnx':
from mmdnn.conversion.onnx.onnx_emitter import OnnxEmitter
if args.IRWeightPath is None:
raise NotImplementedError("ONNX emitter needs IR weight file")
else:
emitter = OnnxEmitter(args.IRModelPath, args.IRWeightPath)
else:
assert False
emitter.run(args.dstModelPath, args.dstWeightPath, args.phase)
return 0