ITensor CreateBindableTensor()

in Tools/WinMLRunner/src/BindingUtilities.cpp [716:857]


    ITensor CreateBindableTensor(const ILearningModelFeatureDescriptor& description, const std::wstring& imagePath,
                                 const InputBindingType inputBindingType, const InputDataType inputDataType,
                                 const CommandLineArgs& args, uint32_t iterationNum,
                                 ColorManagementMode colorManagementMode)
    {
        InputBufferDesc inputBufferDesc = {};

        std::vector<int64_t> shape = {};
        TensorKind tensorKind = TensorKind::Undefined;
        ProcessDescriptor(description, shape, tensorKind, inputBufferDesc);

        SoftwareBitmap softwareBitmap(nullptr);
        if (args.IsCSVInput())
        {
            inputBufferDesc.channelFormat = TensorKind::Float;
            inputBufferDesc.isPlanar = true;

            // Assumes shape is in the format of 'NCHW'
            inputBufferDesc.numChannelsPerElement = static_cast<uint32_t>(shape[1]);

            // Assumes no gaps in the input csv file
            inputBufferDesc.elementStrideInBytes = inputBufferDesc.numChannelsPerElement * sizeof(float_t);

            inputBufferDesc.totalSizeInBytes = sizeof(float_t);
            for (uint32_t i = 0; i < shape.size(); ++i)
                inputBufferDesc.totalSizeInBytes *= static_cast<uint32_t>(shape[i]);

            inputBufferDesc.elements = new uint8_t[inputBufferDesc.totalSizeInBytes];

            ReadCSVIntoBuffer(args.CsvPath(), inputBufferDesc);
        }
        else if (args.IsImageInput())
        {
            softwareBitmap =
                LoadImageFile(description, inputDataType, imagePath.c_str(), args, iterationNum, colorManagementMode);

            // Get Pointers to the SoftwareBitmap data buffers
            const BitmapBuffer sbBitmapBuffer(softwareBitmap.LockBuffer(BitmapBufferAccessMode::Read));
            winrt::Windows::Foundation::IMemoryBufferReference sbReference = sbBitmapBuffer.CreateReference();
            auto sbByteAccess = sbReference.as<::Windows::Foundation::IMemoryBufferByteAccess>();
            winrt::check_hresult(sbByteAccess->GetBuffer(&inputBufferDesc.elements, &inputBufferDesc.totalSizeInBytes));

            inputBufferDesc.isPlanar = false;
            inputBufferDesc.elementFormat = softwareBitmap.BitmapPixelFormat();
            switch (inputBufferDesc.elementFormat)
            {
                case BitmapPixelFormat::Gray8:
                    inputBufferDesc.channelFormat = TensorKind::UInt8;
                    inputBufferDesc.numChannelsPerElement = 1;
                    inputBufferDesc.elementStrideInBytes = sizeof(uint8_t);
                    break;
                case BitmapPixelFormat::Gray16:
                    inputBufferDesc.channelFormat = TensorKind::UInt16;
                    inputBufferDesc.numChannelsPerElement = 1;
                    inputBufferDesc.elementStrideInBytes = sizeof(uint16_t);
                    break;
                case BitmapPixelFormat::Bgra8:
                    inputBufferDesc.channelFormat = TensorKind::UInt8;
                    inputBufferDesc.numChannelsPerElement = 3;
                    inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint8_t);
                    break;
                case BitmapPixelFormat::Rgba8:
                    inputBufferDesc.channelFormat = TensorKind::UInt8;
                    inputBufferDesc.numChannelsPerElement = 3;
                    inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint8_t);
                    break;
                case BitmapPixelFormat::Rgba16:
                    inputBufferDesc.channelFormat = TensorKind::UInt16;
                    inputBufferDesc.numChannelsPerElement = 3;
                    inputBufferDesc.elementStrideInBytes = 4 * sizeof(uint16_t);
                    break;
                default:
                    throw hresult_invalid_argument(L"Unknown BitmapPixelFormat in input image.");
            }
        }

        switch (tensorKind)
        {
            case TensorKind::Undefined:
            {
                std::cout << "BindingUtilities: TensorKind is undefined." << std::endl;
                throw hresult_invalid_argument();
            }
            case TensorKind::Float:
            {
                return CreateTensor<TensorKind::Float>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Float16:
            {
                return CreateTensor<TensorKind::Float16>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Double:
            {
                return CreateTensor<TensorKind::Double>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Int8:
            {
                return CreateTensor<TensorKind::Int8>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::UInt8:
            {
                return CreateTensor<TensorKind::UInt8>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Int16:
            {
                return CreateTensor<TensorKind::Int16>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::UInt16:
            {
                return CreateTensor<TensorKind::UInt16>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Int32:
            {
                return CreateTensor<TensorKind::Int32>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::UInt32:
            {
                return CreateTensor<TensorKind::UInt32>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::Int64:
            {
                return CreateTensor<TensorKind::Int64>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
            case TensorKind::UInt64:
            {
                return CreateTensor<TensorKind::UInt64>(args, shape, inputBindingType, inputBufferDesc);
            }
            break;
        }
        std::cout << "BindingUtilities: TensorKind has not been implemented." << std::endl;
        throw hresult_not_implemented();
    }