in Python/src/device.cpp [119:314]
std::vector<pydml::TensorData*> Device::DispatchOperator(
IDMLCompiledOperator* op,
std::vector<pydml::Binding*>& inputs,
std::vector<dml::Expression*>& outputs
)
{
std::vector<DmlBufferBinding> inputBindings(inputs.size());
uint64_t inputsResourceSize = 0;
for (size_t i = 0; i < inputs.size(); ++i)
{
auto input = inputs[i];
if (!input)
{
continue; // null optional tensor
}
DmlBufferTensorDesc desc = *input->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();
// If OWNED_BY_DML is *not* set, this input must be bound at execution
if (!desc.flags & DML_TENSOR_FLAG_OWNED_BY_DML)
{
uint32_t requiredAlignment = std::max(desc.guaranteedBaseOffsetAlignment, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
// Bind to the end of the inputs resource (with appropriate alignment)
inputBindings[i].offset = RoundUpToMultiple(inputsResourceSize, (uint64_t)requiredAlignment);
inputBindings[i].sizeInBytes = desc.totalTensorSizeInBytes;
inputsResourceSize = inputBindings[i].offset + desc.totalTensorSizeInBytes;
}
}
std::vector<DmlBufferBinding> outputBindings(outputs.size());
uint64_t outputsResourceSize = 0;
for (size_t i = 0; i < outputs.size(); ++i)
{
auto output = outputs[i];
if (!output)
{
continue; // null optional tensor
}
dml::TensorDesc desc = output->GetOutputDesc();
DmlBufferTensorDesc bufferDesc = *desc.AsPtr<DML_BUFFER_TENSOR_DESC>();
uint32_t requiredAlignment = std::max(bufferDesc.guaranteedBaseOffsetAlignment, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
// Bind to the end of the outputs resource (with appropriate alignment)
outputBindings[i].offset = RoundUpToMultiple(outputsResourceSize, (uint64_t)requiredAlignment);
outputBindings[i].sizeInBytes = bufferDesc.totalTensorSizeInBytes;
outputsResourceSize = outputBindings[i].offset + outputBindings[i].sizeInBytes;
}
DML_BINDING_PROPERTIES bindingProps = op->GetBindingProperties();
EnsureUploadHeapSize(inputsResourceSize);
EnsureCpuOrDefaultBufferSize(inputsResourceSize, m_inputsResource);
EnsureReadBackHeapSize(outputsResourceSize);
EnsureCpuOrDefaultBufferSize(outputsResourceSize, m_outputsResource);
EnsureDefaultBufferSize(bindingProps.TemporaryResourceSize, m_temporaryResource);
EnsureDescriptorHeapSize(bindingProps.RequiredDescriptorCount);
// Set up input and output bindings to point to their respective buffers
for (auto& binding : inputBindings)
{
if (binding.sizeInBytes != 0)
{
binding.buffer = m_inputsResource->GetResource();
}
}
for (auto& binding : outputBindings)
{
if (binding.sizeInBytes != 0)
{
binding.buffer = m_outputsResource->GetResource();
}
}
// The persistent resource should have already been initialized when the operator was initialized
assert(m_persistentResource->GetResource()->GetDesc().Width >= bindingProps.PersistentResourceSize);
// Upload inputs for execution
std::vector<ID3D12Resource*> buffersToClear =
{
m_inputsResource->GetResource(),
m_temporaryResource->GetResource(),
m_outputsResource->GetResource()
};
ClearGpuBuffers(buffersToClear);
if (inputsResourceSize)
{
// Copy the data into the upload heap
byte* uploadHeapData = nullptr;
ThrowIfFailed(m_uploadHeap->Map(0, nullptr, reinterpret_cast<void**>(&uploadHeapData)));
for (size_t i = 0; i < inputs.size(); ++i)
{
if (!inputBindings[i].buffer)
{
// This input tensor doesn't need to be bound for initialize
continue;
}
DmlBufferTensorDesc bufferDesc = *inputs[i]->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();
void* dest = uploadHeapData + inputBindings[i].offset;
const void* src = inputs[i]->data.Get();
assert(inputs[i]->data.Size() == bufferDesc.totalTensorSizeInBytes);
memcpy(dest, src, static_cast<size_t>(bufferDesc.totalTensorSizeInBytes));
}
m_uploadHeap->Unmap(0, nullptr);
// Record the copy from the upload heap into the inputs resource
m_commandList->ResourceBarrier(
1,
&CD3DX12_RESOURCE_BARRIER::Transition(
m_inputsResource->GetResource(),
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_COPY_DEST)
);
m_commandList->CopyBufferRegion(m_inputsResource->GetResource(), 0, m_uploadHeap->GetResource(), 0, inputsResourceSize);
m_commandList->ResourceBarrier(
1,
&CD3DX12_RESOURCE_BARRIER::Transition(
m_inputsResource->GetResource(),
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS)
);
}
// Bind for execution
DmlTypeConverter<1024> converter;
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = op;
bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap->GetCPUDescriptorHandleForHeapStart();
bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap->GetGPUDescriptorHandleForHeapStart();
bindingTableDesc.SizeInDescriptors = bindingProps.RequiredDescriptorCount;
ThrowIfFailed(m_bindingTable->Reset(&bindingTableDesc));
// Bind inputs
std::vector<DML_BINDING_DESC> inputBindingDescs(inputBindings.size());
for (size_t i = 0; i < inputBindings.size(); ++i)
{
inputBindingDescs[i] = converter.ToBindingDesc(inputBindings[i]);
}
m_bindingTable->BindInputs(static_cast<uint32_t>(inputBindingDescs.size()), inputBindingDescs.data());
// Bind outputs
std::vector<DML_BINDING_DESC> outputBindingDescs(outputBindings.size());
for (size_t i = 0; i < outputBindings.size(); ++i)
{
outputBindingDescs[i] = converter.ToBindingDesc(outputBindings[i]);
}
m_bindingTable->BindOutputs(static_cast<uint32_t>(outputBindingDescs.size()), outputBindingDescs.data());
// Bind persistent/temporary resources
if (bindingProps.PersistentResourceSize != 0)
{
DML_BUFFER_BINDING persistentBinding = { m_persistentResource->GetResource(), 0, bindingProps.PersistentResourceSize };
auto bindingDesc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &persistentBinding };
m_bindingTable->BindPersistentResource(&bindingDesc);
}
if (bindingProps.TemporaryResourceSize != 0)
{
DML_BUFFER_BINDING temporaryBinding = { m_temporaryResource->GetResource(), 0, bindingProps.TemporaryResourceSize };
auto bindingDesc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &temporaryBinding };
m_bindingTable->BindTemporaryResource(&bindingDesc);
}
// Record and execute commands, and wait for completion
m_commandList->SetDescriptorHeaps(1, m_descriptorHeap->m_Heap.GetAddressOf());
m_commandRecorder->RecordDispatch(m_commandList.Get(), op, m_bindingTable.Get());
RecordOutputReadBack(outputsResourceSize);
ExecuteCommandListAndWait();
// Read the output data back from the readback heap
return DownloadFromReadBackHeap(outputsResourceSize, outputs, outputBindings);
}