ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue()

in winml/lib/Api/LearningModelBinding.cpp [241:378]


ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue(
    const winrt::com_ptr<_winml::IValue> value,
    ILearningModelFeatureDescriptor& descriptor) {
  bool out;
  if (SUCCEEDED(value->IsTensor(&out)) && out) {
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float, &out)) && out) {
      if (descriptor.Kind() == LearningModelFeatureKind::Image) {
        // TODO: this format for unbound output needs more discussion
        wgi::BitmapPixelFormat format = descriptor.as<ImageFeatureDescriptor>()->BitmapPixelFormat();
        std::vector<int64_t> shape;
        value->GetTensorShape(shape);
        uint32_t width = static_cast<uint32_t>(shape[3]);
        uint32_t height = static_cast<uint32_t>(shape[2]);
        uint32_t batchSize = static_cast<uint32_t>(shape[0]);
        return winmlp::ImageFeatureValue::Create(batchSize, format, width, height);
      } else {
        return winmlp::TensorFloat::Create();
      }
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Double, &out)) && out) {
      return winmlp::TensorDouble::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::String, &out)) && out) {
      return winmlp::TensorString::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt8, &out)) && out) {
      return winmlp::TensorUInt8Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int8, &out)) && out) {
      return winmlp::TensorInt8Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt16, &out)) && out) {
      return winmlp::TensorUInt16Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int16, &out)) && out) {
      return winmlp::TensorInt16Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt32, &out)) && out) {
      return winmlp::TensorUInt32Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int32, &out)) && out) {
      return winmlp::TensorInt32Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt64, &out)) && out) {
      return winmlp::TensorUInt64Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int64, &out)) && out) {
      return winmlp::TensorInt64Bit::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Boolean, &out)) && out) {
      return winmlp::TensorBoolean::Create();
    }
    if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float16, &out)) && out) {
      return winmlp::TensorFloat16Bit::Create();
    }
  }

  // Maps
  if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::String, &out)) && out) {
    return winmlp::MapStringToString::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Int64, &out)) && out) {
    return winmlp::MapStringToInt64Bit::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Float, &out)) && out) {
    return winmlp::MapStringToFloat::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Double, &out)) && out) {
    return winmlp::MapStringToDouble::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::String, &out)) && out) {
    return winmlp::MapInt64BitToString::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Int64, &out)) && out) {
    return winmlp::MapInt64BitToInt64Bit::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) {
    return winmlp::MapInt64BitToFloat::Create();
  }
  if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Double, &out)) && out) {
    return winmlp::MapInt64BitToDouble::Create();
  }
  // Sequences
  if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::String, TensorKind::Float, &out)) && out) {
    return winmlp::SequenceMapStringFloat::Create();
  }
  if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) {
    return winmlp::SequenceMapInt64BitFloat::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Float, &out)) && out) {
    return winmlp::SequenceTensorFloat::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Double, &out)) && out) {
    return winmlp::SequenceTensorDouble::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::String, &out)) && out) {
    return winmlp::SequenceTensorString::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::UInt8, &out)) && out) {
    return winmlp::SequenceTensorUInt8Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Int8, &out)) && out) {
    return winmlp::SequenceTensorInt8Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::UInt16, &out)) && out) {
    return winmlp::SequenceTensorUInt16Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Int16, &out)) && out) {
    return winmlp::SequenceTensorInt16Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::UInt32, &out)) && out) {
    return winmlp::SequenceTensorUInt32Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Int32, &out)) && out) {
    return winmlp::SequenceTensorInt32Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::UInt64, &out)) && out) {
    return winmlp::SequenceTensorUInt64Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Int64, &out)) && out) {
    return winmlp::SequenceTensorInt64Bit::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Boolean, &out)) && out) {
    return winmlp::SequenceTensorBoolean::Create();
  }
  if (SUCCEEDED(value->IsOfVectorTensorType(TensorKind::Float16, &out)) && out) {
    return winmlp::SequenceTensorFloat16Bit::Create();
  }

  auto utf8_name = _winml::Strings::UTF8FromHString(descriptor.Name());
  WINML_THROW_HR_IF_TRUE_MSG(
      E_UNEXPECTED,
      true,
      "The engine produced an unexpected evaluation output for unbound output variable %s.",
      utf8_name.c_str());

  return nullptr;
}