HRESULT DebugOperator::Compute()

in Tools/WinMLDashboard/src/cpp/DebugRunner/debug_cpu.cpp [171:251]


HRESULT DebugOperator::Compute(IMLOperatorKernelContext* context)
{
	try
	{
		// Get the input tensor
		winrt::com_ptr<IMLOperatorTensor> inputTensor;
		context->GetInputTensor(0, inputTensor.put());
		// Get the output tensor
		winrt::com_ptr<IMLOperatorTensor> outputTensor;
		context->GetOutputTensor(0, outputTensor.put());
		// Get the input and output shape sizes
		uint32_t inputDimsSize = inputTensor->GetDimensionCount();
		uint32_t outputDimsSize = outputTensor->GetDimensionCount();
		if (inputDimsSize != outputDimsSize)
		{
			return E_UNEXPECTED;
		}
		// Get the input shape
		std::vector<uint32_t> inputDims(inputDimsSize);
		inputTensor->GetShape(inputDimsSize, inputDims.data());
		// Get the output shape
		std::vector<uint32_t> outputDims(outputDimsSize);
		outputTensor->GetShape(outputDimsSize, outputDims.data());
		// For the number of total elements in the input and output shapes
		auto outputDataSize = std::accumulate(outputDims.begin(), outputDims.end(), 1, std::multiplies<uint32_t>());
		auto inputDataSize = std::accumulate(inputDims.begin(), inputDims.end(), 1, std::multiplies<uint32_t>());
		if (outputDataSize != inputDataSize)
		{
			return E_UNEXPECTED;
		}
		MLOperatorTensorDataType type = inputTensor->GetTensorDataType();

		if (outputTensor->GetTensorDataType() != type) {
			return E_UNEXPECTED;
		}

		if (outputTensor->IsCpuData() && inputTensor->IsCpuData()) {
			switch (type) {
			case MLOperatorTensorDataType::Float:
			case MLOperatorTensorDataType::Float16:
				ComputeInternal<float>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Bool:
				ComputeInternal<bool>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Double:
				ComputeInternal<double>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::UInt8:
				ComputeInternal<unsigned char>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Int8:
				ComputeInternal<char>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::UInt16:
				ComputeInternal<unsigned short int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Int16:
				ComputeInternal<short int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Int32:
				ComputeInternal<int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::UInt32:
				ComputeInternal<unsigned int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::Int64:
				ComputeInternal<long long int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			case MLOperatorTensorDataType::UInt64:
				ComputeInternal<unsigned long long int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
				break;
			}
		}
		return S_OK;
	}
	catch (...)
	{
		return winrt::to_hresult();
	}
}