in Python/src/device.cpp [385:531]
void Device::InitializeOperator(
IDMLCompiledOperator* op,
std::vector<pydml::Binding*>& inputs
)
{
// Allocate resources for initialization
ThrowIfFailed(m_operatorInitializer->Reset(1, &op));
DmlBufferArrayBinding inputBinding;
inputBinding.bindings.resize(inputs.size());
// Fill in the offsets and sizes for each binding, which will also tell us how big we need to make our buffer
uint64_t inputsResourceSize = 0;
for (size_t i = 0; i < inputs.size(); ++i)
{
auto input = inputs[i];
if (!input)
{
continue; // null optional tensor
}
DmlBufferTensorDesc bufferDesc = *input->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();
// If OWNED_BY_DML is set, this input must be bound at initialize
if (bufferDesc.flags & DML_TENSOR_FLAG_OWNED_BY_DML)
{
uint32_t requiredAlignment = std::max(bufferDesc.guaranteedBaseOffsetAlignment, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
// Bind to the end of the inputs resource (with appropriate alignment)
inputBinding.bindings[i].offset = RoundUpToMultiple(inputsResourceSize, (uint64_t)requiredAlignment);
inputBinding.bindings[i].sizeInBytes = bufferDesc.totalTensorSizeInBytes;
inputsResourceSize = inputBinding.bindings[i].offset + bufferDesc.totalTensorSizeInBytes;
}
}
uint64_t temporaryResourceSize = m_operatorInitializer->GetBindingProperties().TemporaryResourceSize;
uint64_t persistentResourceSize = op->GetBindingProperties().PersistentResourceSize;
uint32_t descriptorHeapSize = m_operatorInitializer->GetBindingProperties().RequiredDescriptorCount;
EnsureUploadHeapSize(inputsResourceSize);
EnsureCpuOrDefaultBufferSize(inputsResourceSize, m_inputsResource);
EnsureDefaultBufferSize(temporaryResourceSize, m_temporaryResource);
EnsureDefaultBufferSize(persistentResourceSize, m_persistentResource);
EnsureDescriptorHeapSize(descriptorHeapSize);
// Set up the bindings to point to our input resource
for (auto& binding : inputBinding.bindings)
{
if (binding.sizeInBytes != 0)
{
binding.buffer = m_inputsResource->GetResource();
}
}
// Upload inputs for initialization
std::vector<ID3D12Resource*> buffersToClear =
{
m_inputsResource->GetResource(),
m_temporaryResource->GetResource(),
m_persistentResource->GetResource()
};
ClearGpuBuffers(buffersToClear);
if (inputsResourceSize)
{
// Copy the data into the upload heap
byte* uploadHeapData = nullptr;
ThrowIfFailed(m_uploadHeap->Map(0, nullptr, reinterpret_cast<void**>(&uploadHeapData)));
for (size_t i = 0; i < inputs.size(); ++i)
{
if (!inputBinding.bindings[i].buffer)
{
// This input tensor doesn't need to be bound for initialize
continue;
}
DmlBufferTensorDesc bufferDesc = *inputs[i]->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();
void* dest = uploadHeapData + inputBinding.bindings[i].offset;
const void* src = inputs[i]->data.Get();
assert(inputs[i]->data.Size() == bufferDesc.totalTensorSizeInBytes);
memcpy(dest, src, static_cast<size_t>(bufferDesc.totalTensorSizeInBytes));
}
m_uploadHeap->Unmap(0, nullptr);
// Record the copy from the upload heap into the inputs resource
m_commandList->ResourceBarrier(
1,
&CD3DX12_RESOURCE_BARRIER::Transition(
m_inputsResource->GetResource(),
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_COPY_DEST)
);
m_commandList->CopyBufferRegion(m_inputsResource->GetResource(), 0, m_uploadHeap->GetResource(), 0, inputsResourceSize);
m_commandList->ResourceBarrier(
1,
&CD3DX12_RESOURCE_BARRIER::Transition(
m_inputsResource->GetResource(),
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS)
);
}
// Bind for initialization
DmlTypeConverter<1024> converter;
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = m_operatorInitializer.Get();
bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap->GetCPUDescriptorHandleForHeapStart();
bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap->GetGPUDescriptorHandleForHeapStart();
bindingTableDesc.SizeInDescriptors = descriptorHeapSize;
ThrowIfFailed(m_bindingTable->Reset(&bindingTableDesc));
DML_BINDING_DESC inputBindingDesc = converter.ToBindingDesc(inputBinding);
m_bindingTable->BindInputs(1, &inputBindingDesc);
if (persistentResourceSize != 0)
{
DML_BUFFER_BINDING outputBinding = { m_persistentResource->GetResource(), 0, persistentResourceSize };
auto desc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &outputBinding };
m_bindingTable->BindOutputs(1, &desc);
}
if (temporaryResourceSize != 0)
{
DML_BUFFER_BINDING temporaryBinding = { m_temporaryResource->GetResource(), 0, temporaryResourceSize };
auto desc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &temporaryBinding };
m_bindingTable->BindTemporaryResource(&desc);
}
// Record and execute commands, and wait for completion
m_commandList->SetDescriptorHeaps(1, m_descriptorHeap->m_Heap.GetAddressOf());
m_commandRecorder->RecordDispatch(m_commandList.Get(), m_operatorInitializer.Get(), m_bindingTable.Get());
ExecuteCommandListAndWait();
}