in Tools/WinMLDashboard/src/cpp/DebugRunner/debug_cpu.cpp [171:251]
HRESULT DebugOperator::Compute(IMLOperatorKernelContext* context)
{
try
{
// Get the input tensor
winrt::com_ptr<IMLOperatorTensor> inputTensor;
context->GetInputTensor(0, inputTensor.put());
// Get the output tensor
winrt::com_ptr<IMLOperatorTensor> outputTensor;
context->GetOutputTensor(0, outputTensor.put());
// Get the input and output shape sizes
uint32_t inputDimsSize = inputTensor->GetDimensionCount();
uint32_t outputDimsSize = outputTensor->GetDimensionCount();
if (inputDimsSize != outputDimsSize)
{
return E_UNEXPECTED;
}
// Get the input shape
std::vector<uint32_t> inputDims(inputDimsSize);
inputTensor->GetShape(inputDimsSize, inputDims.data());
// Get the output shape
std::vector<uint32_t> outputDims(outputDimsSize);
outputTensor->GetShape(outputDimsSize, outputDims.data());
// For the number of total elements in the input and output shapes
auto outputDataSize = std::accumulate(outputDims.begin(), outputDims.end(), 1, std::multiplies<uint32_t>());
auto inputDataSize = std::accumulate(inputDims.begin(), inputDims.end(), 1, std::multiplies<uint32_t>());
if (outputDataSize != inputDataSize)
{
return E_UNEXPECTED;
}
MLOperatorTensorDataType type = inputTensor->GetTensorDataType();
if (outputTensor->GetTensorDataType() != type) {
return E_UNEXPECTED;
}
if (outputTensor->IsCpuData() && inputTensor->IsCpuData()) {
switch (type) {
case MLOperatorTensorDataType::Float:
case MLOperatorTensorDataType::Float16:
ComputeInternal<float>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Bool:
ComputeInternal<bool>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Double:
ComputeInternal<double>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::UInt8:
ComputeInternal<unsigned char>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Int8:
ComputeInternal<char>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::UInt16:
ComputeInternal<unsigned short int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Int16:
ComputeInternal<short int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Int32:
ComputeInternal<int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::UInt32:
ComputeInternal<unsigned int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::Int64:
ComputeInternal<long long int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
case MLOperatorTensorDataType::UInt64:
ComputeInternal<unsigned long long int>(inputTensor.get(), outputTensor.get(), inputDataSize, inputDims, m_filePath, m_fileType);
break;
}
}
return S_OK;
}
catch (...)
{
return winrt::to_hresult();
}
}