ComputationNodeBasePtr CompositeFunction::CreateComputationNode()

in Source/CNTKv2LibraryDll/CompositeFunction.cpp [668:1436]


    /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
                                                                               Function* function,
                                                                               const std::vector<std::shared_ptr<ComputationNodeBase>>& inputNodes,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               bool useMangledNamesForComputationNodes)
    {
        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
            return variableToNodeMap[GetMappingVariable(variable)];

        ComputationNodeBasePtr computationNodePtr;

        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes);

        std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
        for (auto inputNode : inputNodes)
            inputNodesBasePtrs.push_back(inputNode);

        // take the dataType from the first input, if not specified (i.e. placeholder) then use default
        // node like BatchNormalization may have inputs with different precision,
        // and that validation is done in specific node constructor
        DataType inputNodeType = AsDataType<ElementType>();
        if (inputNodes.size() > 0)
        {
            if (std::dynamic_pointer_cast<ComputationNode<float>, ComputationNodeBase>(inputNodes[0]))
                inputNodeType = DataType::Float;
            else if (std::dynamic_pointer_cast<ComputationNode<double>, ComputationNodeBase>(inputNodes[0]))
                inputNodeType = DataType::Double;
            else if (std::dynamic_pointer_cast<ComputationNode<half>, ComputationNodeBase>(inputNodes[0]))
                inputNodeType = DataType::Float16;
        }

#define ASSIGN_NEW_NODE(nodeClass, ...)                               \
    do {                                                              \
        if (inputNodeType == DataType::Float)                         \
            computationNodePtr = New<nodeClass<float>>(__VA_ARGS__);  \
        else if (inputNodeType == DataType::Double)                   \
            computationNodePtr = New<nodeClass<double>>(__VA_ARGS__); \
        else if (inputNodeType == DataType::Float16)                  \
            computationNodePtr = New<nodeClass<half>>(__VA_ARGS__);   \
    } while(0)

