def run()

in tfjs-converter/python/tensorflowjs/converters/wizard.py [0:0]


def run(dryrun):
  print('Welcome to TensorFlow.js Converter.')
  input_path = [{
      'type': 'input',
      'name': common.INPUT_PATH,
      'message': 'Please provide the path of model file or '
                 'the directory that contains model files. \n'
                 'If you are converting TFHub module please provide the URL.',
      'filter': os.path.expanduser,
      'validate':
          lambda path: 'Please enter a valid path' if not path else True
  }]

  input_params = PyInquirer.prompt(input_path, style=prompt_style)
  detected_input_format, normalized_path = detect_input_format(
      input_params[common.INPUT_PATH])
  input_params[common.INPUT_PATH] = normalized_path

  formats = [
      {
          'type': 'list',
          'name': common.INPUT_FORMAT,
          'message': input_format_message(detected_input_format),
          'choices': input_formats(detected_input_format)
      }, {
          'type': 'list',
          'name': common.OUTPUT_FORMAT,
          'message': 'What is your output format?',
          'choices': available_output_formats,
          'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
                                                (common.KERAS_MODEL,
                                                 common.TFJS_LAYERS_MODEL))
      }
  ]
  format_params = PyInquirer.prompt(formats, input_params, style=prompt_style)
  message = input_path_message(format_params)

  questions = [
      {
          'type': 'input',
          'name': common.INPUT_PATH,
          'message': message,
          'filter': expand_input_path,
          'validate': lambda value: validate_input_path(
              value, format_params[common.INPUT_FORMAT]),
          'when': lambda answers: (not detected_input_format)
      },
      {
          'type': 'list',
          'name': common.SAVED_MODEL_TAGS,
          'choices': available_tags,
          'message': 'What is tags for the saved model?',
          'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
                                   and
                                   (common.OUTPUT_FORMAT not in format_params
                                    or format_params[common.OUTPUT_FORMAT] ==
                                    common.TFJS_GRAPH_MODEL))
      },
      {
          'type': 'list',
          'name': common.SIGNATURE_NAME,
          'message': 'What is signature name of the model?',
          'choices': available_signature_names,
          'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
                                   and
                                   (common.OUTPUT_FORMAT not in format_params
                                    or format_params[common.OUTPUT_FORMAT] ==
                                    common.TFJS_GRAPH_MODEL))
      },
      {
          'type': 'list',
          'name': 'quantize',
          'message': 'Do you want to compress the model? '
                     '(this will decrease the model precision.)',
          'choices': [{
              'name': 'No compression (Higher accuracy)',
              'value': None
          }, {
              'name': 'float16 quantization '
                      '(2x smaller, Minimal accuracy loss)',
              'value': 'float16'
          }, {
              'name': 'uint16 affine quantization (2x smaller, Accuracy loss)',
              'value': 'uint16'
          }, {
              'name': 'uint8 affine quantization (4x smaller, Accuracy loss)',
              'value': 'uint8'
          }]
      },
      {
          'type': 'input',
          'name': common.QUANTIZATION_TYPE_FLOAT16,
          'message': 'Please enter the layers to apply float16 quantization '
                     '(2x smaller, minimal accuracy tradeoff).\n'
                     'Supports wildcard expansion with *, e.g., conv/*/weights',
          'default': '*',
          'when': lambda answers:
                  value_in_list(answers, 'quantize', ('float16'))
      },
      {
          'type': 'input',
          'name': common.QUANTIZATION_TYPE_UINT8,
          'message': 'Please enter the layers to apply affine 1-byte integer '
                     'quantization (4x smaller, accuracy tradeoff).\n'
                     'Supports wildcard expansion with *, e.g., conv/*/weights',
          'default': '*',
          'when': lambda answers:
                  value_in_list(answers, 'quantize', ('uint8'))
      },
      {
          'type': 'input',
          'name': common.QUANTIZATION_TYPE_UINT16,
          'message': 'Please enter the layers to apply affine 2-byte integer '
                     'quantization (2x smaller, accuracy tradeoff).\n'
                     'Supports wildcard expansion with *, e.g., conv/*/weights',
          'default': '*',
          'when': lambda answers:
                  value_in_list(answers, 'quantize', ('uint16'))
      },
      {
          'type': 'input',
          'name': common.WEIGHT_SHARD_SIZE_BYTES,
          'message': 'Please enter shard size (in bytes) of the weight files?',
          'default': str(4 * 1024 * 1024),
          'validate':
              lambda size: ('Please enter a positive integer' if not
                            (size.isdigit() and int(size) > 0) else True),
          'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT,
                                                 (common.TFJS_LAYERS_MODEL,
                                                  common.TFJS_GRAPH_MODEL)) or
                                   value_in_list(answers, common.INPUT_FORMAT,
                                                 (common.TF_SAVED_MODEL,
                                                  common.TF_HUB_MODEL)))
      },
      {
          'type': 'confirm',
          'name': common.SPLIT_WEIGHTS_BY_LAYER,
          'message': 'Do you want to split weights by layers?',
          'default': False,
          'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT,
                                                 (common.TFJS_LAYERS_MODEL)) and
                                   value_in_list(answers, common.INPUT_FORMAT,
                                                 (common.KERAS_MODEL,
                                                  common.KERAS_SAVED_MODEL)))
      },
      {
          'type': 'confirm',
          'name': common.SKIP_OP_CHECK,
          'message': 'Do you want to skip op validation? \n'
                     'This will allow conversion of unsupported ops, \n'
                     'you can implement them as custom ops in tfjs-converter.',
          'default': False,
          'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
                                                (common.TF_SAVED_MODEL,
                                                 common.TF_HUB_MODEL))
      },
      {
          'type': 'confirm',
          'name': common.STRIP_DEBUG_OPS,
          'message': 'Do you want to strip debug ops? \n'
                     'This will improve model execution performance.',
          'default': True,
          'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
                                                (common.TF_SAVED_MODEL,
                                                 common.TF_HUB_MODEL))
      },
      {
          'type': 'confirm',
          'name': common.CONTROL_FLOW_V2,
          'message': 'Do you want to enable Control Flow V2 ops? \n'
                     'This will improve branch and loop execution performance.',
          'default': True,
          'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
                                                (common.TF_SAVED_MODEL,
                                                 common.TF_HUB_MODEL))
      },
      {
          'type': 'input',
          'name': common.METADATA,
          'message': 'Do you want to provide metadata? \n'
                     'Provide your own metadata in the form: \n'
                     'metadata_key:path/metadata.json \n'
                     'Separate multiple metadata by comma.'
      }
  ]
  params = PyInquirer.prompt(questions, format_params, style=prompt_style)

  output_options = [
      {
          'type': 'input',
          'name': common.OUTPUT_PATH,
          'message': 'Which directory do you want to save '
                     'the converted model in?',
          'filter': lambda path: update_output_path(path, params),
          'validate': lambda path: len(path) > 0
      },
      {
          'type': 'confirm',
          'message': 'The output already directory exists, '
                     'do you want to overwrite it?',
          'name': 'overwrite_output_path',
          'default': False,
          'when': lambda ans: output_path_exists(ans[common.OUTPUT_PATH])
      }
  ]

  while (common.OUTPUT_PATH not in params or
         output_path_exists(params[common.OUTPUT_PATH]) and
         not params['overwrite_output_path']):
    params = PyInquirer.prompt(output_options, params, style=prompt_style)

  arguments = generate_arguments(params)
  print('converter command generated:')
  print('tensorflowjs_converter %s' % ' '.join(arguments))
  print('\n\n')

  log_file = os.path.join(tempfile.gettempdir(), 'converter_error.log')
  if not dryrun:
    try:
      converter.convert(arguments)
      print('\n\nFile(s) generated by conversion:')

      print("Filename {0:25} Size(bytes)".format(''))
      total_size = 0
      output_path = params[common.OUTPUT_PATH]
      if os.path.isfile(output_path):
        output_path = os.path.dirname(output_path)
      for basename in sorted(os.listdir(output_path)):
        filename = os.path.join(output_path, basename)
        size = os.path.getsize(filename)
        print("{0:35} {1}".format(basename, size))
        total_size += size
      print("Total size:{0:24} {1}".format('', total_size))
    except BaseException:
      exc_type, exc_value, exc_traceback = sys.exc_info()
      lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
      with open(log_file, 'a') as writer:
        writer.write(''.join(line for line in lines))
      print('Conversion failed, please check error log file %s.' % log_file)