in tensorflow/tensorflow/lite/toco/model_cmdline_flags.cc [181:419]
void ReadModelFlagsFromCommandLineFlags(
const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
// Load proto containing the initial model flags.
// Additional flags specified on the command line will overwrite the values.
if (parsed_model_flags.model_flags_file.specified()) {
string model_flags_file_contents;
QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
&model_flags_file_contents,
port::file::Defaults())
.ok())
<< "Specified --model_flags_file="
<< parsed_model_flags.model_flags_file.value()
<< " was not found or could not be read";
QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
model_flags))
<< "Specified --model_flags_file="
<< parsed_model_flags.model_flags_file.value()
<< " could not be parsed";
}
#ifdef PLATFORM_GOOGLE
CHECK(!((base::SpecifiedOnCommandLine("batch") &&
parsed_model_flags.variable_batch.specified())))
<< "The --batch and --variable_batch flags are mutually exclusive.";
#endif
CHECK(!(parsed_model_flags.output_array.specified() &&
parsed_model_flags.output_arrays.specified()))
<< "The --output_array and --vs flags are mutually exclusive.";
if (parsed_model_flags.output_array.specified()) {
model_flags->add_output_arrays(parsed_model_flags.output_array.value());
}
if (parsed_model_flags.output_arrays.specified()) {
std::vector<string> output_arrays =
absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
for (const string& output_array : output_arrays) {
model_flags->add_output_arrays(output_array);
}
}
const bool uses_single_input_flags =
parsed_model_flags.input_array.specified() ||
parsed_model_flags.mean_value.specified() ||
parsed_model_flags.std_value.specified() ||
parsed_model_flags.input_shape.specified();
const bool uses_multi_input_flags =
parsed_model_flags.input_arrays.specified() ||
parsed_model_flags.mean_values.specified() ||
parsed_model_flags.std_values.specified() ||
parsed_model_flags.input_shapes.specified();
QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
<< "Use either the singular-form input flags (--input_array, "
"--input_shape, --mean_value, --std_value) or the plural form input "
"flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
"but not both forms within the same command line.";
if (parsed_model_flags.input_array.specified()) {
QCHECK(uses_single_input_flags);
model_flags->add_input_arrays()->set_name(
parsed_model_flags.input_array.value());
}
if (parsed_model_flags.input_arrays.specified()) {
QCHECK(uses_multi_input_flags);
for (const auto& input_array :
absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
model_flags->add_input_arrays()->set_name(string(input_array));
}
}
if (parsed_model_flags.mean_value.specified()) {
QCHECK(uses_single_input_flags);
model_flags->mutable_input_arrays(0)->set_mean_value(
parsed_model_flags.mean_value.value());
}
if (parsed_model_flags.mean_values.specified()) {
QCHECK(uses_multi_input_flags);
std::vector<string> mean_values =
absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
QCHECK(mean_values.size() == model_flags->input_arrays_size());
for (size_t i = 0; i < mean_values.size(); ++i) {
char* last = nullptr;
model_flags->mutable_input_arrays(i)->set_mean_value(
strtod(mean_values[i].data(), &last));
CHECK(last != mean_values[i].data());
}
}
if (parsed_model_flags.std_value.specified()) {
QCHECK(uses_single_input_flags);
model_flags->mutable_input_arrays(0)->set_std_value(
parsed_model_flags.std_value.value());
}
if (parsed_model_flags.std_values.specified()) {
QCHECK(uses_multi_input_flags);
std::vector<string> std_values =
absl::StrSplit(parsed_model_flags.std_values.value(), ',');
QCHECK(std_values.size() == model_flags->input_arrays_size());
for (size_t i = 0; i < std_values.size(); ++i) {
char* last = nullptr;
model_flags->mutable_input_arrays(i)->set_std_value(
strtod(std_values[i].data(), &last));
CHECK(last != std_values[i].data());
}
}
if (parsed_model_flags.input_data_type.specified()) {
QCHECK(uses_single_input_flags);
IODataType type;
QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
model_flags->mutable_input_arrays(0)->set_data_type(type);
}
if (parsed_model_flags.input_data_types.specified()) {
QCHECK(uses_multi_input_flags);
std::vector<string> input_data_types =
absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
QCHECK(input_data_types.size() == model_flags->input_arrays_size());
for (size_t i = 0; i < input_data_types.size(); ++i) {
IODataType type;
QCHECK(IODataType_Parse(input_data_types[i], &type));
model_flags->mutable_input_arrays(i)->set_data_type(type);
}
}
if (parsed_model_flags.input_shape.specified()) {
QCHECK(uses_single_input_flags);
if (model_flags->input_arrays().empty()) {
model_flags->add_input_arrays();
}
auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
shape->clear_dims();
const IntList& list = parsed_model_flags.input_shape.value();
for (auto& dim : list.elements) {
shape->add_dims(dim);
}
}
if (parsed_model_flags.input_shapes.specified()) {
QCHECK(uses_multi_input_flags);
std::vector<string> input_shapes =
absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
QCHECK(input_shapes.size() == model_flags->input_arrays_size());
for (size_t i = 0; i < input_shapes.size(); ++i) {
auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
shape->clear_dims();
// Treat an empty input shape as a scalar.
if (input_shapes[i].empty()) {
continue;
}
for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
int size;
CHECK(absl::SimpleAtoi(dim_str, &size))
<< "Failed to parse input_shape: " << input_shapes[i];
shape->add_dims(size);
}
}
}
#define READ_MODEL_FLAG(name) \
do { \
if (parsed_model_flags.name.specified()) { \
model_flags->set_##name(parsed_model_flags.name.value()); \
} \
} while (false)
READ_MODEL_FLAG(variable_batch);
#undef READ_MODEL_FLAG
for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
auto* rnn_state_proto = model_flags->add_rnn_states();
for (const auto& kv_pair : element) {
const string& key = kv_pair.first;
const string& value = kv_pair.second;
if (key == "state_array") {
rnn_state_proto->set_state_array(value);
} else if (key == "back_edge_source_array") {
rnn_state_proto->set_back_edge_source_array(value);
} else if (key == "size") {
int32 size = 0;
CHECK(absl::SimpleAtoi(value, &size));
CHECK_GT(size, 0);
rnn_state_proto->set_size(size);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
}
}
CHECK(rnn_state_proto->has_state_array() &&
rnn_state_proto->has_back_edge_source_array() &&
rnn_state_proto->has_size())
<< "--rnn_states must include state_array, back_edge_source_array and "
"size.";
}
for (const auto& element : parsed_model_flags.model_checks.value().elements) {
auto* model_check_proto = model_flags->add_model_checks();
for (const auto& kv_pair : element) {
const string& key = kv_pair.first;
const string& value = kv_pair.second;
if (key == "count_type") {
model_check_proto->set_count_type(value);
} else if (key == "count_min") {
int32 count = 0;
CHECK(absl::SimpleAtoi(value, &count));
CHECK_GE(count, -1);
model_check_proto->set_count_min(count);
} else if (key == "count_max") {
int32 count = 0;
CHECK(absl::SimpleAtoi(value, &count));
CHECK_GE(count, -1);
model_check_proto->set_count_max(count);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
}
}
}
if (!model_flags->has_allow_nonascii_arrays()) {
model_flags->set_allow_nonascii_arrays(
parsed_model_flags.allow_nonascii_arrays.value());
}
if (!model_flags->has_allow_nonexistent_arrays()) {
model_flags->set_allow_nonexistent_arrays(
parsed_model_flags.allow_nonexistent_arrays.value());
}
if (!model_flags->has_change_concat_input_ranges()) {
model_flags->set_change_concat_input_ranges(
parsed_model_flags.change_concat_input_ranges.value());
}
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
CHECK(port::file::GetContents(
parsed_model_flags.arrays_extra_info_file.value(),
&arrays_extra_info_file_contents, port::file::Defaults())
.ok());
ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
model_flags->mutable_arrays_extra_info());
}
}