in Tools/WinMLRunner/src/BindingUtilities.cpp [716:857]
ITensor CreateBindableTensor(const ILearningModelFeatureDescriptor& description, const std::wstring& imagePath,
const InputBindingType inputBindingType, const InputDataType inputDataType,
const CommandLineArgs& args, uint32_t iterationNum,
ColorManagementMode colorManagementMode)
{
InputBufferDesc inputBufferDesc = {};
std::vector<int64_t> shape = {};
TensorKind tensorKind = TensorKind::Undefined;
ProcessDescriptor(description, shape, tensorKind, inputBufferDesc);
SoftwareBitmap softwareBitmap(nullptr);
if (args.IsCSVInput())
{
inputBufferDesc.channelFormat = TensorKind::Float;
inputBufferDesc.isPlanar = true;
// Assumes shape is in the format of 'NCHW'
inputBufferDesc.numChannelsPerElement = static_cast<uint32_t>(shape[1]);
// Assumes no gaps in the input csv file
inputBufferDesc.elementStrideInBytes = inputBufferDesc.numChannelsPerElement * sizeof(float_t);
inputBufferDesc.totalSizeInBytes = sizeof(float_t);
for (uint32_t i = 0; i < shape.size(); ++i)
inputBufferDesc.totalSizeInBytes *= static_cast<uint32_t>(shape[i]);
inputBufferDesc.elements = new uint8_t[inputBufferDesc.totalSizeInBytes];
ReadCSVIntoBuffer(args.CsvPath(), inputBufferDesc);
}
else if (args.IsImageInput())
{
softwareBitmap =
LoadImageFile(description, inputDataType, imagePath.c_str(), args, iterationNum, colorManagementMode);
// Get Pointers to the SoftwareBitmap data buffers
const BitmapBuffer sbBitmapBuffer(softwareBitmap.LockBuffer(BitmapBufferAccessMode::Read));
winrt::Windows::Foundation::IMemoryBufferReference sbReference = sbBitmapBuffer.CreateReference();
auto sbByteAccess = sbReference.as<::Windows::Foundation::IMemoryBufferByteAccess>();
winrt::check_hresult(sbByteAccess->GetBuffer(&inputBufferDesc.elements, &inputBufferDesc.totalSizeInBytes));
inputBufferDesc.isPlanar = false;
inputBufferDesc.elementFormat = softwareBitmap.BitmapPixelFormat();
switch (inputBufferDesc.elementFormat)
{
case BitmapPixelFormat::Gray8:
inputBufferDesc.channelFormat = TensorKind::UInt8;
inputBufferDesc.numChannelsPerElement = 1;
inputBufferDesc.elementStrideInBytes = sizeof(uint8_t);
break;
case BitmapPixelFormat::Gray16:
inputBufferDesc.channelFormat = TensorKind::UInt16;
inputBufferDesc.numChannelsPerElement = 1;
inputBufferDesc.elementStrideInBytes = sizeof(uint16_t);
break;
case BitmapPixelFormat::Bgra8:
inputBufferDesc.channelFormat = TensorKind::UInt8;
inputBufferDesc.numChannelsPerElement = 3;
inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint8_t);
break;
case BitmapPixelFormat::Rgba8:
inputBufferDesc.channelFormat = TensorKind::UInt8;
inputBufferDesc.numChannelsPerElement = 3;
inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint8_t);
break;
case BitmapPixelFormat::Rgba16:
inputBufferDesc.channelFormat = TensorKind::UInt16;
inputBufferDesc.numChannelsPerElement = 3;
inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint16_t);
break;
default:
throw hresult_invalid_argument(L"Unknown BitmapPixelFormat in input image.");
}
}
switch (tensorKind)
{
case TensorKind::Undefined:
{
std::cout << "BindingUtilities: TensorKind is undefined." << std::endl;
throw hresult_invalid_argument();
}
case TensorKind::Float:
{
return CreateTensor<TensorKind::Float>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Float16:
{
return CreateTensor<TensorKind::Float16>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Double:
{
return CreateTensor<TensorKind::Double>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Int8:
{
return CreateTensor<TensorKind::Int8>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::UInt8:
{
return CreateTensor<TensorKind::UInt8>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Int16:
{
return CreateTensor<TensorKind::Int16>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::UInt16:
{
return CreateTensor<TensorKind::UInt16>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Int32:
{
return CreateTensor<TensorKind::Int32>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::UInt32:
{
return CreateTensor<TensorKind::UInt32>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::Int64:
{
return CreateTensor<TensorKind::Int64>(args, shape, inputBindingType, inputBufferDesc);
}
break;
case TensorKind::UInt64:
{
return CreateTensor<TensorKind::UInt64>(args, shape, inputBindingType, inputBufferDesc);
}
break;
}
std::cout << "BindingUtilities: TensorKind has not been implemented." << std::endl;
throw hresult_not_implemented();
}