void CNTKToONNXHelper::ProcessInputs()

in Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp [5968:6316]


void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
    onnxruntime::Graph* graph,
    std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
    std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
    std::vector<onnxruntime::NodeArg *>& inputs,
    std::vector<ScanLoop> &scanLoops, int createLoopIndex)
{
    std::string cntkOpName = ToLegacyString(ToUTF8(src->OpName()));
    std::string onnxOpName = ToOPName(src);

    std::vector<FunctionPtr> fs;
    for (size_t inputIndex = 0; inputIndex < src->Inputs().size(); ++inputIndex)
    {
        auto input = src->Inputs()[inputIndex];

        while (input.IsPlaceholder())
        {
            input = input.BlockFunctionVariableMapping();
            if (!input.IsInitialized())
                LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
        }

        // of the pattern so that complex patterns become not skipped as a whole.
        // retry SkipBatchAndSequenceAxisInput shall be good enough.
        input = SkipBatchAndSequenceAxisInput(input);
        //// UnpackBatchAxis and ToBatchAxis is a noop in ONNX
        //bool dynamicAxisPackUnpackSkipped = false;

        //// TODO: to skip a batch/sequence pack/uppack, we need
        //// to ensure src only sees its direct inputs to maintain dynamic axis semantic of CNTK ops.
        //// However, if batch size is not FreeBatchSize, we need to keep the batch size, not the #.
        //// For example (in c++ shape order):
        //// (1987, 600) -> ToBatchAxis -> (1987, 600)        // because 1987 != FreeBatchSize
        //// ElementTimes with [#][600] -> (1987, 600)
        //// if we keep CNTK dynamic semantics:
        //// (1987, 600) -> ToBatchAxis -> [#](600, )
        //// ElementTimes with [#][600] -> (#, 600) which is (1, 600)
        //if (dynamicAxisPackUnpackSkipped)
        //    input = SkipBatchAndSequenceAxisInput(input);

        // Input might be a placeholder after skipping.
        while (input.IsPlaceholder())
        {
            input = input.BlockFunctionVariableMapping();
            if (!input.IsInitialized())
                LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
        }

        // Special case handling of LayerNormalization layer because it changes
        // ops dynamically based on value of inputs. If more such cases ops are seen,
        // this should be abstracted out from here.
        if (ToLegacyString(ToUTF8(src->OpName())) == "LayerNormalization")
        {
            // If non-zero epsilon was specified, a fourth input is included
            // which must be ignored because we cannot export epsilon to ONNX.
            // See LayerNormalization branch in AddNode() below.
            if (src->Inputs().size() == 4 && inputIndex == 0 && input.IsConstant())
                continue;
        }
        else if (ToLegacyString(ToUTF8(src->OpName())) == "Crop")
        {
            // Export only the first input. In ONNX Crop accepts only 1 input, and there is no notion of referent input.
            if (inputIndex > 0)
                continue;
        }

        if ((src->OpName() == L"Sequence::Slice" || src->OpName() == L"Sequence::IsFirst" || src->OpName() == L"Sequence::IsLast") && inputIndex != src->Inputs().size() - 1)
        {
            // for these sequence ops, only the last input is the real valid input.
            continue;
        }
        else if (FilterInput(src, input, inputIndex))
            continue;

        //
        // Get unique name based on user-defined name if available, otherwise use our internal unique name ID.
        //
        std::string inputName = [&]() {
            auto inputItr = compositeOutputsMap.find(input);
            if (inputItr != compositeOutputsMap.end())
                return UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
            else
                return UniqueNodeNameStorage::GetUniqueInputNodeName(input);
        }();

        bool isConstant = (input.IsParameter() || input.IsConstant());
        // sequence convolution has different indexing which cannot be handled by IgnoreConstantAndParameter
        if (src->OpName() != L"Convolution" || !src->Outputs()[0].HasSequenceAxis())
            isConstant &= !Operators::IgnoreConstantAndParameter(src->OpName(), inputIndex);

        bool isInSubGraph = createLoopIndex >= 0 && createLoopIndex < scanLoops.size();

        bool isScanInputInSubgraph = createLoopIndex != -1 &&
            std::find_if(scanLoops[createLoopIndex].m_scanInputs.begin(), scanLoops[createLoopIndex].m_scanInputs.end(),
                [inputName](Variable v) { return inputName == UniqueNodeNameStorage::GetUniqueInputNodeName(v); }) != scanLoops[createLoopIndex].m_scanInputs.end();

        bool isOutputOfStepFunction = input.Owner() &&
            (input.Owner()->OpName() == L"PastValue" || input.Owner()->OpName() == L"FutureValue");

        onnx::TypeProto inputArgType;

        if (isOutputOfStepFunction)
        {
            if (isInSubGraph)
            {
                // need to take input from step function's initial state (second input to the step function)
                // if initial state is a scalar, it will be created with correct shape later in this method.

                ScanLoop &scanLoop = scanLoops[createLoopIndex];
                // one intial state may map to multiple final states.
                // to make one to one mapping from initial to final states,
                // we have to split the inital state.
                inputName = MakeInitialStateNodeArgName(input);
                inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
            }
        }
        else if (input.Owner() && ONNX::Operators::IsRNNOp(ToLegacyString(ToUTF8(input.Owner()->OpName()))) &&
            isInSubGraph)
        {
            // we are processing subgraph and hit LSTM block.
            // Because LSTM is constructed as a whole compositeOutputsMap does not have map for LSTM block.
            // Now LSTM is in the loop. The LSTM block is decomposed in scan loop.
            // So we need to use its internal names (instead of block names).
            BlockFunction* block = dynamic_cast<BlockFunction *>(input.Owner().get());

            // from block to underlying
            std::unordered_map<Variable, Variable> bm = block->CompositeOutputsMap();
            if (bm.find(input) == bm.end())
                LogicError("cannot map PastValue/Future's input to LSTM underlying output");

            inputName = UniqueNodeNameStorage::GetUniqueInputNodeName(bm[input]);
        }

        //
        // If this input is output, then it is the ouput of an up stream node. Recursively add all upstream nodes.
        // Pretty much, we are doing DFS.
        //
        if (input.IsOutput())
            // fs.push_back(input.Owner());
            CreateNode(input.Owner(), graph, functionNodes, variableNodes,
                scanLoops, createLoopIndex);

        if (cntkOpName == "Splice")
        {
            // for ops like Concat, batch axis may exist in one of the operand
            // CNTK allows the other operand(s) not having batch axis. But ONNX
            // requires operands to have the same rank
            inputArgType = ToTypeProto(input.Shape(), OpInputsHasBatchAxis(src), input.HasSequenceAxis());
        }
        else if (cntkOpName == "ImageScaler")
        {
            // TODO: verify - ONNX specifies that ImageScaler always need a batch axis
            inputArgType = ToTypeProto(input.Shape(), true);
        }
        else if (cntkOpName == "Convolution")
        {
            const size_t ConvWeightIndex = 0u;
            const size_t ConvOperandIndex = 1u;
            NDShape inputShape = input.Shape();
            if (inputIndex == ConvWeightIndex)
            {
                // CNTK kernel shape can omit the out channel axis if its value equals to 1.
                // On the other hand, ONNX spec requires out channel axis to be explicitly set.
                // w: [O x C x W x H], operand: [N] x [C x W x H].
                // Thus insert the emulated out channel axis if needed.
                const NDShape& operandShape = src->Inputs()[ConvOperandIndex].Shape();
                if (operandShape.Rank() >= inputShape.Rank())
                    inputShape = inputShape.AppendShape({ 1 });
                assert(inputShape.Rank() == (operandShape.Rank() + 1));
            }
            inputArgType = ToTypeProto(inputShape, input.HasBatchAxis(), input.HasSequenceAxis());
        }
        else
        {
            inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());

            if (isConstant && cntkOpName == "BatchNormalization" && (inputIndex > 0 && inputIndex <= 4))
            {
                // In case of BatchNormalization, if data (input[0]) is of type FP16, then all BN stats(inputs[1:4])
                // need to be converted from FP32 to FP16 prior to getting exported to ONNX
                if (src->Inputs()[0].GetDataType() == DataType::Float16)
                    input = Utils::ConvertVariableType<float, float16>(input, true);

                //// This is a workaround allowing CNTK V1 pretrained models to continue running after removal of sequence axis from input
                if ((src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0) && input.Shape().Rank() > 1)
                    inputArgType = ToTypeProto(input.Shape().SubShape(0, 1), input.HasBatchAxis(), input.HasSequenceAxis());
            }
        }

        // TODO: if it is an identity op, we shall peek its input node to find the correct tensor element type.

        if (onnxOpName == "Identity")
        {
            // shall match the type of the same name NodeArg from upstream.
            string inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
            if (!TryMatchNodeArgType(inputArgType, graph, inputNodeArgName))
                UpdateONNXType(src->Inputs()[0].GetDataType(), inputArgType);
        }
        else if (OpNeedONNXTypeMap(cntkOpName))
        {
            if (!input.IsOutput())
            {
                MapAndUpdateONNXType(onnxOpName, true, inputIndex, input.GetDataType(), &inputArgType);
            }
            else
            {
                // input NodeArg has already been created as an output NodeArg of the previous function node.
                // a Cast op needs to be inserted to get the desired type in ONNX.
                TensorProto_DataType onnx_type = MapAndUpdateONNXType(onnxOpName, true, inputIndex, input.GetDataType(), nullptr);
                if (ConvertDataTypeCNTKToTensorProto(input.GetDataType()) != onnx_type)
                {
                    UpdateONNXType(input.GetDataType(), inputArgType);
                    onnxruntime::NodeArg &castInputArg = graph->GetOrCreateNodeArg(inputName, &inputArgType);
                    onnxruntime::Node* castNode = AddCastNode(castInputArg, graph, onnx_type, ToLegacyString(ToUTF8(src->Uid())));
                    inputs.push_back(const_cast<NodeArg *>(castNode->OutputDefs()[0]));

                    // we already completed preparation of this input and can proceed to the next input.
                    continue;
                }
                else if (isInSubGraph)
                {
                    //
                    UpdateONNXType(input.GetDataType(), inputArgType);
                }
            }
        }
        else
        {
            UpdateONNXType(input.GetDataType(), inputArgType);
        }

        bool addedInitializer = false;
        //
        // Leaf nodes are data entry to the graph and need their own node with only output arg.
        //
        if (isConstant)
        {
            if (variableNodes.find(input) == variableNodes.end())
            {
                if (input.IsParameter() || input.IsConstant())
                {
                    auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();

                    onnx::TensorProto dstTensor;
                    dstTensor.set_name(inputName);

                    CopyTensor(srcTensor, dstTensor, &inputArgType);
                    if (CNTKToONNXHelper::globalGraph && createLoopIndex != -1)
                    {
                        scanLoops[createLoopIndex].initializerAsInput.push_back(inputName);

                        // With Bing.Malta50.proto1_128_gru_normv3_ep3_z.model, I can only got ONNX runtime
                        // to produce matching results by putting initializers in the subgraphs
                        // (calling graph->AddInitializedTensor instead).
                        CNTKToONNXHelper::globalGraph->AddInitializedTensor(dstTensor);
                        // graph->AddInitializedTensor(dstTensor);
                        addedInitializer = true;
                    }
                    else
                        graph->AddInitializedTensor(dstTensor);
                }
            }
        }

        onnxruntime::NodeArg *adjusted = nullptr;
        if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
        {
            inputName = MakeScanInputOutputNodeArgName(inputName);
            
            // in case of broadcast, we want the input name unchanged. 
            // The inserted reshape op is treated as being inside of the scan subgraph.
            adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType, inputName);
        }
        else
        {
            adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
        }

        onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;
        if (addedInitializer)
        {
            graph->AddOuterScopeNodeArg(inputArg.Name());
        }

        inputs.push_back(&inputArg);

        if (cntkOpName == "Reshape")
        {
            // ONNX1.2 reshape node take shape as input instead of attribute.

            // We can construct the shape input for onnx by two ways: 1. cntk node output shape, or 2. cntk node attribute "newShape".
            // If there attribute "newShape" is missing, or attributes "beginAxis" and "endAxis" exists, we use cntk node output shape.
            // such that we don't need to duplicate the shape inference logic here.
            // Otherwise we use the cntk node attribute "newShape".
            bool useOutputShape = [&]() {
                if (!src->Attributes().Contains(L"newShape") || ((NDShape)src->Attributes()[L"newShape"].Value<NDShape>()).Rank() == 0)
                    return true;
                if (src->Attributes().Contains(L"beginAxis") && ((Axis)src->Attributes()[L"beginAxis"].Value<Axis>()).StaticAxisIndex() != 0)
                    return true;
                if (src->Attributes().Contains(L"endAxis") && ((Axis)src->Attributes()[L"endAxis"].Value<Axis>()).StaticAxisIndex() != src->Inputs()[0].Shape().Rank())
                    return true;
                return false;
            }();
            const NDShape shape = useOutputShape ? src->Output().Shape() : (NDShape)src->Attributes()[L"newShape"].Value<NDShape>();
            const NDShape inputShape = src->Inputs()[0].Shape();
            std::vector<int64_t> newShapeVec;
            size_t numInferredDimensions(0);
            // If output has batch axis, then create an output shape (which goes in as input to the
            // ONNX node) with an additional axis for batch axis (1).
            // ONNX dimensions are left aligned
            if (src->Output().HasSequenceAxis() && !isInSubGraph)
                newShapeVec.push_back(NDShape::FreeDimension);
            if (src->Output().HasBatchAxis())
                newShapeVec.push_back(BatchSizeProcessor::FreeBatchSize());
            for (int i = 0; i < shape.Rank(); i++)
            {
                int indexToOutputShape = shape.Rank() - i - 1;
                int indexToInputShape = inputShape.Rank() - i - 1;
                const auto& axisSize = shape.Dimensions()[indexToOutputShape];
                if (axisSize == NDShape::InferredDimension)
                {
                    numInferredDimensions++;
                    if (numInferredDimensions > 1)
                        LogicError("Reshape: Multiple InferredDimension not supported by ONNX.");
                    else
                        newShapeVec.push_back(ReshapeInferredDim);
                }
                else if (axisSize == NDShape::FreeDimension &&
                    indexToInputShape >= 0 && inputShape[indexToInputShape] != NDShape::FreeDimension)
                {
                    numInferredDimensions++;
                    if (numInferredDimensions > 1)
                        LogicError("Reshape: Multiple InferredDimension not supported by ONNX.");
                    newShapeVec.push_back(ReshapeInferredDim);
                }
                else // REVIEW SPTIWARI: Should we fill 0 for FreeDimension here?
                    newShapeVec.push_back(static_cast<int64_t>(axisSize));
            }

            // std::reverse(newShapeVec.begin(), newShapeVec.end());
            onnx::TypeProto shapeInputArgType = ToTypeProto(std::vector<int64_t>({ (int64_t)newShapeVec.size() }));
            shapeInputArgType.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_INT64);

            onnxruntime::NodeArg &shapeInputArg = graph->GetOrCreateNodeArg(ToLegacyString(ToUTF8(src->Output().Uid())) + "_shape", &shapeInputArgType);
            inputs.push_back(&shapeInputArg);
            AddShapeInitializer(shapeInputArg.Name(), newShapeVec, graph);
        }
    }
}