in mlmodel/src/Validation/WordTaggerValidator.cpp [17:182]
Result validate<MLModelType_wordTagger>(const Specification::Model& format) {
const auto& interface = format.description();
const auto& inputs = interface.input();
const auto& outputs = interface.output();
// make sure model is a word tager
if (!format.has_wordtagger()) {
return Result(ResultType::INVALID_MODEL_PARAMETERS, "Model not a word tagger.");
}
const auto& wordTagger = format.wordtagger();
int numNonEmptyOutputFeatures = 0;
if (wordTagger.tokentagsoutputfeaturename().length() > 0) {
numNonEmptyOutputFeatures++;
}
if (wordTagger.tokensoutputfeaturename().length() > 0) {
numNonEmptyOutputFeatures++;
}
if (wordTagger.tokenlocationsoutputfeaturename().length() > 0) {
numNonEmptyOutputFeatures++;
}
if (wordTagger.tokenlengthsoutputfeaturename().length() > 0) {
numNonEmptyOutputFeatures++;
}
if (outputs.size() != numNonEmptyOutputFeatures) {
return Result(ResultType::TOO_MANY_FEATURES_FOR_MODEL_TYPE,
"More model output features than the output features of the word tagger model.");
}
int tokensOutputIndex = -1;
int tokenTagsOutputIndex = -1;
int tokenLocationsOutputIndex = -1;
int tokenLengthsOutputIndex = -1;
// check that any interface output can be found from word tagger outputs
for (int i = 0; i < outputs.size(); i++) {
const std::string name = outputs[i].name();
bool present = false;
if (wordTagger.tokensoutputfeaturename().compare(name) == 0) {
tokensOutputIndex = i;
present = true;
}
if (wordTagger.tokentagsoutputfeaturename().compare(name) == 0) {
tokenTagsOutputIndex = i;
present = true;
}
if (wordTagger.tokenlocationsoutputfeaturename().compare(name) == 0) {
tokenLocationsOutputIndex = i;
present = true;
}
if (wordTagger.tokenlengthsoutputfeaturename().compare(name) == 0) {
tokenLengthsOutputIndex = i;
present = true;
}
if (!present) {
return Result(ResultType::TOO_MANY_FEATURES_FOR_MODEL_TYPE,
"Output feature '" + name + "' was not required by the output features of the word tagger model.");
}
}
// token tags is the required output feature name, while tokens/locations/lengths are optional
if (tokenTagsOutputIndex == -1) {
return Result(ResultType::INTERFACE_FEATURE_NAME_MISMATCH,
"Expected feature '" + wordTagger.tokentagsoutputfeaturename() + "' (defined by tokenTagsOutputFeatureName) to the model is not present in the model description.");
}
Result result;
// Validate the inputs: only one input with string type is allowed
result = validateDescriptionsContainFeatureWithTypes(inputs, 1, {Specification::FeatureType::kStringType});
if (!result.good()) {
return result;
}
// Validate the outputs: only sequence type is allowed for any output
result = validateDescriptionsContainFeatureWithTypes(outputs, outputs.size(), {Specification::FeatureType::kSequenceType});
if (!result.good()) {
return result;
}
// validate the individual output type, token tags has to be a sequence of strings
const auto& tokenTagsOutputSequenceType = outputs[tokenTagsOutputIndex].type().sequencetype();
if (tokenTagsOutputSequenceType.Type_case() != Specification::SequenceFeatureType::kStringType) {
std::stringstream out;
out << "Unsupported type \"" << MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(tokenTagsOutputSequenceType.Type_case()))
<< "\" for feature \"" << wordTagger.tokentagsoutputfeaturename() + "\". Should be: "
<< MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(Specification::SequenceFeatureType::kStringType));
return Result(ResultType::FEATURE_TYPE_INVARIANT_VIOLATION, out.str());
}
// validate the individual output type, tokens (if present) has to be a sequence of strings
if (tokensOutputIndex != -1) {
const auto& tokensOutputSequenceType = outputs[tokensOutputIndex].type().sequencetype();
if (tokensOutputSequenceType.Type_case() != Specification::SequenceFeatureType::kStringType) {
std::stringstream out;
out << "Unsupported type \"" << MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(tokensOutputSequenceType.Type_case()))
<< "\" for feature \"" << wordTagger.tokensoutputfeaturename() + "\". Should be: "
<< MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(Specification::SequenceFeatureType::kStringType));
return Result(ResultType::FEATURE_TYPE_INVARIANT_VIOLATION, out.str());
}
}
// validate the individual output type, token locations (if present) has to be a sequence of integers
if (tokenLocationsOutputIndex != -1) {
const auto& tokenLocationsOutputSequenceType = outputs[tokenLocationsOutputIndex].type().sequencetype();
if (tokenLocationsOutputSequenceType.Type_case() != Specification::SequenceFeatureType::kInt64Type) {
std::stringstream out;
out << "Unsupported type \"" << MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(tokenLocationsOutputSequenceType.Type_case()))
<< "\" for feature \"" << wordTagger.tokenlocationsoutputfeaturename() + "\". Should be: "
<< MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(Specification::SequenceFeatureType::kInt64Type));
return Result(ResultType::FEATURE_TYPE_INVARIANT_VIOLATION, out.str());
}
}
// validate the individual output type, token lengths (if present) has to be a sequence of integers
if (tokenLengthsOutputIndex != -1) {
const auto& tokenLengthsOutputSequenceType = outputs[tokenLengthsOutputIndex].type().sequencetype();
if (tokenLengthsOutputSequenceType.Type_case() != Specification::SequenceFeatureType::kInt64Type) {
std::stringstream out;
out << "Unsupported type \"" << MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(tokenLengthsOutputSequenceType.Type_case()))
<< "\" for feature \"" << wordTagger.tokenlengthsoutputfeaturename() + "\". Should be: "
<< MLFeatureTypeType_Name(static_cast<MLFeatureTypeType>(Specification::SequenceFeatureType::kInt64Type));
return Result(ResultType::FEATURE_TYPE_INVARIANT_VIOLATION, out.str());
}
}
// Validate the model parameters
if (wordTagger.revision() == 0) {
return Result(ResultType::INVALID_MODEL_PARAMETERS, "Model revision number not set. Must be >= 1");
}
int numTags;
switch (wordTagger.Tags_case()) {
case Specification::CoreMLModels::WordTagger::kStringTags:
numTags = wordTagger.stringtags().vector_size();
break;
case Specification::CoreMLModels::WordTagger::TAGS_NOT_SET:
numTags = -1;
break;
}
if (numTags <= 0) {
return Result(ResultType::INVALID_MODEL_PARAMETERS, "Model output tags not set. Must have at least one tag");
}
if (wordTagger.modelparameterdata().empty()) {
return Result(ResultType::INVALID_MODEL_PARAMETERS, "Model parameter data not set");
}
return result;
}