in winml/lib/Api/FeatureValues.h [204:456]
inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable(
_winml::BindingType bindingType,
const wf::IInspectable& inspectable,
const winml::ILearningModelFeatureDescriptor& descriptor) {
// Tensor and ImageFeatureValue types are passed in directly as feature values
if (auto featureValue = inspectable.try_as<winml::ILearningModelFeatureValue>()) {
return featureValue;
}
if (auto videoFrames = inspectable.try_as<wfc::IVector<wm::VideoFrame>>()) {
return (0 == videoFrames.Size()) ? nullptr : winrt::make<winmlp::ImageFeatureValue>(videoFrames);
}
if (bindingType == _winml::BindingType::kInput) {
// Allows to bind IVectorView<VideoFrame> as input.
if (auto videoFrames = inspectable.try_as<wfc::IVectorView<wm::VideoFrame>>()) {
return (0 == videoFrames.Size()) ? nullptr : winrt::make<winmlp::ImageFeatureValue>(videoFrames);
}
}
// ImageFeatureValues Types can be implicitly inferred from the VideoFrame object
if (auto videoFrame = inspectable.try_as<wm::VideoFrame>()) {
return winrt::make<winmlp::ImageFeatureValue>(videoFrame);
}
// MapFeatureValues Types are implicitly inferred from the iinspectable object
if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, float>>()) {
return winmlp::MapStringToFloat::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, double>>()) {
return winmlp::MapStringToDouble::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, int64_t>>()) {
return winmlp::MapStringToInt64Bit::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, winrt::hstring>>()) {
return winmlp::MapStringToString::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<int64_t, float>>()) {
return winmlp::MapInt64BitToFloat::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<int64_t, double>>()) {
return winmlp::MapInt64BitToDouble::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<int64_t, int64_t>>()) {
return winmlp::MapInt64BitToInt64Bit::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMap<int64_t, winrt::hstring>>()) {
return winmlp::MapInt64BitToString::Create(map);
}
if (bindingType == _winml::BindingType::kInput) {
// Feature inputs should be more permissive, and allow for views to be bound since they are read only
if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, float>>()) {
return winmlp::MapStringToFloat::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, double>>()) {
return winmlp::MapStringToDouble::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, int64_t>>()) {
return winmlp::MapStringToInt64Bit::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, winrt::hstring>>()) {
return winmlp::MapStringToString::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<int64_t, float>>()) {
return winmlp::MapInt64BitToFloat::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<int64_t, double>>()) {
return winmlp::MapInt64BitToDouble::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<int64_t, int64_t>>()) {
return winmlp::MapInt64BitToInt64Bit::Create(map);
}
if (auto map = inspectable.try_as<wfc::IMapView<int64_t, winrt::hstring>>()) {
return winmlp::MapInt64BitToString::Create(map);
}
}
if (descriptor.Kind() == winml::LearningModelFeatureKind::Sequence) {
// SequenceFeatureValues Types are implicitly inferred from the iinspectable object
if (auto sequence = inspectable.try_as<wfc::IVector<wfc::IMap<winrt::hstring, float>>>()) {
return winmlp::SequenceMapStringFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<wfc::IMap<int64_t, float>>>()) {
return winmlp::SequenceMapInt64BitFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorFloat>>()) {
return winmlp::SequenceTensorFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorBoolean>>()) {
return winmlp::SequenceTensorBoolean::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorDouble>>()) {
return winmlp::SequenceTensorDouble::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt8Bit>>()) {
return winmlp::SequenceTensorInt8Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt8Bit>>()) {
return winmlp::SequenceTensorUInt8Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt16Bit>>()) {
return winmlp::SequenceTensorUInt16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt16Bit>>()) {
return winmlp::SequenceTensorInt16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt32Bit>>()) {
return winmlp::SequenceTensorUInt32Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt32Bit>>()) {
return winmlp::SequenceTensorInt32Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt64Bit>>()) {
return winmlp::SequenceTensorUInt64Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt64Bit>>()) {
return winmlp::SequenceTensorInt64Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorFloat16Bit>>()) {
return winmlp::SequenceTensorFloat16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorString>>()) {
return winmlp::SequenceTensorString::Create(sequence);
}
if (bindingType == _winml::BindingType::kInput) {
// Feature inputs should be more permissive, and allow for views to be bound since they are read only
if (auto sequence = inspectable.try_as<wfc::IVectorView<wfc::IMap<winrt::hstring, float>>>()) {
return winmlp::SequenceMapStringFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<wfc::IMap<int64_t, float>>>()) {
return winmlp::SequenceMapInt64BitFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorFloat>>()) {
return winmlp::SequenceTensorFloat::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorBoolean>>()) {
return winmlp::SequenceTensorBoolean::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorDouble>>()) {
return winmlp::SequenceTensorDouble::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt8Bit>>()) {
return winmlp::SequenceTensorInt8Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt8Bit>>()) {
return winmlp::SequenceTensorUInt8Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt16Bit>>()) {
return winmlp::SequenceTensorUInt16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt16Bit>>()) {
return winmlp::SequenceTensorInt16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt32Bit>>()) {
return winmlp::SequenceTensorUInt32Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt32Bit>>()) {
return winmlp::SequenceTensorInt32Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt64Bit>>()) {
return winmlp::SequenceTensorUInt64Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt64Bit>>()) {
return winmlp::SequenceTensorInt64Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorFloat16Bit>>()) {
return winmlp::SequenceTensorFloat16Bit::Create(sequence);
}
if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorString>>()) {
return winmlp::SequenceTensorString::Create(sequence);
}
}
}
else if (descriptor.Kind() == winml::LearningModelFeatureKind::Tensor) {
auto tensorDescriptor = descriptor.as<winml::ITensorFeatureDescriptor>();
// Vector of IBuffer Input should be copied into the appropriate Tensor
if (auto buffers = inspectable.try_as<wfc::IIterable<wss::IBuffer>>()) {
if (tensorDescriptor.TensorKind() == winml::TensorKind::Boolean) {
return winmlp::TensorBoolean::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Float) {
return winmlp::TensorFloat::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Double) {
return winmlp::TensorDouble::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Float16) {
return winmlp::TensorFloat16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt8) {
return winmlp::TensorUInt8Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Int8) {
return winmlp::TensorInt8Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt16) {
return winmlp::TensorUInt16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Int16) {
return winmlp::TensorInt16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt32) {
return winmlp::TensorUInt32Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Int32) {
return winmlp::TensorInt32Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt64) {
return winmlp::TensorUInt64Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Int64) {
return winmlp::TensorInt64Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
if (tensorDescriptor.TensorKind() == winml::TensorKind::Float16) {
return winmlp::TensorFloat16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
}
}
using TensorCreator = winml::ILearningModelFeatureValue (*)(BindingType, const wf::IInspectable& inspectable, const winml::ITensorFeatureDescriptor& descriptor);
constexpr std::array<TensorCreator, 13> creators =
{
// Vector and VectorViews of float16 and int8 collide with float and uint8 respectively.
// They are omitted because of this ambiguity and are not constructible via raw winrt collections.
CreateTensorValueFromInspectable<winmlp::TensorBoolean, bool>,
CreateTensorValueFromInspectable<winmlp::TensorFloat, float>,
CreateTensorValueFromInspectable<winmlp::TensorDouble, double>,
CreateTensorValueFromInspectable<winmlp::TensorUInt8Bit, uint8_t>,
CreateTensorValueFromInspectable<winmlp::TensorInt8Bit, uint8_t>,
CreateTensorValueFromInspectable<winmlp::TensorUInt16Bit, uint16_t>,
CreateTensorValueFromInspectable<winmlp::TensorInt16Bit, int16_t>,
CreateTensorValueFromInspectable<winmlp::TensorUInt32Bit, uint32_t>,
CreateTensorValueFromInspectable<winmlp::TensorInt32Bit, int32_t>,
CreateTensorValueFromInspectable<winmlp::TensorUInt64Bit, uint64_t>,
CreateTensorValueFromInspectable<winmlp::TensorInt64Bit, int64_t>,
CreateTensorValueFromInspectable<winmlp::TensorFloat16Bit, float>,
CreateTensorValueFromInspectable<winmlp::TensorString, winrt::hstring>
};
for (const auto& tensorCreator : creators) {
if (auto createdTensor = tensorCreator(bindingType, inspectable, tensorDescriptor)) {
return createdTensor;
}
}
}
return nullptr;
}