void ReadModelFlagsFromCommandLineFlags()

in tensorflow/tensorflow/lite/toco/model_cmdline_flags.cc [181:419]


void ReadModelFlagsFromCommandLineFlags(
    const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
  toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");

  // Load proto containing the initial model flags.
  // Additional flags specified on the command line will overwrite the values.
  if (parsed_model_flags.model_flags_file.specified()) {
    string model_flags_file_contents;
    QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
                                   &model_flags_file_contents,
                                   port::file::Defaults())
               .ok())
        << "Specified --model_flags_file="
        << parsed_model_flags.model_flags_file.value()
        << " was not found or could not be read";
    QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
                                             model_flags))
        << "Specified --model_flags_file="
        << parsed_model_flags.model_flags_file.value()
        << " could not be parsed";
  }

#ifdef PLATFORM_GOOGLE
  CHECK(!((base::SpecifiedOnCommandLine("batch") &&
           parsed_model_flags.variable_batch.specified())))
      << "The --batch and --variable_batch flags are mutually exclusive.";
#endif
  CHECK(!(parsed_model_flags.output_array.specified() &&
          parsed_model_flags.output_arrays.specified()))
      << "The --output_array and --vs flags are mutually exclusive.";

  if (parsed_model_flags.output_array.specified()) {
    model_flags->add_output_arrays(parsed_model_flags.output_array.value());
  }

  if (parsed_model_flags.output_arrays.specified()) {
    std::vector<string> output_arrays =
        absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
    for (const string& output_array : output_arrays) {
      model_flags->add_output_arrays(output_array);
    }
  }

  const bool uses_single_input_flags =
      parsed_model_flags.input_array.specified() ||
      parsed_model_flags.mean_value.specified() ||
      parsed_model_flags.std_value.specified() ||
      parsed_model_flags.input_shape.specified();

  const bool uses_multi_input_flags =
      parsed_model_flags.input_arrays.specified() ||
      parsed_model_flags.mean_values.specified() ||
      parsed_model_flags.std_values.specified() ||
      parsed_model_flags.input_shapes.specified();

  QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
      << "Use either the singular-form input flags (--input_array, "
         "--input_shape, --mean_value, --std_value) or the plural form input "
         "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
         "but not both forms within the same command line.";

  if (parsed_model_flags.input_array.specified()) {
    QCHECK(uses_single_input_flags);
    model_flags->add_input_arrays()->set_name(
        parsed_model_flags.input_array.value());
  }
  if (parsed_model_flags.input_arrays.specified()) {
    QCHECK(uses_multi_input_flags);
    for (const auto& input_array :
         absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
      model_flags->add_input_arrays()->set_name(string(input_array));
    }
  }
  if (parsed_model_flags.mean_value.specified()) {
    QCHECK(uses_single_input_flags);
    model_flags->mutable_input_arrays(0)->set_mean_value(
        parsed_model_flags.mean_value.value());
  }
  if (parsed_model_flags.mean_values.specified()) {
    QCHECK(uses_multi_input_flags);
    std::vector<string> mean_values =
        absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
    QCHECK(mean_values.size() == model_flags->input_arrays_size());
    for (size_t i = 0; i < mean_values.size(); ++i) {
      char* last = nullptr;
      model_flags->mutable_input_arrays(i)->set_mean_value(
          strtod(mean_values[i].data(), &last));
      CHECK(last != mean_values[i].data());
    }
  }
  if (parsed_model_flags.std_value.specified()) {
    QCHECK(uses_single_input_flags);
    model_flags->mutable_input_arrays(0)->set_std_value(
        parsed_model_flags.std_value.value());
  }
  if (parsed_model_flags.std_values.specified()) {
    QCHECK(uses_multi_input_flags);
    std::vector<string> std_values =
        absl::StrSplit(parsed_model_flags.std_values.value(), ',');
    QCHECK(std_values.size() == model_flags->input_arrays_size());
    for (size_t i = 0; i < std_values.size(); ++i) {
      char* last = nullptr;
      model_flags->mutable_input_arrays(i)->set_std_value(
          strtod(std_values[i].data(), &last));
      CHECK(last != std_values[i].data());
    }
  }
  if (parsed_model_flags.input_data_type.specified()) {
    QCHECK(uses_single_input_flags);
    IODataType type;
    QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
    model_flags->mutable_input_arrays(0)->set_data_type(type);
  }
  if (parsed_model_flags.input_data_types.specified()) {
    QCHECK(uses_multi_input_flags);
    std::vector<string> input_data_types =
        absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
    QCHECK(input_data_types.size() == model_flags->input_arrays_size());
    for (size_t i = 0; i < input_data_types.size(); ++i) {
      IODataType type;
      QCHECK(IODataType_Parse(input_data_types[i], &type));
      model_flags->mutable_input_arrays(i)->set_data_type(type);
    }
  }
  if (parsed_model_flags.input_shape.specified()) {
    QCHECK(uses_single_input_flags);
    if (model_flags->input_arrays().empty()) {
      model_flags->add_input_arrays();
    }
    auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
    shape->clear_dims();
    const IntList& list = parsed_model_flags.input_shape.value();
    for (auto& dim : list.elements) {
      shape->add_dims(dim);
    }
  }
  if (parsed_model_flags.input_shapes.specified()) {
    QCHECK(uses_multi_input_flags);
    std::vector<string> input_shapes =
        absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
    QCHECK(input_shapes.size() == model_flags->input_arrays_size());
    for (size_t i = 0; i < input_shapes.size(); ++i) {
      auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
      shape->clear_dims();
      // Treat an empty input shape as a scalar.
      if (input_shapes[i].empty()) {
        continue;
      }
      for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
        int size;
        CHECK(absl::SimpleAtoi(dim_str, &size))
            << "Failed to parse input_shape: " << input_shapes[i];
        shape->add_dims(size);
      }
    }
  }

