in DxDispatch/src/dxdispatch/HlslDispatchable.cpp [319:449]
void HlslDispatchable::Bind(const Bindings& bindings)
{
uint32_t descriptorIncrementSize = m_device->D3D()->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
for (auto& binding : bindings)
{
auto& targetName = binding.first;
auto& sources = binding.second;
assert(sources.size() == 1); // TODO: support multiple
auto& source = sources[0];
assert(source.resource != nullptr);
assert(source.resourceDesc != nullptr);
if (!std::holds_alternative<Model::BufferDesc>(source.resourceDesc->value))
{
throw std::invalid_argument("HLSL operators currently only support buffer bindings");
}
auto& sourceBufferDesc = std::get<Model::BufferDesc>(source.resourceDesc->value);
auto& bindPointIterator = m_bindPoints.find(targetName);
if (bindPointIterator == m_bindPoints.end())
{
throw std::invalid_argument(fmt::format("Attempting to bind shader input '{}', which does not exist (or was optimized away) in the shader.", targetName));
}
auto& bindPoint = bindPointIterator->second;
CD3DX12_CPU_DESCRIPTOR_HANDLE cpuHandle{
m_descriptorHeap->GetCPUDescriptorHandleForHeapStart(),
static_cast<int>(bindPoint.offsetInDescriptorsFromTableStart),
descriptorIncrementSize
};
auto FillViewDesc = [&](auto& viewDesc)
{
viewDesc.Buffer.StructureByteStride = bindPoint.structureByteStride;
viewDesc.Buffer.NumElements = source.elementCount;
viewDesc.Buffer.FirstElement = source.elementOffset;
if (bindPoint.viewType == BufferViewType::Typed)
{
if (source.format)
{
viewDesc.Format = *source.format;
}
else
{
// If the binding doesn't specify, assume the data type used to initialize the buffer.
viewDesc.Format = Device::GetDxgiFormatFromDmlTensorDataType(sourceBufferDesc.initialValuesDataType);
}
}
else if (bindPoint.viewType == BufferViewType::Structured)
{
if (source.format && *source.format != DXGI_FORMAT_UNKNOWN)
{
throw std::invalid_argument(fmt::format("'{}' is a structured buffer, so the format must be omitted or UNKNOWN.", targetName));
}
viewDesc.Format = DXGI_FORMAT_UNKNOWN;
}
else if (bindPoint.viewType == BufferViewType::Raw)
{
if (source.format && *source.format != DXGI_FORMAT_R32_TYPELESS)
{
throw std::invalid_argument(fmt::format("'{}' is a raw buffer, so the format must be omitted or R32_TYPELESS.", targetName));
}
if (sourceBufferDesc.sizeInBytes % D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT != 0)
{
throw std::invalid_argument(fmt::format(
"Attempting to bind '{}' as a raw buffer, but its size ({} bytes) is not aligned to {} bytes",
source.resourceDesc->name,
sourceBufferDesc.sizeInBytes,
D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT));
}
viewDesc.Format = DXGI_FORMAT_R32_TYPELESS;
if constexpr (std::is_same_v<decltype(viewDesc), D3D12_UNORDERED_ACCESS_VIEW_DESC&>)
{
viewDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
}
if constexpr (std::is_same_v<decltype(viewDesc), D3D12_SHADER_RESOURCE_VIEW_DESC&>)
{
viewDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_RAW;
}
}
};
if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV)
{
D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
FillViewDesc(uavDesc);
uavDesc.Buffer.CounterOffsetInBytes = source.counterOffsetBytes;
m_device->D3D()->CreateUnorderedAccessView(source.resource, source.counterResource, &uavDesc, cpuHandle);
}
else if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_SRV)
{
D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {};
srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER;
srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
FillViewDesc(srvDesc);
m_device->D3D()->CreateShaderResourceView(source.resource, &srvDesc, cpuHandle);
}
else if (bindPoint.descriptorType == D3D12_DESCRIPTOR_RANGE_TYPE_CBV)
{
if (sourceBufferDesc.sizeInBytes % D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT != 0)
{
throw std::invalid_argument(fmt::format(
"Attempting to bind '{}' as a constant buffer, but its size ({} bytes) is not aligned to {} bytes",
source.resourceDesc->name,
sourceBufferDesc.sizeInBytes,
D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT));
}
D3D12_CONSTANT_BUFFER_VIEW_DESC cbvDesc = {};
cbvDesc.BufferLocation = source.resource->GetGPUVirtualAddress();
cbvDesc.SizeInBytes = sourceBufferDesc.sizeInBytes;
m_device->D3D()->CreateConstantBufferView(&cbvDesc, cpuHandle);
}
else
{
throw std::invalid_argument("Unexpected binding type");
}
}
m_device->GetCommandList()->SetComputeRootSignature(m_rootSignature.Get());
m_device->GetCommandList()->SetPipelineState(m_pipelineState.Get());
ID3D12DescriptorHeap* descriptorHeaps[] = { m_descriptorHeap.Get() };
m_device->GetCommandList()->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps);
m_device->GetCommandList()->SetComputeRootDescriptorTable(0, m_descriptorHeap->GetGPUDescriptorHandleForHeapStart());
}