void Device::InitializeOperator()

in Python/src/device.cpp [385:531]


void Device::InitializeOperator(
    IDMLCompiledOperator* op,
    std::vector<pydml::Binding*>& inputs
    )
{
    // Allocate resources for initialization
    ThrowIfFailed(m_operatorInitializer->Reset(1, &op));

    DmlBufferArrayBinding inputBinding;
    inputBinding.bindings.resize(inputs.size());

    // Fill in the offsets and sizes for each binding, which will also tell us how big we need to make our buffer
    uint64_t inputsResourceSize = 0;

    for (size_t i = 0; i < inputs.size(); ++i)
    {
        auto input = inputs[i];

        if (!input)
        {
            continue; // null optional tensor
        }

        DmlBufferTensorDesc bufferDesc = *input->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();

        // If OWNED_BY_DML is set, this input must be bound at initialize
        if (bufferDesc.flags & DML_TENSOR_FLAG_OWNED_BY_DML)
        {
            uint32_t requiredAlignment = std::max(bufferDesc.guaranteedBaseOffsetAlignment, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);

            // Bind to the end of the inputs resource (with appropriate alignment)
            inputBinding.bindings[i].offset = RoundUpToMultiple(inputsResourceSize, (uint64_t)requiredAlignment);
            inputBinding.bindings[i].sizeInBytes = bufferDesc.totalTensorSizeInBytes;

            inputsResourceSize = inputBinding.bindings[i].offset + bufferDesc.totalTensorSizeInBytes;
        }
    }

    uint64_t temporaryResourceSize = m_operatorInitializer->GetBindingProperties().TemporaryResourceSize;
    uint64_t persistentResourceSize = op->GetBindingProperties().PersistentResourceSize;
    uint32_t descriptorHeapSize = m_operatorInitializer->GetBindingProperties().RequiredDescriptorCount;

    EnsureUploadHeapSize(inputsResourceSize);
    EnsureCpuOrDefaultBufferSize(inputsResourceSize, m_inputsResource);
    EnsureDefaultBufferSize(temporaryResourceSize, m_temporaryResource);
    EnsureDefaultBufferSize(persistentResourceSize, m_persistentResource);
    EnsureDescriptorHeapSize(descriptorHeapSize);

    // Set up the bindings to point to our input resource
    for (auto& binding : inputBinding.bindings)
    {
        if (binding.sizeInBytes != 0)
        {
            binding.buffer = m_inputsResource->GetResource();
        }
    }

    // Upload inputs for initialization
    std::vector<ID3D12Resource*> buffersToClear =
    {
        m_inputsResource->GetResource(),
        m_temporaryResource->GetResource(),
        m_persistentResource->GetResource()
    };

    ClearGpuBuffers(buffersToClear);

    if (inputsResourceSize)
    {
        // Copy the data into the upload heap
        byte* uploadHeapData = nullptr;

        ThrowIfFailed(m_uploadHeap->Map(0, nullptr, reinterpret_cast<void**>(&uploadHeapData)));

        for (size_t i = 0; i < inputs.size(); ++i)
        {
            if (!inputBinding.bindings[i].buffer)
            {
                // This input tensor doesn't need to be bound for initialize
                continue;
            }

            DmlBufferTensorDesc bufferDesc = *inputs[i]->desc.AsPtr<DML_BUFFER_TENSOR_DESC>();

            void* dest = uploadHeapData + inputBinding.bindings[i].offset;
            const void* src = inputs[i]->data.Get();

            assert(inputs[i]->data.Size() == bufferDesc.totalTensorSizeInBytes);

            memcpy(dest, src, static_cast<size_t>(bufferDesc.totalTensorSizeInBytes));
        }

        m_uploadHeap->Unmap(0, nullptr);

        // Record the copy from the upload heap into the inputs resource
        m_commandList->ResourceBarrier(
            1,
            &CD3DX12_RESOURCE_BARRIER::Transition(
                m_inputsResource->GetResource(),
                D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
                D3D12_RESOURCE_STATE_COPY_DEST)
            );

        m_commandList->CopyBufferRegion(m_inputsResource->GetResource(), 0, m_uploadHeap->GetResource(), 0, inputsResourceSize);

        m_commandList->ResourceBarrier(
            1,
            &CD3DX12_RESOURCE_BARRIER::Transition(
                m_inputsResource->GetResource(),
                D3D12_RESOURCE_STATE_COPY_DEST,
                D3D12_RESOURCE_STATE_UNORDERED_ACCESS)
                );
    }

    // Bind for initialization
    DmlTypeConverter<1024> converter;

    DML_BINDING_TABLE_DESC bindingTableDesc = {};
    bindingTableDesc.Dispatchable = m_operatorInitializer.Get();
    bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap->GetCPUDescriptorHandleForHeapStart();
    bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap->GetGPUDescriptorHandleForHeapStart();
    bindingTableDesc.SizeInDescriptors = descriptorHeapSize;

    ThrowIfFailed(m_bindingTable->Reset(&bindingTableDesc));

    DML_BINDING_DESC inputBindingDesc = converter.ToBindingDesc(inputBinding);
    m_bindingTable->BindInputs(1, &inputBindingDesc);

    if (persistentResourceSize != 0)
    {
        DML_BUFFER_BINDING outputBinding = { m_persistentResource->GetResource(), 0, persistentResourceSize };
        auto desc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &outputBinding };
        m_bindingTable->BindOutputs(1, &desc);
    }

    if (temporaryResourceSize != 0)
    {
        DML_BUFFER_BINDING temporaryBinding = { m_temporaryResource->GetResource(), 0, temporaryResourceSize };
        auto desc = DML_BINDING_DESC { DML_BINDING_TYPE_BUFFER, &temporaryBinding };
        m_bindingTable->BindTemporaryResource(&desc);
    }

    // Record and execute commands, and wait for completion
    m_commandList->SetDescriptorHeaps(1, m_descriptorHeap->m_Heap.GetAddressOf());
    m_commandRecorder->RecordDispatch(m_commandList.Get(), m_operatorInitializer.Get(), m_bindingTable.Get());
    ExecuteCommandListAndWait();
}