#define READ_MODEL_FLAG(name)                                   \
  do {                                                          \
    if (parsed_model_flags.name.specified()) {                  \
      model_flags->set_##name(parsed_model_flags.name.value()); \
    }                                                           \
  } while (false)

  READ_MODEL_FLAG(variable_batch);

#undef READ_MODEL_FLAG

  for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
    auto* rnn_state_proto = model_flags->add_rnn_states();
    for (const auto& kv_pair : element) {
      const string& key = kv_pair.first;
      const string& value = kv_pair.second;
      if (key == "state_array") {
        rnn_state_proto->set_state_array(value);
      } else if (key == "back_edge_source_array") {
        rnn_state_proto->set_back_edge_source_array(value);
      } else if (key == "size") {
        int32 size = 0;
        CHECK(absl::SimpleAtoi(value, &size));
        CHECK_GT(size, 0);
        rnn_state_proto->set_size(size);
      } else {
        LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
      }
    }
    CHECK(rnn_state_proto->has_state_array() &&
          rnn_state_proto->has_back_edge_source_array() &&
          rnn_state_proto->has_size())
        << "--rnn_states must include state_array, back_edge_source_array and "
           "size.";
  }

  for (const auto& element : parsed_model_flags.model_checks.value().elements) {
    auto* model_check_proto = model_flags->add_model_checks();
    for (const auto& kv_pair : element) {
      const string& key = kv_pair.first;
      const string& value = kv_pair.second;
      if (key == "count_type") {
        model_check_proto->set_count_type(value);
      } else if (key == "count_min") {
        int32 count = 0;
        CHECK(absl::SimpleAtoi(value, &count));
        CHECK_GE(count, -1);
        model_check_proto->set_count_min(count);
      } else if (key == "count_max") {
        int32 count = 0;
        CHECK(absl::SimpleAtoi(value, &count));
        CHECK_GE(count, -1);
        model_check_proto->set_count_max(count);
      } else {
        LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
      }
    }
  }

  if (!model_flags->has_allow_nonascii_arrays()) {
    model_flags->set_allow_nonascii_arrays(
        parsed_model_flags.allow_nonascii_arrays.value());
  }
  if (!model_flags->has_allow_nonexistent_arrays()) {
    model_flags->set_allow_nonexistent_arrays(
        parsed_model_flags.allow_nonexistent_arrays.value());
  }
  if (!model_flags->has_change_concat_input_ranges()) {
    model_flags->set_change_concat_input_ranges(
        parsed_model_flags.change_concat_input_ranges.value());
  }

  if (parsed_model_flags.arrays_extra_info_file.specified()) {
    string arrays_extra_info_file_contents;
    CHECK(port::file::GetContents(
              parsed_model_flags.arrays_extra_info_file.value(),
              &arrays_extra_info_file_contents, port::file::Defaults())
              .ok());
    ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
                                      model_flags->mutable_arrays_extra_info());
  }
}