#define ASSIGN_NEW_NODE2(nodeClass, dtype, ...)                              \
    do {                                                                     \
        if (inputNodeType == DataType::Float)                                \
            computationNodePtr = New<nodeClass<dtype, float>>(__VA_ARGS__);  \
        else if (inputNodeType == DataType::Double)                          \
            computationNodePtr = New<nodeClass<dtype, double>>(__VA_ARGS__); \
        else if (inputNodeType == DataType::Float16)                         \
            computationNodePtr = New<nodeClass<dtype, half>>(__VA_ARGS__);   \
    } while(0)

        auto outputs = function->RawOutputs();
        if (variable == outputs[0])
        {
            if (primitiveFunction)
            {
                auto functionInputs = function->Inputs();
                auto& functionConfig = function->Attributes();
                PrimitiveOpType op = primitiveFunction->OpType();

                switch (op)
                {
                case PrimitiveOpType::Negate:
                    ASSIGN_NEW_NODE(NegateNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sigmoid:
                    ASSIGN_NEW_NODE(SigmoidNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Atanh:
                    ASSIGN_NEW_NODE(AtanhNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Tanh:
                    ASSIGN_NEW_NODE(TanhNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Acos:
                    ASSIGN_NEW_NODE(AcosNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Cos:
                    ASSIGN_NEW_NODE(CosineNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Asin:
                    ASSIGN_NEW_NODE(AsinNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sin:
                    ASSIGN_NEW_NODE(SinNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Atan:
                    ASSIGN_NEW_NODE(AtanNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Tan:
                    ASSIGN_NEW_NODE(TanNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Cosh:
                    ASSIGN_NEW_NODE(CoshNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Asinh:
                    ASSIGN_NEW_NODE(AsinhNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sinh:
                    ASSIGN_NEW_NODE(SinhNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ReLU:
                    ASSIGN_NEW_NODE(RectifiedLinearNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Exp:
                    ASSIGN_NEW_NODE(ExpNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Log:
                    ASSIGN_NEW_NODE(LogNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sqrt:
                    ASSIGN_NEW_NODE(SqrtNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ELU:
                    ASSIGN_NEW_NODE(ExponentialLinearUnitNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Floor:
                    ASSIGN_NEW_NODE(FloorNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Abs:
                    ASSIGN_NEW_NODE(AbsNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Reciprocal:
                    ASSIGN_NEW_NODE(ReciprocalNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Softmax:
                    ASSIGN_NEW_NODE(SoftmaxNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Hardmax:
                    ASSIGN_NEW_NODE(HardmaxNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::StraightThrough:
                    ASSIGN_NEW_NODE(StraightThroughNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::TopK:
                {
                    auto k = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumItems].Value<size_t>();
                    ASSIGN_NEW_NODE(TopKNode, network->GetDeviceId(), internalNodeName, k);
                    break;
                }
                case PrimitiveOpType::StableSigmoid:
                    ASSIGN_NEW_NODE(StableSigmoidNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::TransposeAxes:
                {
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec))
                    {
                        auto perm = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                        for (auto& p : perm)
                            p = NormalizeStaticAxis(p, perm.size());
                        ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm));
                    }
                    else
                    {
                        auto axis1 = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis1].Value<Axis>();
                        auto axis2 = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis2].Value<Axis>();

                        // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
                        ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
                    }
                    break;
                }
                case PrimitiveOpType::Where:
                {
                    auto dynamicAxes = variable.DynamicAxes();
                    auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                    ASSIGN_NEW_NODE(WhereNode, network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
                    break;
                }
                case PrimitiveOpType::ToSequence:
                {
                    auto dynamicAxes = variable.DynamicAxes();
                    auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                    ASSIGN_NEW_NODE(ToSequenceNode, network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName);
                    break;
                }
                case PrimitiveOpType::ToSequenceLike:
                    ASSIGN_NEW_NODE(ToSequenceLikeNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::UnpackSequence:
                {
                    auto paddingValue = functionConfig[PrimitiveFunctionAttribute::AttributeNameSequenceUnpackPaddingValue].Value<double>();
                    auto suppressMaskOutput = functionConfig[PrimitiveFunctionAttribute::AttributeNameSequenceUnpackSuppressMaskOutput].Value<bool>();
                    ASSIGN_NEW_NODE(UnpackSequenceNode, network->GetDeviceId(), internalNodeName, paddingValue, suppressMaskOutput);
                    break;
                }
                case PrimitiveOpType::Slice:
                {
                    std::vector<Axis> axis;
                    std::vector<int> beginIndex, endIndex, strides;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec) &&
                        functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginIndexVec) &&
                        functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndIndexVec))
                    {
                        axis = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                        beginIndex = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginIndexVec].Value<std::vector<DictionaryValue>>());
                        endIndex = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameEndIndexVec].Value<std::vector<DictionaryValue>>());
                        if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameSliceStridesVec))
                            strides = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSliceStridesVec].Value<std::vector<DictionaryValue>>());
                        else
                            strides.resize(axis.size(), 1);
                    }
                    else if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxis) &&
                        functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginIndex) &&
                        functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndIndex))
                    {
                        axis.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>());
                        beginIndex.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginIndex].Value<int>());
                        endIndex.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameEndIndex].Value<int>());
                        if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameSliceStrides))
                            strides.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameSliceStrides].Value<int>());
                        else
                            strides.push_back(1);
                    }
                    else
                    {
                        RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes");
                    }
                    // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
                    ASSIGN_NEW_NODE(SliceNode, network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis), strides);
                    break;
                }
                case PrimitiveOpType::RandomSample:
                {
                    auto numSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumSamples].Value<size_t>();
                    auto allowDuplicates = functionConfig[PrimitiveFunctionAttribute::AttributeNameAllowDuplicates].Value<bool>();
                    ASSIGN_NEW_NODE(RandomSampleNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                    break;
                }
                case PrimitiveOpType::RandomSampleInclusionFrequency:
                {
                    auto numSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumSamples].Value<size_t>();
                    auto allowDuplicates = functionConfig[PrimitiveFunctionAttribute::AttributeNameAllowDuplicates].Value<bool>();
                    ASSIGN_NEW_NODE(RandomSampleInclusionFrequencyNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                    break;
                }
                case PrimitiveOpType::Dropout:
                {
                    auto dropoutRate = functionConfig[PrimitiveFunctionAttribute::AttributeNameDropoutRate].Value<double>();
                    ASSIGN_NEW_NODE(DropoutNode, network->GetDeviceId(), internalNodeName);
                    SMART_NODE_INVOKE(DropoutNode, computationNodePtr, SetDropoutRate, dropoutRate);
                    break;
                }
                case PrimitiveOpType::RandomDistribution:
                {
                    auto seed = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngSeed].Value<size_t>();
                    auto offset = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngOffset].Value<size_t>();
                    auto rvtype = functionConfig[PrimitiveFunctionAttribute::AttributeNameRandomDistributionType].Value<std::wstring>();

                    std::vector<double> randomDistributionArgs;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameRandomDistributionArgs))
                        randomDistributionArgs = AsVector<double>(functionConfig[PrimitiveFunctionAttribute::AttributeNameRandomDistributionArgs].Value<std::vector<DictionaryValue>>());

                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameNewShape))
                    {
                        auto shape = functionConfig[PrimitiveFunctionAttribute::AttributeNameNewShape].Value<NDShape>();
                        ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape));
                    }
                    else
                        ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs);
                    SMART_NODE_INVOKE(RandomDistributionNode, computationNodePtr, SetRngState, seed, offset);
                    break;
                }
                case PrimitiveOpType::Reshape:
                {
                    auto beginAxis = Axis(0);
                    auto endAxis = Axis((int)functionInputs[0].Shape().Rank());
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginAxis))
                        beginAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginAxis].Value<Axis>();

                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndAxis))
                        endAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameEndAxis].Value<Axis>();

                    auto replacementShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameNewShape].Value<NDShape>();
                    for (size_t i = 0; i < replacementShape.Rank(); ++i)
                    {
                        if (replacementShape[i] == NDShape::InferredDimension)
                            replacementShape[i] = 0;
                        else if (replacementShape[i] == NDShape::FreeDimension)
                            // ReshappingNodes::Validate() uses input sample (with free dimension being set) to calculate sampleLayout.
                            // Set free dimension to 0 as well so that it can be inferred from sample.
                            // The drawback is that we cannot support shapes with more than 1 inferred/free combined dimensions.
                            // More work on ReshappingNodes is needed if it is required to handle more than 1 inferred/free combined dimensions.
                            replacementShape[i] = 0;
                    }

                    ASSIGN_NEW_NODE(ReshapeNode, network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
                    break;
                }
                case PrimitiveOpType::Squeeze:
                {
                    auto beginAxis = Axis(0);
                    auto inputShape = functionInputs[0].Shape();
                    auto endAxis = Axis((int)inputShape.Rank());
                    auto outputShape = GetSqueezedShape(inputShape, functionConfig);

                    computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(outputShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
                    break;
                }
                case PrimitiveOpType::ConstantOp:
                {
                    double fillValue = functionConfig[PrimitiveFunctionAttribute::AttributeNameFillValue].Value<double>();
                    computationNodePtr = New<ConstantNode<ElementType>>(network->GetDeviceId(), internalNodeName, fillValue);
                    break;
                }
                case PrimitiveOpType::EyeLikeOp:
                {
                    bool outputSparse = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputSparse].Value<bool>();
                    ASSIGN_NEW_NODE(EyeLikeNode, network->GetDeviceId(), internalNodeName, outputSparse);
                    break;
                }
                case PrimitiveOpType::ROIPooling:
                {
                    PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingType].Value<size_t>());
                    auto roiOutputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameROIOutputShape].Value<NDShape>();
                    auto spatialScale = functionConfig[PrimitiveFunctionAttribute::AttributeNameSpatialScale].Value<double>();
                    ASSIGN_NEW_NODE(ROIPoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(roiOutputShape), spatialScale);
                    break;
                }
                case PrimitiveOpType::Pooling:
                {
                    PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingType].Value<size_t>());
                    auto poolingWindowsShape = functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingWindowShape].Value<NDShape>();
                    auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    auto ceilOutDim = false;
                    auto includePad = false;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameCeilOutDim))
                    {
                        ceilOutDim = functionConfig[PrimitiveFunctionAttribute::AttributeNameCeilOutDim].Value<bool>();
                    }
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameIncludePad))
                    {
                        includePad = functionConfig[PrimitiveFunctionAttribute::AttributeNameIncludePad].Value<bool>();
                    }
                    ASSIGN_NEW_NODE(PoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::Unpooling:
                {
                    auto unpoolingWindowShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameUnpoolingWindowShape].Value<NDShape>();
                    auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    //We only get here after validation so it is safe to assume unpooling is max
                    ASSIGN_NEW_NODE(MaxUnpoolingNode, network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::SumAll:
                    ASSIGN_NEW_NODE(SumElementsNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::OneHot:
                {
                    auto numClass = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumClass].Value<size_t>();
                    auto is_sparse = functionConfig[PrimitiveFunctionAttribute::AttributeNameOneHotOutputSparse].Value<bool>();
                    auto axis = functionConfig[PrimitiveFunctionAttribute::AttributeNameOneHotAxis].Value<Axis>();
                    ASSIGN_NEW_NODE(OneHotNode, network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Gather:
                    ASSIGN_NEW_NODE(GatherNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ToBatch:
                {
                    ASSIGN_NEW_NODE(ToBatchAxisNode, network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::UnpackBatch:
                {
                    ASSIGN_NEW_NODE(UnpackBatchAxisNode, network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Plus:
                    ASSIGN_NEW_NODE(PlusNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LogPlus:
                    ASSIGN_NEW_NODE(LogPlusNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Pow:
                    ASSIGN_NEW_NODE(PowNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Minus:
                    ASSIGN_NEW_NODE(MinusNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ElementTimes:
                    ASSIGN_NEW_NODE(ElementTimesNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Equal:
                    ASSIGN_NEW_NODE(EqualNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::NotEqual:
                    ASSIGN_NEW_NODE(NotEqualNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Less:
                    ASSIGN_NEW_NODE(LessNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LessEqual:
                    ASSIGN_NEW_NODE(LessEqualNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Greater:
                    ASSIGN_NEW_NODE(GreaterNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::GreaterEqual:
                    ASSIGN_NEW_NODE(GreaterEqualNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Times:
                {
                    size_t outputRank = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputRank].Value<size_t>();
                    auto inferInputRankToMap = functionConfig[PrimitiveFunctionAttribute::AttributeNameInferInputRankToMap].Value<int>();
                    ASSIGN_NEW_NODE(TimesNode, network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
                    break;
                }
                case PrimitiveOpType::TransposeTimes:
                {
                    size_t outputRank = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputRank].Value<size_t>();
                    ASSIGN_NEW_NODE(TransposeTimesNode, network->GetDeviceId(), internalNodeName, outputRank);
                    break;
                }
                case PrimitiveOpType::Convolution:
                {
                    auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
                    NDShape dilation = { 1 };
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDilation))
                        dilation = functionConfig[PrimitiveFunctionAttribute::AttributeNameDilation].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
                    auto sharing = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    auto transpose = functionConfig[PrimitiveFunctionAttribute::AttributeNameTranspose].Value<bool>();
                    NDShape outputMapCount, kernelShape;
                    std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
                    NDShape outputShape = NDShape::Unknown();
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOutputShape))
                        outputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputShape].Value<NDShape>();
                    auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameGroups))
                        groups = functionConfig[PrimitiveFunctionAttribute::AttributeNameGroups].Value<size_t>();
                    auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
                    ASSIGN_NEW_NODE(ConvolutionNode, network->GetDeviceId(), internalNodeName,
                                    AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
                                    sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
                                    outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
                                    ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups);
                    break;
                }
                case PrimitiveOpType::ConvolutionSequenceShape:
                {
                    auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
                    NDShape dilation = { 1 };
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDilation))
                        dilation = functionConfig[PrimitiveFunctionAttribute::AttributeNameDilation].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
                    auto sharing = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    auto transpose = functionConfig[PrimitiveFunctionAttribute::AttributeNameTranspose].Value<bool>();
                    NDShape outputMapCount, kernelShape;
                    std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
                    NDShape outputShape = NDShape::Unknown();
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOutputShape))
                        outputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputShape].Value<NDShape>();
                    auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameGroups))
                        groups = functionConfig[PrimitiveFunctionAttribute::AttributeNameGroups].Value<size_t>();
                    auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
                    ASSIGN_NEW_NODE(ConvolutionSequenceShapeNode, network->GetDeviceId(), internalNodeName,
                        AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
                        sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
                        outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
                        ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups);
                    break;
                }
                case PrimitiveOpType::CosDistance:
                    ASSIGN_NEW_NODE(CosDistanceNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::CosDistanceWithNegativeSamples:
                    ASSIGN_NEW_NODE(CosDistanceWithNegativeSamplesNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Logistic:
                    ASSIGN_NEW_NODE(LogisticNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::SquaredError:
                    ASSIGN_NEW_NODE(SquareErrorNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::CrossEntropyWithSoftmax:
                    ASSIGN_NEW_NODE(CrossEntropyWithSoftmaxNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ClassificationError:
                    ASSIGN_NEW_NODE(ClassificationErrorNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::EditDistanceError:
                {
                    auto subPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameSubstitutionPenalty].Value<float>();
                    auto delPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameDeletionPenalty].Value<float>();
                    auto insPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameInsertionPenalty].Value<float>();
                    auto squashInputs = functionConfig[PrimitiveFunctionAttribute::AttributeNameSquashInputs].Value<bool>();
                    auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
                    ASSIGN_NEW_NODE(EditDistanceErrorNode, network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
                    break;
                }
                case PrimitiveOpType::LatticeSequenceWithSoftmax:
                {
                    auto symListPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameSymListPath].Value<wstring>();
                    auto phonePath = functionConfig[PrimitiveFunctionAttribute::AttributeNamePhonePath].Value<wstring>();
                    auto stateListPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameStateListPath].Value<wstring>();
                    auto transProbPath =  functionConfig[PrimitiveFunctionAttribute::AttributeNameTransProbPath].Value<wstring>();
                    auto latticeConfigPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameLatticeConfigPath].Value<wstring>();
                    auto frameDropThresh = functionConfig[PrimitiveFunctionAttribute::AttributeNameFrameDropThresh].Value<float>();
                    auto doReferenceAlign = functionConfig[PrimitiveFunctionAttribute::AttributeNameDoReferenceAlign].Value<bool>();
                    auto seqGammarUsesMBR = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarUsesMBR].Value<bool>();
                    auto seqGammarAMF = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarAMF].Value<float>();
                    auto seqGammarLMF = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarLMF].Value<float>();
                    auto seqGammarBMMIFactor = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarBMMIFactor].Value<float>();
                    auto seqGammarWordPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarWordPen].Value<float>();
                    auto hSmoothingWeight = functionConfig[PrimitiveFunctionAttribute::AttributeNameHSmoothingWeight].Value<float>();

                    computationNodePtr = New<LatticeSequenceWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath,
                        hSmoothingWeight, frameDropThresh, doReferenceAlign, seqGammarUsesMBR, seqGammarAMF, seqGammarLMF, seqGammarBMMIFactor, seqGammarWordPen);
                    break;
                }
                case PrimitiveOpType::ForwardBackward:
                {
                    auto delayContraint = functionConfig[PrimitiveFunctionAttribute::AttributeNameDelayConstraint].Value<int>();
                    auto blankTokenId = functionConfig[PrimitiveFunctionAttribute::AttributeNameBlankTokenId].Value<size_t>();
                    ASSIGN_NEW_NODE(ForwardBackwardNode, network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
                    break;
                }
                case PrimitiveOpType::LambdaRank:
                    ASSIGN_NEW_NODE(LambdaRankNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::NDCG:
                    ASSIGN_NEW_NODE(NDCG1EvalNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::PastValue:
                case PrimitiveOpType::FutureValue:
                {
                    Variable inputOperandVar = functionInputs[0];
                    Variable initialStateVar = functionInputs[1];

                    size_t offset = primitiveFunction->Attributes()[PrimitiveFunctionAttribute::AttributeNameOffset].Value<size_t>();
                    if (op == PrimitiveOpType::PastValue)
                        ASSIGN_NEW_NODE(PastValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
                    else
                        ASSIGN_NEW_NODE(FutureValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);

                    break;
                }
                case PrimitiveOpType::ReduceElements:
                {
                    bool keepDimensions = true;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameReductionKeepDimensions))
                        keepDimensions = functionConfig[PrimitiveFunctionAttribute::AttributeNameReductionKeepDimensions].Value<bool>();
                    auto reductionOpName = functionConfig[PrimitiveFunctionAttribute::AttributeNameReductionOpName].Value<std::wstring>();
                    std::vector<Axis> reductionAxis;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec))
                    {
                        reductionAxis = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                     }
                    else if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxis))
                    {
                        reductionAxis.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>());
                    }
                    else
                    {
                        RuntimeError("Failed to create computation node': Reduce operation %ls with no '%ls' or  '%ls' attributes",
                            PrimitiveOpTypeName(op).c_str(),
                            PrimitiveFunctionAttribute::AttributeNameAxis.c_str(),
                            PrimitiveFunctionAttribute::AttributeNameAxisVec.c_str()
                        );

                    } 
                    ASSIGN_NEW_NODE(ReduceElementsNode, network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions);
                    break;
                }
                case PrimitiveOpType::BatchNormalization:
                {
                    auto spatial = functionConfig[PrimitiveFunctionAttribute::AttributeNameSpatial].Value<bool>();
                    auto normalizationTimeConstant = functionConfig[PrimitiveFunctionAttribute::AttributeNameNormalizationTimeConstant].Value<double>();
                    auto blendTimeConstant = functionConfig[PrimitiveFunctionAttribute::AttributeNameBlendTimeConstant].Value<double>();
                    auto epsilon = functionConfig[PrimitiveFunctionAttribute::AttributeNameEpsilon].Value<double>();
                    auto useCuDNNEngine = functionConfig[PrimitiveFunctionAttribute::AttributeNameUseCuDNNEngine].Value<bool>();
                    
                    bool disableRegularization = false;
                    if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDisableRegularization))
                    {
                        disableRegularization = functionConfig[PrimitiveFunctionAttribute::AttributeNameDisableRegularization].Value<bool>();
                    }
                    
                    ASSIGN_NEW_NODE(BatchNormalizationNode, network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, disableRegularization, ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::Combine:
                    // This operation is just a no-op and is a means to combine multiple functions to create a single Function
                    // whose outputs are a union of the outputs of the Functions being combined.
                    computationNodePtr = variableToNodeMap[variable];
                    break;
                case PrimitiveOpType::PackedIndex:
                    ASSIGN_NEW_NODE(PackedIndexNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::GatherPacked:
                    ASSIGN_NEW_NODE(GatherPackedNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ScatterPacked:
                    ASSIGN_NEW_NODE(ScatterPackedNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Clip:
                    ASSIGN_NEW_NODE(ClipNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Select:
                    ASSIGN_NEW_NODE(IfNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Splice:
                {
                    Axis spliceAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>();
                    ASSIGN_NEW_NODE(RowStackNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
                    break;
                }
                case PrimitiveOpType::Pad:
                {
                    auto head = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingHead].Value<std::vector<DictionaryValue>>());
                    auto foot = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingFoot].Value<std::vector<DictionaryValue>>());
                    auto mode = functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingMode].Value<size_t>();
                    auto constantValue = functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingConstantValue].Value<double>();
                    ASSIGN_NEW_NODE(PaddingNode, network->GetDeviceId(), internalNodeName, head, foot, (PaddingType)mode, constantValue);
                    break;
                }
                case PrimitiveOpType::OptimizedRNNStack:
                {
                    auto bidirectional = functionConfig[PrimitiveFunctionAttribute::AttributeNameBidirectional].Value<bool>();
                    auto numLayers = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumLayers].Value<size_t>();
                    auto hiddenSize = functionConfig[PrimitiveFunctionAttribute::AttributeNameHiddenSize].Value<size_t>();
                    auto recurrentOp = functionConfig[PrimitiveFunctionAttribute::AttributeNameRecurrentOp].Value<std::wstring>();

                    ASSIGN_NEW_NODE(OptimizedRNNStackNode, network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
                    break;
                }
                case PrimitiveOpType::ReconcileDynamicAxis:
                {
                    ASSIGN_NEW_NODE(ReconcileDynamicAxisNode, network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::LogSoftmax:
                {
                    //This can be implemented as x => x - ReduceLogSum(x). How to do this here?
                    ASSIGN_NEW_NODE(LogSoftmaxNode, network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Pass:
                    ASSIGN_NEW_NODE(PassNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LabelsToGraph:
                    ASSIGN_NEW_NODE(LabelsToGraphNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::StopGradient:
                    ASSIGN_NEW_NODE(StopGradientNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Assign:
                    ASSIGN_NEW_NODE(AssignNode, network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Crop:
                    if (functionInputs.size() == 2)
                    {
                        if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOffset))
                        {
                            // Crop with given offsets.
                            const auto& offsets = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNameOffset].Value<std::vector<DictionaryValue>>());
                            if (offsets.size() != 2)
                            {
                                CNTK::LogicError("Vector of crop offsets must have size 2.");
                            }
                            ASSIGN_NEW_NODE(CropNode, offsets[0], offsets[1], network->GetDeviceId(), internalNodeName);
                        }
                        else
                        {
                            // Crop with two inputs and automatic offset computation.
                            ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName);
                        }
                    }
                    else if (functionInputs.size() == 4)
                    {
                        // Crop with four inputs and automatic offset computation.
                        ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName);
                    }
                    else
                    {
                        CNTK::LogicError("Crop node must have 2 or 4 node inputs.");
                    }
                    break;
                case PrimitiveOpType::Cast:
                {
                    DataType outputType = (DataType)functionConfig[PrimitiveFunctionAttribute::AttributeNameNewDataType].Value<int>();
                    switch (outputType)
                    {
                    case DataType::Float:
                        ASSIGN_NEW_NODE2(CastNode, float, network->GetDeviceId(), internalNodeName);
                        break;
                    case DataType::Double:
                        ASSIGN_NEW_NODE2(CastNode, double, network->GetDeviceId(), internalNodeName);
                        break;
                    case DataType::Float16:
                        ASSIGN_NEW_NODE2(CastNode, half, network->GetDeviceId(), internalNodeName);
                        break;
                    }
                    break;
                }
                case PrimitiveOpType::CustomProxyOp:
                {
                    ASSIGN_NEW_NODE(CustomProxyOpNode, network->GetDeviceId(), internalNodeName);
                    break;
                }
                default:
                    CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
                    break;
                }

                // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);

                if (computationNodePtr->Is<INumInputs>())
                {
                    auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
                    if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
                        CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.",
                            function->AsString().c_str(),
                            (int)inputNodesBasePtrs.size(),
                            (int)computationNodeExpectedInputCount);
                }

                if (computationNodePtr->Is<RngUser>())
                {
                    auto seed = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngSeed].Value<size_t>();
                    auto offset = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngOffset].Value<size_t>();
                    computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
                }
            }
            else
            {
                ASSIGN_NEW_NODE(UserDefinedV2FunctionNode, network->GetDeviceId(), internalNodeName, function->shared_from_this());

                // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF
                // backward implementations directly compute aggregate gradient values for unique inputs
                std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
                for (auto inputNodeBasePtr : inputNodesBasePtrs)
                {
                    if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
                        uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
                }

                inputNodesBasePtrs = uniqueInputNodesBasePtrs;
            }
        }
        else
        {
            size_t i = 1;
            while (outputs[i] != variable) i++;
            assert(i < outputs.size());

            ASSIGN_NEW_NODE(OutputMultiplexerNode, network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i);
            inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
        }

        network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
        return computationNodePtr;
    }