in Libraries/DirectMLX.h [3911:4010]
inline GraphDesc GraphBuilder::GetGraphDesc(Span<const Expression> outputs) const
{
GraphDesc desc = {};
desc.inputCount = static_cast<uint32_t>(m_inputNodes.size());
desc.outputCount = static_cast<uint32_t>(outputs.size());
for (const OperatorNode& node : m_operatorNodes)
{
uint32_t nodeIndex = static_cast<uint32_t>(desc.nodes.size());
desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get() });
// Walk through each of this node's inputs and add it as an edge
const uint32_t inputCount = static_cast<uint32_t>(node.inputs.size());
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
NodeOutput* input = node.inputs[inputIndex];
if (input == nullptr)
{
continue;
}
NodeID inputNode = input->GetNode();
// Reinterpret nodes aren't "real" nodes, they're just used to modify TensorDescs across
// edges. So we follow this node backwards until it hits a real node.
while (inputNode.type == NodeType::Reinterpret)
{
input = m_reinterpretNodes[inputNode.index].input;
inputNode = input->GetNode();
}
if (inputNode.type == NodeType::Input)
{
DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
inputEdge.GraphInputIndex = m_inputNodes[inputNode.index].inputIndex;
inputEdge.ToNodeIndex = nodeIndex;
inputEdge.ToNodeInputIndex = inputIndex;
desc.inputEdges.push_back(inputEdge);
}
else if (inputNode.type == NodeType::Operator)
{
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = inputNode.index;
intermediateEdge.FromNodeOutputIndex = input->GetOutputIndex();
intermediateEdge.ToNodeIndex = nodeIndex;
intermediateEdge.ToNodeInputIndex = inputIndex;
desc.intermediateEdges.push_back(intermediateEdge);
}
else
{
assert(false); // Invalid node type
DMLX_THROW(E_UNEXPECTED);
}
}
}
// Add output edges
for (uint32_t outputIndex = 0; outputIndex < desc.outputCount; ++outputIndex)
{
NodeOutput* output = outputs[outputIndex].Impl();
if (output == nullptr)
{
continue;
}
NodeID outputNode = output->GetNode();
// Reinterpret nodes are meaningless on outputs (they're no-ops), so just follow them back until we
// get to a real operator node.
while (outputNode.type == NodeType::Reinterpret)
{
output = m_reinterpretNodes[outputNode.index].input;
outputNode = output->GetNode();
}
if (outputNode.type == NodeType::Input)
{
// It's not valid to connect an output of the graph directly to an input without an intervening
// node. If this behavior is desired, it should instead be accomplished with a copy e.g. using
// the elementwise identity operator.
DMLX_THROW(E_INVALIDARG);
}
assert(outputNode.type == NodeType::Operator);
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = output->GetNode().index;
outputEdge.FromNodeOutputIndex = output->GetOutputIndex();
outputEdge.GraphOutputIndex = outputIndex;
desc.outputEdges.push_back(outputEdge);
}
// Sanity
assert(desc.nodes.size() == m_operatorNodes.size());
assert(desc.outputEdges.size() == desc.outputCount);
assert(desc.outputCount == outputs.size());
return desc;
}