inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable()

in winml/lib/Api/FeatureValues.h [204:456]


inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable(
    _winml::BindingType bindingType,
    const wf::IInspectable& inspectable,
    const winml::ILearningModelFeatureDescriptor& descriptor) {

  // Tensor and ImageFeatureValue types are passed in directly as feature values
  if (auto featureValue = inspectable.try_as<winml::ILearningModelFeatureValue>()) {
    return featureValue;
  }

  if (auto videoFrames = inspectable.try_as<wfc::IVector<wm::VideoFrame>>()) {
    return (0 == videoFrames.Size()) ? nullptr : winrt::make<winmlp::ImageFeatureValue>(videoFrames);
  }

  if (bindingType == _winml::BindingType::kInput) {
    // Allows to bind IVectorView<VideoFrame> as input.
    if (auto videoFrames = inspectable.try_as<wfc::IVectorView<wm::VideoFrame>>()) {
      return (0 == videoFrames.Size()) ? nullptr : winrt::make<winmlp::ImageFeatureValue>(videoFrames);
    }
  }

  // ImageFeatureValues Types can be implicitly inferred from the VideoFrame object
  if (auto videoFrame = inspectable.try_as<wm::VideoFrame>()) {
    return winrt::make<winmlp::ImageFeatureValue>(videoFrame);
  }

  // MapFeatureValues Types are implicitly inferred from the iinspectable object
  if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, float>>()) {
    return winmlp::MapStringToFloat::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, double>>()) {
    return winmlp::MapStringToDouble::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, int64_t>>()) {
    return winmlp::MapStringToInt64Bit::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<winrt::hstring, winrt::hstring>>()) {
    return winmlp::MapStringToString::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<int64_t, float>>()) {
    return winmlp::MapInt64BitToFloat::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<int64_t, double>>()) {
    return winmlp::MapInt64BitToDouble::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<int64_t, int64_t>>()) {
    return winmlp::MapInt64BitToInt64Bit::Create(map);
  }
  if (auto map = inspectable.try_as<wfc::IMap<int64_t, winrt::hstring>>()) {
    return winmlp::MapInt64BitToString::Create(map);
  }

  if (bindingType == _winml::BindingType::kInput) {
    // Feature inputs should be more permissive, and allow for views to be bound since they are read only
    if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, float>>()) {
      return winmlp::MapStringToFloat::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, double>>()) {
      return winmlp::MapStringToDouble::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, int64_t>>()) {
      return winmlp::MapStringToInt64Bit::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<winrt::hstring, winrt::hstring>>()) {
      return winmlp::MapStringToString::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<int64_t, float>>()) {
      return winmlp::MapInt64BitToFloat::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<int64_t, double>>()) {
      return winmlp::MapInt64BitToDouble::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<int64_t, int64_t>>()) {
      return winmlp::MapInt64BitToInt64Bit::Create(map);
    }
    if (auto map = inspectable.try_as<wfc::IMapView<int64_t, winrt::hstring>>()) {
      return winmlp::MapInt64BitToString::Create(map);
    }
  }
    
  if (descriptor.Kind() == winml::LearningModelFeatureKind::Sequence) {
    // SequenceFeatureValues Types are implicitly inferred from the iinspectable object
    if (auto sequence = inspectable.try_as<wfc::IVector<wfc::IMap<winrt::hstring, float>>>()) {
      return winmlp::SequenceMapStringFloat::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<wfc::IMap<int64_t, float>>>()) {
      return winmlp::SequenceMapInt64BitFloat::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorFloat>>()) {
      return winmlp::SequenceTensorFloat::Create(sequence);
    }    
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorBoolean>>()) {
      return winmlp::SequenceTensorBoolean::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorDouble>>()) {
      return winmlp::SequenceTensorDouble::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt8Bit>>()) {
      return winmlp::SequenceTensorInt8Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt8Bit>>()) {
      return winmlp::SequenceTensorUInt8Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt16Bit>>()) {
      return winmlp::SequenceTensorUInt16Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt16Bit>>()) {
      return winmlp::SequenceTensorInt16Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt32Bit>>()) {
      return winmlp::SequenceTensorUInt32Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt32Bit>>()) {
      return winmlp::SequenceTensorInt32Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorUInt64Bit>>()) {
      return winmlp::SequenceTensorUInt64Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorInt64Bit>>()) {
      return winmlp::SequenceTensorInt64Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorFloat16Bit>>()) {
      return winmlp::SequenceTensorFloat16Bit::Create(sequence);
    }
    if (auto sequence = inspectable.try_as<wfc::IVector<winml::TensorString>>()) {
      return winmlp::SequenceTensorString::Create(sequence);
    }

    if (bindingType == _winml::BindingType::kInput) {
      // Feature inputs should be more permissive, and allow for views to be bound since they are read only
      if (auto sequence = inspectable.try_as<wfc::IVectorView<wfc::IMap<winrt::hstring, float>>>()) {
        return winmlp::SequenceMapStringFloat::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<wfc::IMap<int64_t, float>>>()) {
        return winmlp::SequenceMapInt64BitFloat::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorFloat>>()) {
        return winmlp::SequenceTensorFloat::Create(sequence);
      }    
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorBoolean>>()) {
        return winmlp::SequenceTensorBoolean::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorDouble>>()) {
        return winmlp::SequenceTensorDouble::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt8Bit>>()) {
        return winmlp::SequenceTensorInt8Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt8Bit>>()) {
        return winmlp::SequenceTensorUInt8Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt16Bit>>()) {
        return winmlp::SequenceTensorUInt16Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt16Bit>>()) {
        return winmlp::SequenceTensorInt16Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt32Bit>>()) {
        return winmlp::SequenceTensorUInt32Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt32Bit>>()) {
        return winmlp::SequenceTensorInt32Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorUInt64Bit>>()) {
        return winmlp::SequenceTensorUInt64Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorInt64Bit>>()) {
        return winmlp::SequenceTensorInt64Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorFloat16Bit>>()) {
        return winmlp::SequenceTensorFloat16Bit::Create(sequence);
      }
      if (auto sequence = inspectable.try_as<wfc::IVectorView<winml::TensorString>>()) {
        return winmlp::SequenceTensorString::Create(sequence);
      }
    }
  }
  else if (descriptor.Kind() == winml::LearningModelFeatureKind::Tensor) {
    auto tensorDescriptor = descriptor.as<winml::ITensorFeatureDescriptor>();

    // Vector of IBuffer Input should be copied into the appropriate Tensor
    if (auto buffers = inspectable.try_as<wfc::IIterable<wss::IBuffer>>()) {
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Boolean) {
        return winmlp::TensorBoolean::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Float) {
        return winmlp::TensorFloat::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Double) {
        return winmlp::TensorDouble::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Float16) {
        return winmlp::TensorFloat16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt8) {
        return winmlp::TensorUInt8Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Int8) {
        return winmlp::TensorInt8Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt16) {
        return winmlp::TensorUInt16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Int16) {
        return winmlp::TensorInt16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt32) {
        return winmlp::TensorUInt32Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Int32) {
        return winmlp::TensorInt32Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::UInt64) {
        return winmlp::TensorUInt64Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Int64) {
        return winmlp::TensorInt64Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
      if (tensorDescriptor.TensorKind() == winml::TensorKind::Float16) {
        return winmlp::TensorFloat16Bit::CreateFromBatchedBuffers(tensorDescriptor.Shape(), buffers);
      }
    }


    using TensorCreator = winml::ILearningModelFeatureValue (*)(BindingType, const wf::IInspectable& inspectable, const winml::ITensorFeatureDescriptor& descriptor);
    constexpr std::array<TensorCreator, 13> creators =
        {
            // Vector and VectorViews of float16 and int8 collide with float and uint8 respectively.
            // They are omitted because of this ambiguity and are not constructible via raw winrt collections.
            CreateTensorValueFromInspectable<winmlp::TensorBoolean, bool>,
            CreateTensorValueFromInspectable<winmlp::TensorFloat, float>,
            CreateTensorValueFromInspectable<winmlp::TensorDouble, double>,
            CreateTensorValueFromInspectable<winmlp::TensorUInt8Bit, uint8_t>,
            CreateTensorValueFromInspectable<winmlp::TensorInt8Bit, uint8_t>,
            CreateTensorValueFromInspectable<winmlp::TensorUInt16Bit, uint16_t>,
            CreateTensorValueFromInspectable<winmlp::TensorInt16Bit, int16_t>,
            CreateTensorValueFromInspectable<winmlp::TensorUInt32Bit, uint32_t>,
            CreateTensorValueFromInspectable<winmlp::TensorInt32Bit, int32_t>,
            CreateTensorValueFromInspectable<winmlp::TensorUInt64Bit, uint64_t>,
            CreateTensorValueFromInspectable<winmlp::TensorInt64Bit, int64_t>,
            CreateTensorValueFromInspectable<winmlp::TensorFloat16Bit, float>,
            CreateTensorValueFromInspectable<winmlp::TensorString, winrt::hstring>
        };

    for (const auto& tensorCreator : creators) {
      if (auto createdTensor = tensorCreator(bindingType, inspectable, tensorDescriptor)) {
        return createdTensor;
      }
    }
  }
  
  return nullptr;
}