inline GraphDesc GraphBuilder::GetGraphDesc()

in Libraries/DirectMLX.h [3911:4010]


        inline GraphDesc GraphBuilder::GetGraphDesc(Span<const Expression> outputs) const
        {
            GraphDesc desc = {};
            desc.inputCount = static_cast<uint32_t>(m_inputNodes.size());
            desc.outputCount = static_cast<uint32_t>(outputs.size());

            for (const OperatorNode& node : m_operatorNodes)
            {
                uint32_t nodeIndex = static_cast<uint32_t>(desc.nodes.size());
                desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get() });

                // Walk through each of this node's inputs and add it as an edge
                const uint32_t inputCount = static_cast<uint32_t>(node.inputs.size());
                for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
                {
                    NodeOutput* input = node.inputs[inputIndex];
                    if (input == nullptr)
                    {
                        continue;
                    }
                    NodeID inputNode = input->GetNode();

                    // Reinterpret nodes aren't "real" nodes, they're just used to modify TensorDescs across
                    // edges. So we follow this node backwards until it hits a real node.
                    while (inputNode.type == NodeType::Reinterpret)
                    {
                        input = m_reinterpretNodes[inputNode.index].input;
                        inputNode = input->GetNode();
                    }

                    if (inputNode.type == NodeType::Input)
                    {
                        DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
                        inputEdge.GraphInputIndex = m_inputNodes[inputNode.index].inputIndex;
                        inputEdge.ToNodeIndex = nodeIndex;
                        inputEdge.ToNodeInputIndex = inputIndex;

                        desc.inputEdges.push_back(inputEdge);
                    }
                    else if (inputNode.type == NodeType::Operator)
                    {
                        DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
                        intermediateEdge.FromNodeIndex = inputNode.index;
                        intermediateEdge.FromNodeOutputIndex = input->GetOutputIndex();
                        intermediateEdge.ToNodeIndex = nodeIndex;
                        intermediateEdge.ToNodeInputIndex = inputIndex;

                        desc.intermediateEdges.push_back(intermediateEdge);
                    }
                    else
                    {
                        assert(false); // Invalid node type
                        DMLX_THROW(E_UNEXPECTED);
                    }
                }
            }

            // Add output edges
            for (uint32_t outputIndex = 0; outputIndex < desc.outputCount; ++outputIndex)
            {
                NodeOutput* output = outputs[outputIndex].Impl();
                if (output == nullptr)
                {
                    continue;
                }
                NodeID outputNode = output->GetNode();

                // Reinterpret nodes are meaningless on outputs (they're no-ops), so just follow them back until we
                // get to a real operator node.
                while (outputNode.type == NodeType::Reinterpret)
                {
                    output = m_reinterpretNodes[outputNode.index].input;
                    outputNode = output->GetNode();
                }

                if (outputNode.type == NodeType::Input)
                {
                    // It's not valid to connect an output of the graph directly to an input without an intervening
                    // node. If this behavior is desired, it should instead be accomplished with a copy e.g. using
                    // the elementwise identity operator.
                    DMLX_THROW(E_INVALIDARG);
                }

                assert(outputNode.type == NodeType::Operator);

                DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
                outputEdge.FromNodeIndex = output->GetNode().index;
                outputEdge.FromNodeOutputIndex = output->GetOutputIndex();
                outputEdge.GraphOutputIndex = outputIndex;

                desc.outputEdges.push_back(outputEdge);
            }

            // Sanity
            assert(desc.nodes.size() == m_operatorNodes.size());
            assert(desc.outputEdges.size() == desc.outputCount);
            assert(desc.outputCount == outputs.size());

            return desc;
        }