void HlslDispatchable::Bind()

in DxDispatch/src/dxdispatch/HlslDispatchable.cpp [319:449]


void HlslDispatchable::Bind(const Bindings& bindings)
{
    uint32_t descriptorIncrementSize = m_device->D3D()->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);

    for (auto& binding : bindings)
    {
        auto& targetName = binding.first;
        auto& sources = binding.second;
        assert(sources.size() == 1); // TODO: support multiple
        auto& source = sources[0];

        assert(source.resource != nullptr);
        assert(source.resourceDesc != nullptr);

        if (!std::holds_alternative<Model::BufferDesc>(source.resourceDesc->value))
        {
            throw std::invalid_argument("HLSL operators currently only support buffer bindings");
        }
        auto& sourceBufferDesc = std::get<Model::BufferDesc>(source.resourceDesc->value);

        auto& bindPointIterator = m_bindPoints.find(targetName);
        if (bindPointIterator == m_bindPoints.end())
        {
            throw std::invalid_argument(fmt::format("Attempting to bind shader input '{}', which does not exist (or was optimized away) in the shader.", targetName));
        }
        auto& bindPoint = bindPointIterator->second;

        CD3DX12_CPU_DESCRIPTOR_HANDLE cpuHandle{
            m_descriptorHeap->GetCPUDescriptorHandleForHeapStart(), 
            static_cast<int>(bindPoint.offsetInDescriptorsFromTableStart), 
            descriptorIncrementSize
        };

        auto FillViewDesc = [&](auto& viewDesc)
        {
            viewDesc.Buffer.StructureByteStride = bindPoint.structureByteStride;
            viewDesc.Buffer.NumElements = source.elementCount;
            viewDesc.Buffer.FirstElement = source.elementOffset;

            if (bindPoint.viewType == BufferViewType::Typed)
            {
                if (source.format)
                {
                    viewDesc.Format = *source.format;
                }
                else
                {
                    // If the binding doesn't specify, assume the data type used to initialize the buffer.
                    viewDesc.Format = Device::GetDxgiFormatFromDmlTensorDataType(sourceBufferDesc.initialValuesDataType);
                }
            }
            else if (bindPoint.viewType == BufferViewType::Structured)
            {
                if (source.format && *source.format != DXGI_FORMAT_UNKNOWN)
                {
                    throw std::invalid_argument(fmt::format("'{}' is a structured buffer, so the format must be omitted or UNKNOWN.", targetName));
                }
                viewDesc.Format = DXGI_FORMAT_UNKNOWN;
            }
            else if (bindPoint.viewType == BufferViewType::Raw)
            {
                if (source.format && *source.format != DXGI_FORMAT_R32_TYPELESS)
                {
                    throw std::invalid_argument(fmt::format("'{}' is a raw buffer, so the format must be omitted or R32_TYPELESS.", targetName));
                }

                if (sourceBufferDesc.sizeInBytes % D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT != 0)
                {
                    throw std::invalid_argument(fmt::format(
                        "Attempting to bind '{}' as a raw buffer, but its size ({} bytes) is not aligned to {} bytes", 
                        source.resourceDesc->name,
                        sourceBufferDesc.sizeInBytes,
                        D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT));
                }

                viewDesc.Format = DXGI_FORMAT_R32_TYPELESS;
                if constexpr (std::is_same_v<decltype(viewDesc), D3D12_UNORDERED_ACCESS_VIEW_DESC&>)
                {
                    viewDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
                }
                if constexpr (std::is_same_v<decltype(viewDesc), D3D12_SHADER_RESOURCE_VIEW_DESC&>)
                {
                    viewDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_RAW;
                }
            }
        };

        if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV)
        {
            D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
            uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
            FillViewDesc(uavDesc);
            uavDesc.Buffer.CounterOffsetInBytes = source.counterOffsetBytes;
            m_device->D3D()->CreateUnorderedAccessView(source.resource, source.counterResource, &uavDesc, cpuHandle);
        }
        else if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_SRV)
        {
            D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {};
            srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER;
            srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
            FillViewDesc(srvDesc);
            m_device->D3D()->CreateShaderResourceView(source.resource, &srvDesc, cpuHandle);
        }
        else if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_CBV)
        {
            if (sourceBufferDesc.sizeInBytes % D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT != 0)
            {
                throw std::invalid_argument(fmt::format(
                    "Attempting to bind '{}' as a constant buffer, but its size ({} bytes) is not aligned to {} bytes", 
                    source.resourceDesc->name,
                    sourceBufferDesc.sizeInBytes,
                    D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT));
            }

            D3D12_CONSTANT_BUFFER_VIEW_DESC cbvDesc = {};
            cbvDesc.BufferLocation = source.resource->GetGPUVirtualAddress();
            cbvDesc.SizeInBytes = sourceBufferDesc.sizeInBytes;
            m_device->D3D()->CreateConstantBufferView(&cbvDesc, cpuHandle);
        }
        else
        {
            throw std::invalid_argument("Unexpected binding type");
        }
    }

    m_device->GetCommandList()->SetComputeRootSignature(m_rootSignature.Get());
    m_device->GetCommandList()->SetPipelineState(m_pipelineState.Get());
    ID3D12DescriptorHeap* descriptorHeaps[] = { m_descriptorHeap.Get() };
    m_device->GetCommandList()->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps);
    m_device->GetCommandList()->SetComputeRootDescriptorTable(0, m_descriptorHeap->GetGPUDescriptorHandleForHeapStart());
}