in Source/CNTKv2LibraryDll/CompositeFunction.cpp [668:1436]
/*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
Function* function,
const std::vector<std::shared_ptr<ComputationNodeBase>>& inputNodes,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
bool useMangledNamesForComputationNodes)
{
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
return variableToNodeMap[GetMappingVariable(variable)];
ComputationNodeBasePtr computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes);
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputNode : inputNodes)
inputNodesBasePtrs.push_back(inputNode);
// take the dataType from the first input, if not specified (i.e. placeholder) then use default
// node like BatchNormalization may have inputs with different precision,
// and that validation is done in specific node constructor
DataType inputNodeType = AsDataType<ElementType>();
if (inputNodes.size() > 0)
{
if (std::dynamic_pointer_cast<ComputationNode<float>, ComputationNodeBase>(inputNodes[0]))
inputNodeType = DataType::Float;
else if (std::dynamic_pointer_cast<ComputationNode<double>, ComputationNodeBase>(inputNodes[0]))
inputNodeType = DataType::Double;
else if (std::dynamic_pointer_cast<ComputationNode<half>, ComputationNodeBase>(inputNodes[0]))
inputNodeType = DataType::Float16;
}
#define ASSIGN_NEW_NODE(nodeClass, ...) \
do { \
if (inputNodeType == DataType::Float) \
computationNodePtr = New<nodeClass<float>>(__VA_ARGS__); \
else if (inputNodeType == DataType::Double) \
computationNodePtr = New<nodeClass<double>>(__VA_ARGS__); \
else if (inputNodeType == DataType::Float16) \
computationNodePtr = New<nodeClass<half>>(__VA_ARGS__); \
} while(0)
#define ASSIGN_NEW_NODE2(nodeClass, dtype, ...) \
do { \
if (inputNodeType == DataType::Float) \
computationNodePtr = New<nodeClass<dtype, float>>(__VA_ARGS__); \
else if (inputNodeType == DataType::Double) \
computationNodePtr = New<nodeClass<dtype, double>>(__VA_ARGS__); \
else if (inputNodeType == DataType::Float16) \
computationNodePtr = New<nodeClass<dtype, half>>(__VA_ARGS__); \
} while(0)
auto outputs = function->RawOutputs();
if (variable == outputs[0])
{
if (primitiveFunction)
{
auto functionInputs = function->Inputs();
auto& functionConfig = function->Attributes();
PrimitiveOpType op = primitiveFunction->OpType();
switch (op)
{
case PrimitiveOpType::Negate:
ASSIGN_NEW_NODE(NegateNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sigmoid:
ASSIGN_NEW_NODE(SigmoidNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Atanh:
ASSIGN_NEW_NODE(AtanhNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Tanh:
ASSIGN_NEW_NODE(TanhNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Acos:
ASSIGN_NEW_NODE(AcosNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cos:
ASSIGN_NEW_NODE(CosineNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Asin:
ASSIGN_NEW_NODE(AsinNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sin:
ASSIGN_NEW_NODE(SinNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Atan:
ASSIGN_NEW_NODE(AtanNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Tan:
ASSIGN_NEW_NODE(TanNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cosh:
ASSIGN_NEW_NODE(CoshNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Asinh:
ASSIGN_NEW_NODE(AsinhNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sinh:
ASSIGN_NEW_NODE(SinhNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ReLU:
ASSIGN_NEW_NODE(RectifiedLinearNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Exp:
ASSIGN_NEW_NODE(ExpNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Log:
ASSIGN_NEW_NODE(LogNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sqrt:
ASSIGN_NEW_NODE(SqrtNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ELU:
ASSIGN_NEW_NODE(ExponentialLinearUnitNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Floor:
ASSIGN_NEW_NODE(FloorNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Abs:
ASSIGN_NEW_NODE(AbsNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Reciprocal:
ASSIGN_NEW_NODE(ReciprocalNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Softmax:
ASSIGN_NEW_NODE(SoftmaxNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Hardmax:
ASSIGN_NEW_NODE(HardmaxNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StraightThrough:
ASSIGN_NEW_NODE(StraightThroughNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TopK:
{
auto k = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumItems].Value<size_t>();
ASSIGN_NEW_NODE(TopKNode, network->GetDeviceId(), internalNodeName, k);
break;
}
case PrimitiveOpType::StableSigmoid:
ASSIGN_NEW_NODE(StableSigmoidNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TransposeAxes:
{
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec))
{
auto perm = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
for (auto& p : perm)
p = NormalizeStaticAxis(p, perm.size());
ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm));
}
else
{
auto axis1 = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis1].Value<Axis>();
auto axis2 = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis2].Value<Axis>();
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
}
break;
}
case PrimitiveOpType::Where:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
ASSIGN_NEW_NODE(WhereNode, network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequence:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
ASSIGN_NEW_NODE(ToSequenceNode, network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequenceLike:
ASSIGN_NEW_NODE(ToSequenceLikeNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::UnpackSequence:
{
auto paddingValue = functionConfig[PrimitiveFunctionAttribute::AttributeNameSequenceUnpackPaddingValue].Value<double>();
auto suppressMaskOutput = functionConfig[PrimitiveFunctionAttribute::AttributeNameSequenceUnpackSuppressMaskOutput].Value<bool>();
ASSIGN_NEW_NODE(UnpackSequenceNode, network->GetDeviceId(), internalNodeName, paddingValue, suppressMaskOutput);
break;
}
case PrimitiveOpType::Slice:
{
std::vector<Axis> axis;
std::vector<int> beginIndex, endIndex, strides;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec) &&
functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginIndexVec) &&
functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndIndexVec))
{
axis = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
beginIndex = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginIndexVec].Value<std::vector<DictionaryValue>>());
endIndex = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameEndIndexVec].Value<std::vector<DictionaryValue>>());
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameSliceStridesVec))
strides = AsVector<int>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSliceStridesVec].Value<std::vector<DictionaryValue>>());
else
strides.resize(axis.size(), 1);
}
else if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxis) &&
functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginIndex) &&
functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndIndex))
{
axis.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>());
beginIndex.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginIndex].Value<int>());
endIndex.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameEndIndex].Value<int>());
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameSliceStrides))
strides.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameSliceStrides].Value<int>());
else
strides.push_back(1);
}
else
{
RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes");
}
// Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
ASSIGN_NEW_NODE(SliceNode, network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis), strides);
break;
}
case PrimitiveOpType::RandomSample:
{
auto numSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunctionAttribute::AttributeNameAllowDuplicates].Value<bool>();
ASSIGN_NEW_NODE(RandomSampleNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::RandomSampleInclusionFrequency:
{
auto numSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunctionAttribute::AttributeNameAllowDuplicates].Value<bool>();
ASSIGN_NEW_NODE(RandomSampleInclusionFrequencyNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::Dropout:
{
auto dropoutRate = functionConfig[PrimitiveFunctionAttribute::AttributeNameDropoutRate].Value<double>();
ASSIGN_NEW_NODE(DropoutNode, network->GetDeviceId(), internalNodeName);
SMART_NODE_INVOKE(DropoutNode, computationNodePtr, SetDropoutRate, dropoutRate);
break;
}
case PrimitiveOpType::RandomDistribution:
{
auto seed = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngOffset].Value<size_t>();
auto rvtype = functionConfig[PrimitiveFunctionAttribute::AttributeNameRandomDistributionType].Value<std::wstring>();
std::vector<double> randomDistributionArgs;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameRandomDistributionArgs))
randomDistributionArgs = AsVector<double>(functionConfig[PrimitiveFunctionAttribute::AttributeNameRandomDistributionArgs].Value<std::vector<DictionaryValue>>());
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameNewShape))
{
auto shape = functionConfig[PrimitiveFunctionAttribute::AttributeNameNewShape].Value<NDShape>();
ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape));
}
else
ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs);
SMART_NODE_INVOKE(RandomDistributionNode, computationNodePtr, SetRngState, seed, offset);
break;
}
case PrimitiveOpType::Reshape:
{
auto beginAxis = Axis(0);
auto endAxis = Axis((int)functionInputs[0].Shape().Rank());
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameBeginAxis))
beginAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameBeginAxis].Value<Axis>();
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameEndAxis))
endAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameEndAxis].Value<Axis>();
auto replacementShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameNewShape].Value<NDShape>();
for (size_t i = 0; i < replacementShape.Rank(); ++i)
{
if (replacementShape[i] == NDShape::InferredDimension)
replacementShape[i] = 0;
else if (replacementShape[i] == NDShape::FreeDimension)
// ReshappingNodes::Validate() uses input sample (with free dimension being set) to calculate sampleLayout.
// Set free dimension to 0 as well so that it can be inferred from sample.
// The drawback is that we cannot support shapes with more than 1 inferred/free combined dimensions.
// More work on ReshappingNodes is needed if it is required to handle more than 1 inferred/free combined dimensions.
replacementShape[i] = 0;
}
ASSIGN_NEW_NODE(ReshapeNode, network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
break;
}
case PrimitiveOpType::Squeeze:
{
auto beginAxis = Axis(0);
auto inputShape = functionInputs[0].Shape();
auto endAxis = Axis((int)inputShape.Rank());
auto outputShape = GetSqueezedShape(inputShape, functionConfig);
computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(outputShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
break;
}
case PrimitiveOpType::ConstantOp:
{
double fillValue = functionConfig[PrimitiveFunctionAttribute::AttributeNameFillValue].Value<double>();
computationNodePtr = New<ConstantNode<ElementType>>(network->GetDeviceId(), internalNodeName, fillValue);
break;
}
case PrimitiveOpType::EyeLikeOp:
{
bool outputSparse = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputSparse].Value<bool>();
ASSIGN_NEW_NODE(EyeLikeNode, network->GetDeviceId(), internalNodeName, outputSparse);
break;
}
case PrimitiveOpType::ROIPooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingType].Value<size_t>());
auto roiOutputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameROIOutputShape].Value<NDShape>();
auto spatialScale = functionConfig[PrimitiveFunctionAttribute::AttributeNameSpatialScale].Value<double>();
ASSIGN_NEW_NODE(ROIPoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(roiOutputShape), spatialScale);
break;
}
case PrimitiveOpType::Pooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingType].Value<size_t>());
auto poolingWindowsShape = functionConfig[PrimitiveFunctionAttribute::AttributeNamePoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto ceilOutDim = false;
auto includePad = false;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameCeilOutDim))
{
ceilOutDim = functionConfig[PrimitiveFunctionAttribute::AttributeNameCeilOutDim].Value<bool>();
}
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameIncludePad))
{
includePad = functionConfig[PrimitiveFunctionAttribute::AttributeNameIncludePad].Value<bool>();
}
ASSIGN_NEW_NODE(PoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Unpooling:
{
auto unpoolingWindowShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameUnpoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
//We only get here after validation so it is safe to assume unpooling is max
ASSIGN_NEW_NODE(MaxUnpoolingNode, network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::SumAll:
ASSIGN_NEW_NODE(SumElementsNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::OneHot:
{
auto numClass = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumClass].Value<size_t>();
auto is_sparse = functionConfig[PrimitiveFunctionAttribute::AttributeNameOneHotOutputSparse].Value<bool>();
auto axis = functionConfig[PrimitiveFunctionAttribute::AttributeNameOneHotAxis].Value<Axis>();
ASSIGN_NEW_NODE(OneHotNode, network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName);
break;
}
case PrimitiveOpType::Gather:
ASSIGN_NEW_NODE(GatherNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ToBatch:
{
ASSIGN_NEW_NODE(ToBatchAxisNode, network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::UnpackBatch:
{
ASSIGN_NEW_NODE(UnpackBatchAxisNode, network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Plus:
ASSIGN_NEW_NODE(PlusNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LogPlus:
ASSIGN_NEW_NODE(LogPlusNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Pow:
ASSIGN_NEW_NODE(PowNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Minus:
ASSIGN_NEW_NODE(MinusNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ElementTimes:
ASSIGN_NEW_NODE(ElementTimesNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Equal:
ASSIGN_NEW_NODE(EqualNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NotEqual:
ASSIGN_NEW_NODE(NotEqualNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Less:
ASSIGN_NEW_NODE(LessNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LessEqual:
ASSIGN_NEW_NODE(LessEqualNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Greater:
ASSIGN_NEW_NODE(GreaterNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GreaterEqual:
ASSIGN_NEW_NODE(GreaterEqualNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Times:
{
size_t outputRank = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputRank].Value<size_t>();
auto inferInputRankToMap = functionConfig[PrimitiveFunctionAttribute::AttributeNameInferInputRankToMap].Value<int>();
ASSIGN_NEW_NODE(TimesNode, network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
break;
}
case PrimitiveOpType::TransposeTimes:
{
size_t outputRank = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputRank].Value<size_t>();
ASSIGN_NEW_NODE(TransposeTimesNode, network->GetDeviceId(), internalNodeName, outputRank);
break;
}
case PrimitiveOpType::Convolution:
{
auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
NDShape dilation = { 1 };
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDilation))
dilation = functionConfig[PrimitiveFunctionAttribute::AttributeNameDilation].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunctionAttribute::AttributeNameTranspose].Value<bool>();
NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
NDShape outputShape = NDShape::Unknown();
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOutputShape))
outputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputShape].Value<NDShape>();
auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameGroups))
groups = functionConfig[PrimitiveFunctionAttribute::AttributeNameGroups].Value<size_t>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
ASSIGN_NEW_NODE(ConvolutionNode, network->GetDeviceId(), internalNodeName,
AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups);
break;
}
case PrimitiveOpType::ConvolutionSequenceShape:
{
auto strides = functionConfig[PrimitiveFunctionAttribute::AttributeNameStrides].Value<NDShape>();
NDShape dilation = { 1 };
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDilation))
dilation = functionConfig[PrimitiveFunctionAttribute::AttributeNameDilation].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunctionAttribute::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunctionAttribute::AttributeNameTranspose].Value<bool>();
NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
NDShape outputShape = NDShape::Unknown();
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOutputShape))
outputShape = functionConfig[PrimitiveFunctionAttribute::AttributeNameOutputShape].Value<NDShape>();
auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameGroups))
groups = functionConfig[PrimitiveFunctionAttribute::AttributeNameGroups].Value<size_t>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunctionAttribute::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
ASSIGN_NEW_NODE(ConvolutionSequenceShapeNode, network->GetDeviceId(), internalNodeName,
AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups);
break;
}
case PrimitiveOpType::CosDistance:
ASSIGN_NEW_NODE(CosDistanceNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CosDistanceWithNegativeSamples:
ASSIGN_NEW_NODE(CosDistanceWithNegativeSamplesNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Logistic:
ASSIGN_NEW_NODE(LogisticNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::SquaredError:
ASSIGN_NEW_NODE(SquareErrorNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CrossEntropyWithSoftmax:
ASSIGN_NEW_NODE(CrossEntropyWithSoftmaxNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ClassificationError:
ASSIGN_NEW_NODE(ClassificationErrorNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::EditDistanceError:
{
auto subPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameSubstitutionPenalty].Value<float>();
auto delPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameDeletionPenalty].Value<float>();
auto insPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameInsertionPenalty].Value<float>();
auto squashInputs = functionConfig[PrimitiveFunctionAttribute::AttributeNameSquashInputs].Value<bool>();
auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
ASSIGN_NEW_NODE(EditDistanceErrorNode, network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
break;
}
case PrimitiveOpType::LatticeSequenceWithSoftmax:
{
auto symListPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameSymListPath].Value<wstring>();
auto phonePath = functionConfig[PrimitiveFunctionAttribute::AttributeNamePhonePath].Value<wstring>();
auto stateListPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameStateListPath].Value<wstring>();
auto transProbPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameTransProbPath].Value<wstring>();
auto latticeConfigPath = functionConfig[PrimitiveFunctionAttribute::AttributeNameLatticeConfigPath].Value<wstring>();
auto frameDropThresh = functionConfig[PrimitiveFunctionAttribute::AttributeNameFrameDropThresh].Value<float>();
auto doReferenceAlign = functionConfig[PrimitiveFunctionAttribute::AttributeNameDoReferenceAlign].Value<bool>();
auto seqGammarUsesMBR = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarUsesMBR].Value<bool>();
auto seqGammarAMF = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarAMF].Value<float>();
auto seqGammarLMF = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarLMF].Value<float>();
auto seqGammarBMMIFactor = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarBMMIFactor].Value<float>();
auto seqGammarWordPen = functionConfig[PrimitiveFunctionAttribute::AttributeNameSeqGammarWordPen].Value<float>();
auto hSmoothingWeight = functionConfig[PrimitiveFunctionAttribute::AttributeNameHSmoothingWeight].Value<float>();
computationNodePtr = New<LatticeSequenceWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath,
hSmoothingWeight, frameDropThresh, doReferenceAlign, seqGammarUsesMBR, seqGammarAMF, seqGammarLMF, seqGammarBMMIFactor, seqGammarWordPen);
break;
}
case PrimitiveOpType::ForwardBackward:
{
auto delayContraint = functionConfig[PrimitiveFunctionAttribute::AttributeNameDelayConstraint].Value<int>();
auto blankTokenId = functionConfig[PrimitiveFunctionAttribute::AttributeNameBlankTokenId].Value<size_t>();
ASSIGN_NEW_NODE(ForwardBackwardNode, network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
break;
}
case PrimitiveOpType::LambdaRank:
ASSIGN_NEW_NODE(LambdaRankNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NDCG:
ASSIGN_NEW_NODE(NDCG1EvalNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
{
Variable inputOperandVar = functionInputs[0];
Variable initialStateVar = functionInputs[1];
size_t offset = primitiveFunction->Attributes()[PrimitiveFunctionAttribute::AttributeNameOffset].Value<size_t>();
if (op == PrimitiveOpType::PastValue)
ASSIGN_NEW_NODE(PastValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
else
ASSIGN_NEW_NODE(FutureValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
break;
}
case PrimitiveOpType::ReduceElements:
{
bool keepDimensions = true;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameReductionKeepDimensions))
keepDimensions = functionConfig[PrimitiveFunctionAttribute::AttributeNameReductionKeepDimensions].Value<bool>();
auto reductionOpName = functionConfig[PrimitiveFunctionAttribute::AttributeNameReductionOpName].Value<std::wstring>();
std::vector<Axis> reductionAxis;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxisVec))
{
reductionAxis = AsVector<Axis>(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
}
else if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameAxis))
{
reductionAxis.push_back(functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>());
}
else
{
RuntimeError("Failed to create computation node': Reduce operation %ls with no '%ls' or '%ls' attributes",
PrimitiveOpTypeName(op).c_str(),
PrimitiveFunctionAttribute::AttributeNameAxis.c_str(),
PrimitiveFunctionAttribute::AttributeNameAxisVec.c_str()
);
}
ASSIGN_NEW_NODE(ReduceElementsNode, network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions);
break;
}
case PrimitiveOpType::BatchNormalization:
{
auto spatial = functionConfig[PrimitiveFunctionAttribute::AttributeNameSpatial].Value<bool>();
auto normalizationTimeConstant = functionConfig[PrimitiveFunctionAttribute::AttributeNameNormalizationTimeConstant].Value<double>();
auto blendTimeConstant = functionConfig[PrimitiveFunctionAttribute::AttributeNameBlendTimeConstant].Value<double>();
auto epsilon = functionConfig[PrimitiveFunctionAttribute::AttributeNameEpsilon].Value<double>();
auto useCuDNNEngine = functionConfig[PrimitiveFunctionAttribute::AttributeNameUseCuDNNEngine].Value<bool>();
bool disableRegularization = false;
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameDisableRegularization))
{
disableRegularization = functionConfig[PrimitiveFunctionAttribute::AttributeNameDisableRegularization].Value<bool>();
}
ASSIGN_NEW_NODE(BatchNormalizationNode, network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, disableRegularization, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
ASSIGN_NEW_NODE(PackedIndexNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GatherPacked:
ASSIGN_NEW_NODE(GatherPackedNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ScatterPacked:
ASSIGN_NEW_NODE(ScatterPackedNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Clip:
ASSIGN_NEW_NODE(ClipNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Select:
ASSIGN_NEW_NODE(IfNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Splice:
{
Axis spliceAxis = functionConfig[PrimitiveFunctionAttribute::AttributeNameAxis].Value<Axis>();
ASSIGN_NEW_NODE(RowStackNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
break;
}
case PrimitiveOpType::Pad:
{
auto head = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingHead].Value<std::vector<DictionaryValue>>());
auto foot = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingFoot].Value<std::vector<DictionaryValue>>());
auto mode = functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingMode].Value<size_t>();
auto constantValue = functionConfig[PrimitiveFunctionAttribute::AttributeNamePaddingConstantValue].Value<double>();
ASSIGN_NEW_NODE(PaddingNode, network->GetDeviceId(), internalNodeName, head, foot, (PaddingType)mode, constantValue);
break;
}
case PrimitiveOpType::OptimizedRNNStack:
{
auto bidirectional = functionConfig[PrimitiveFunctionAttribute::AttributeNameBidirectional].Value<bool>();
auto numLayers = functionConfig[PrimitiveFunctionAttribute::AttributeNameNumLayers].Value<size_t>();
auto hiddenSize = functionConfig[PrimitiveFunctionAttribute::AttributeNameHiddenSize].Value<size_t>();
auto recurrentOp = functionConfig[PrimitiveFunctionAttribute::AttributeNameRecurrentOp].Value<std::wstring>();
ASSIGN_NEW_NODE(OptimizedRNNStackNode, network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
break;
}
case PrimitiveOpType::ReconcileDynamicAxis:
{
ASSIGN_NEW_NODE(ReconcileDynamicAxisNode, network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::LogSoftmax:
{
//This can be implemented as x => x - ReduceLogSum(x). How to do this here?
ASSIGN_NEW_NODE(LogSoftmaxNode, network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Pass:
ASSIGN_NEW_NODE(PassNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LabelsToGraph:
ASSIGN_NEW_NODE(LabelsToGraphNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StopGradient:
ASSIGN_NEW_NODE(StopGradientNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Assign:
ASSIGN_NEW_NODE(AssignNode, network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Crop:
if (functionInputs.size() == 2)
{
if (functionConfig.Contains(PrimitiveFunctionAttribute::AttributeNameOffset))
{
// Crop with given offsets.
const auto& offsets = AsVector<size_t>(functionConfig[PrimitiveFunctionAttribute::AttributeNameOffset].Value<std::vector<DictionaryValue>>());
if (offsets.size() != 2)
{
CNTK::LogicError("Vector of crop offsets must have size 2.");
}
ASSIGN_NEW_NODE(CropNode, offsets[0], offsets[1], network->GetDeviceId(), internalNodeName);
}
else
{
// Crop with two inputs and automatic offset computation.
ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName);
}
}
else if (functionInputs.size() == 4)
{
// Crop with four inputs and automatic offset computation.
ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName);
}
else
{
CNTK::LogicError("Crop node must have 2 or 4 node inputs.");
}
break;
case PrimitiveOpType::Cast:
{
DataType outputType = (DataType)functionConfig[PrimitiveFunctionAttribute::AttributeNameNewDataType].Value<int>();
switch (outputType)
{
case DataType::Float:
ASSIGN_NEW_NODE2(CastNode, float, network->GetDeviceId(), internalNodeName);
break;
case DataType::Double:
ASSIGN_NEW_NODE2(CastNode, double, network->GetDeviceId(), internalNodeName);
break;
case DataType::Float16:
ASSIGN_NEW_NODE2(CastNode, half, network->GetDeviceId(), internalNodeName);
break;
}
break;
}
case PrimitiveOpType::CustomProxyOp:
{
ASSIGN_NEW_NODE(CustomProxyOpNode, network->GetDeviceId(), internalNodeName);
break;
}
default:
CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
break;
}
// Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);
if (computationNodePtr->Is<INumInputs>())
{
auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.",
function->AsString().c_str(),
(int)inputNodesBasePtrs.size(),
(int)computationNodeExpectedInputCount);
}
if (computationNodePtr->Is<RngUser>())
{
auto seed = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunctionAttribute::AttributeNameRngOffset].Value<size_t>();
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
}
else
{
ASSIGN_NEW_NODE(UserDefinedV2FunctionNode, network->GetDeviceId(), internalNodeName, function->shared_from_this());
// For user defined functions, we only attach unique inputs in the internal computation network since, the UDF
// backward implementations directly compute aggregate gradient values for unique inputs
std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
for (auto inputNodeBasePtr : inputNodesBasePtrs)
{
if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
}
inputNodesBasePtrs = uniqueInputNodesBasePtrs;
}
}
else
{
size_t i = 1;
while (outputs[i] != variable) i++;
assert(i < outputs.size());
ASSIGN_NEW_NODE(OutputMultiplexerNode, network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i);
inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
}
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
return computationNodePtr;
}