in tfjs-converter/python/tensorflowjs/converters/wizard.py [0:0]
def run(dryrun):
print('Welcome to TensorFlow.js Converter.')
input_path = [{
'type': 'input',
'name': common.INPUT_PATH,
'message': 'Please provide the path of model file or '
'the directory that contains model files. \n'
'If you are converting TFHub module please provide the URL.',
'filter': os.path.expanduser,
'validate':
lambda path: 'Please enter a valid path' if not path else True
}]
input_params = PyInquirer.prompt(input_path, style=prompt_style)
detected_input_format, normalized_path = detect_input_format(
input_params[common.INPUT_PATH])
input_params[common.INPUT_PATH] = normalized_path
formats = [
{
'type': 'list',
'name': common.INPUT_FORMAT,
'message': input_format_message(detected_input_format),
'choices': input_formats(detected_input_format)
}, {
'type': 'list',
'name': common.OUTPUT_FORMAT,
'message': 'What is your output format?',
'choices': available_output_formats,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.KERAS_MODEL,
common.TFJS_LAYERS_MODEL))
}
]
format_params = PyInquirer.prompt(formats, input_params, style=prompt_style)
message = input_path_message(format_params)
questions = [
{
'type': 'input',
'name': common.INPUT_PATH,
'message': message,
'filter': expand_input_path,
'validate': lambda value: validate_input_path(
value, format_params[common.INPUT_FORMAT]),
'when': lambda answers: (not detected_input_format)
},
{
'type': 'list',
'name': common.SAVED_MODEL_TAGS,
'choices': available_tags,
'message': 'What is tags for the saved model?',
'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
and
(common.OUTPUT_FORMAT not in format_params
or format_params[common.OUTPUT_FORMAT] ==
common.TFJS_GRAPH_MODEL))
},
{
'type': 'list',
'name': common.SIGNATURE_NAME,
'message': 'What is signature name of the model?',
'choices': available_signature_names,
'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
and
(common.OUTPUT_FORMAT not in format_params
or format_params[common.OUTPUT_FORMAT] ==
common.TFJS_GRAPH_MODEL))
},
{
'type': 'list',
'name': 'quantize',
'message': 'Do you want to compress the model? '
'(this will decrease the model precision.)',
'choices': [{
'name': 'No compression (Higher accuracy)',
'value': None
}, {
'name': 'float16 quantization '
'(2x smaller, Minimal accuracy loss)',
'value': 'float16'
}, {
'name': 'uint16 affine quantization (2x smaller, Accuracy loss)',
'value': 'uint16'
}, {
'name': 'uint8 affine quantization (4x smaller, Accuracy loss)',
'value': 'uint8'
}]
},
{
'type': 'input',
'name': common.QUANTIZATION_TYPE_FLOAT16,
'message': 'Please enter the layers to apply float16 quantization '
'(2x smaller, minimal accuracy tradeoff).\n'
'Supports wildcard expansion with *, e.g., conv/*/weights',
'default': '*',
'when': lambda answers:
value_in_list(answers, 'quantize', ('float16'))
},
{
'type': 'input',
'name': common.QUANTIZATION_TYPE_UINT8,
'message': 'Please enter the layers to apply affine 1-byte integer '
'quantization (4x smaller, accuracy tradeoff).\n'
'Supports wildcard expansion with *, e.g., conv/*/weights',
'default': '*',
'when': lambda answers:
value_in_list(answers, 'quantize', ('uint8'))
},
{
'type': 'input',
'name': common.QUANTIZATION_TYPE_UINT16,
'message': 'Please enter the layers to apply affine 2-byte integer '
'quantization (2x smaller, accuracy tradeoff).\n'
'Supports wildcard expansion with *, e.g., conv/*/weights',
'default': '*',
'when': lambda answers:
value_in_list(answers, 'quantize', ('uint16'))
},
{
'type': 'input',
'name': common.WEIGHT_SHARD_SIZE_BYTES,
'message': 'Please enter shard size (in bytes) of the weight files?',
'default': str(4 * 1024 * 1024),
'validate':
lambda size: ('Please enter a positive integer' if not
(size.isdigit() and int(size) > 0) else True),
'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT,
(common.TFJS_LAYERS_MODEL,
common.TFJS_GRAPH_MODEL)) or
value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL)))
},
{
'type': 'confirm',
'name': common.SPLIT_WEIGHTS_BY_LAYER,
'message': 'Do you want to split weights by layers?',
'default': False,
'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT,
(common.TFJS_LAYERS_MODEL)) and
value_in_list(answers, common.INPUT_FORMAT,
(common.KERAS_MODEL,
common.KERAS_SAVED_MODEL)))
},
{
'type': 'confirm',
'name': common.SKIP_OP_CHECK,
'message': 'Do you want to skip op validation? \n'
'This will allow conversion of unsupported ops, \n'
'you can implement them as custom ops in tfjs-converter.',
'default': False,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
},
{
'type': 'confirm',
'name': common.STRIP_DEBUG_OPS,
'message': 'Do you want to strip debug ops? \n'
'This will improve model execution performance.',
'default': True,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
},
{
'type': 'confirm',
'name': common.CONTROL_FLOW_V2,
'message': 'Do you want to enable Control Flow V2 ops? \n'
'This will improve branch and loop execution performance.',
'default': True,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
},
{
'type': 'input',
'name': common.METADATA,
'message': 'Do you want to provide metadata? \n'
'Provide your own metadata in the form: \n'
'metadata_key:path/metadata.json \n'
'Separate multiple metadata by comma.'
}
]
params = PyInquirer.prompt(questions, format_params, style=prompt_style)
output_options = [
{
'type': 'input',
'name': common.OUTPUT_PATH,
'message': 'Which directory do you want to save '
'the converted model in?',
'filter': lambda path: update_output_path(path, params),
'validate': lambda path: len(path) > 0
},
{
'type': 'confirm',
'message': 'The output already directory exists, '
'do you want to overwrite it?',
'name': 'overwrite_output_path',
'default': False,
'when': lambda ans: output_path_exists(ans[common.OUTPUT_PATH])
}
]
while (common.OUTPUT_PATH not in params or
output_path_exists(params[common.OUTPUT_PATH]) and
not params['overwrite_output_path']):
params = PyInquirer.prompt(output_options, params, style=prompt_style)
arguments = generate_arguments(params)
print('converter command generated:')
print('tensorflowjs_converter %s' % ' '.join(arguments))
print('\n\n')
log_file = os.path.join(tempfile.gettempdir(), 'converter_error.log')
if not dryrun:
try:
converter.convert(arguments)
print('\n\nFile(s) generated by conversion:')
print("Filename {0:25} Size(bytes)".format(''))
total_size = 0
output_path = params[common.OUTPUT_PATH]
if os.path.isfile(output_path):
output_path = os.path.dirname(output_path)
for basename in sorted(os.listdir(output_path)):
filename = os.path.join(output_path, basename)
size = os.path.getsize(filename)
print("{0:35} {1}".format(basename, size))
total_size += size
print("Total size:{0:24} {1}".format('', total_size))
except BaseException:
exc_type, exc_value, exc_traceback = sys.exc_info()
lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
with open(log_file, 'a') as writer:
writer.write(''.join(line for line in lines))
print('Conversion failed, please check error log file %s.' % log_file)