in Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp [2041:3134]
FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector<Variable> &inputs, const Graph *graph,
VariableToFunctionPtr &sequenceWrapperInputToFunctionPtr, const Variable& inputPlaceholder
)
{
string onnxOpName = node->OpType();
Variable inputOperand0 = (inputPlaceholder.IsInitialized() || inputs.empty()) ? inputPlaceholder : inputs[0];
if (onnxOpName == "LSTM")
{
const string direction = GetNamedAttributeAsString(node, "direction");
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
std::vector<string>({"Sigmoid", "Tanh", "Tanh"}));
return CreateLSTM(node, inputs, direction, activations, activation_alpha, activation_beta, sequenceWrapperInputToFunctionPtr);
}
else if (onnxOpName == "GRU")
{
const string direction = GetNamedAttributeAsString(node, "direction");
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
std::vector<string>({"Sigmoid", "Tanh"}));
return CreateGRU(node, inputs, direction, activations, activation_alpha, activation_beta, sequenceWrapperInputToFunctionPtr);
}
else if (onnxOpName == "RNN")
{
const string direction = GetNamedAttributeAsString(node, "direction");
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
std::vector<string>({"Tanh"}));
return CreateRNN(node, inputs, direction, activations, activation_alpha, activation_beta, sequenceWrapperInputToFunctionPtr);
}
if (onnxOpName == "FC")
{
return CreateCNTKFCNode(ToFixedWStringFromMultiByte(node->Name()), inputs);
}
else if (onnxOpName == "Flatten")
{
int64_t axisIndex = (size_t) GetNamedAttributeAsInt64(node, "axis", 1);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkFunction = Flatten(inputs[0], axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Equal")
{
FunctionPtr cntkFunction = Equal(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Greater")
{
FunctionPtr cntkFunction = Greater(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Less")
{
FunctionPtr cntkFunction = Less(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Mean")
{
FunctionPtr cntkFunction = Mean(inputs, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Clip")
{
double minValue = GetNamedAttributeAsFloat(node, "min");
double maxValue = GetNamedAttributeAsFloat(node, "max");
Constant minVariable = Constant::Scalar(CNTK::DataType::Float, minValue);
Constant maxVariable = Constant::Scalar(CNTK::DataType::Float, maxValue);
FunctionPtr cntkFunction = Clip(inputs[0], minVariable, maxVariable, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Sum")
{
FunctionPtr cntkFunction = Sum(inputs, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "HardSigmoid")
{
float alpha = GetNamedAttributeAsFloat(node, "alpha");
float beta = GetNamedAttributeAsFloat(node, "beta");
FunctionPtr cntkFunction = HardSigmoid(inputs[0], alpha, beta, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "LRN")
{
// Guard the even number size case. The size > channel case is checked at cntk side.
size_t size = static_cast<size_t>(GetNamedAttributeAsInt64(node, "size"));
// In ONNX the size to sum over channel axis is given by diameter, while in CNTK radius.
// Thus we are unable to support even number diameter.
// Currently in Lotus we are also throwing error when diameter is even.
if (size % 2 != 1)
LogicError("LRN does not support even diameter size to sum over channel axis.");
size_t depthRadius = (size - 1)/2;
double bias = static_cast<double>(GetNamedAttributeAsFloat(node, "bias", 1.0f));
double alpha = static_cast<double>(GetNamedAttributeAsFloat(node, "alpha", 1e-4f));
double beta = static_cast<double>(GetNamedAttributeAsFloat(node, "beta", 0.75f));
FunctionPtr cntkFunction = LocalResponseNormalization(inputOperand0,
depthRadius, bias, alpha, beta, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "AveragePool" || onnxOpName == "MaxPool")
{
NDShape poolingWindowShape = GetNamedAttributeAsShape(node, "kernel_shape", false);
auto dim = poolingWindowShape.Rank();
NDShape strides = GetNamedAttributeAsShape(node, "strides", false, NDShape(std::vector<size_t>(poolingWindowShape.Rank(), 1u)));
bool includePad = GetNamedAttributeAsInt64(node, "count_include_pad", 0) != 0;
bool hasAutoPad = HasNamedAttribute(node, "auto_pad") && GetNamedAttributeAsString(node, "auto_pad", "SAME_UPPER") != "NOTSET";
bool hasPads = HasNamedAttribute(node, "pads");
bool ceilOutDim = false;
if (strides.Rank() != dim)
LogicError("Length of attribute 'strides' should be equal to dimensionality of the kernel.");
if (hasAutoPad && hasPads)
{
LogicError("Ambiguous Conv node specification. Both %s and %s attributes are specified. Only one of the two should be specified.",
"auto_pad", "pads");
}
strides = strides.AppendShape({ 1 }); // Because CNTK Pooling API takes strides for channel axis also.
std::vector<bool> cntkPoolingAutoPadding;
std::pair<std::vector<size_t>, std::vector<size_t>> padsPair;
FunctionPtr cntkFunction;
if (hasAutoPad)
{
ConvAutoPadType auto_pad = ConvertStrToConvAutoPadType(GetNamedAttributeAsString(node, "auto_pad", "SAME_UPPER"));
switch (auto_pad)
{
case ConvAutoPadType::SAME_LOWER:
case ConvAutoPadType::SAME_UPPER:
{
const bool isSameUpper = auto_pad == ConvAutoPadType::SAME_UPPER;
const NDShape& inputWithBatchAxisShape = inputs[0].Shape();
padsPair = CalcPaddingForSameLowerOrUpperAutoPad(inputWithBatchAxisShape, poolingWindowShape, strides, /*isSameUpper=*/isSameUpper);
cntkFunction = Pooling(inputOperand0, onnxOpName == "AveragePool" ? PoolingType::Average : PoolingType::Max,
poolingWindowShape, strides, padsPair.first, padsPair.second, ceilOutDim, includePad, ToFixedWStringFromMultiByte(node->Name()));
break;
}
case ConvAutoPadType::VALID:
{
cntkPoolingAutoPadding.insert(cntkPoolingAutoPadding.begin(), dim + 1, false);
cntkFunction = Pooling(inputOperand0, onnxOpName == "AveragePool" ? PoolingType::Average : PoolingType::Max,
poolingWindowShape, strides, cntkPoolingAutoPadding, ceilOutDim, includePad, ToFixedWStringFromMultiByte(node->Name()));
break;
}
}
}
else // Either hasPads == true, i.e. pads was specified, or if pads is not specified then we use default pads value of 0.
{
// If 'pads' is specified, we pad the node and then do 'valid' convolution.
std::vector<int64_t> pads = GetNamedAttributeAsInt64Vec(node, "pads", std::vector<int64_t>(2*dim, 0));
auto padsPair = SplitAndReverseVec(pads);
cntkFunction = Pooling(inputOperand0, onnxOpName == "AveragePool" ? PoolingType::Average : PoolingType::Max,
poolingWindowShape, strides, padsPair.first, padsPair.second, ceilOutDim, includePad, ToFixedWStringFromMultiByte(node->Name()));
}
return cntkFunction;
}
else if (onnxOpName == "GlobalAveragePool" || onnxOpName == "GlobalMaxPool")
{
NDShape strides = {1};
std::vector<bool> autoPadding = {false};
bool ceilOutDim = false;
bool includePad = false;
FunctionPtr cntkFunction = Pooling(inputOperand0,
onnxOpName == "GlobalAveragePool" ? PoolingType::Average : PoolingType::Max,
NDShape::Unknown(), strides, autoPadding, ceilOutDim, includePad, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "MaxRoiPool")
{
// ONNX spec is list of ints - however current IR spec is AttrType::AttributeProto_AttributeType_FLOATS
std::vector<int64_t> pooled_shape = GetNamedAttributeAsInt64Vec(node, "pooled_shape");
std::vector<size_t> dims = VecInt64ToVecSize_t(pooled_shape);
NDShape roiOutputShape(dims);
float spatialScale = GetNamedAttributeAsFloat(node, "spatial_scale");
FunctionPtr cntkFunction = ROIPooling(inputs[0], inputs[1],
PoolingType::Max, roiOutputShape, spatialScale, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Conv")
{
return CreateCNTKConvNode(node, inputs);
}
else if (onnxOpName == "ConvTranspose")
{
return CreateCNTKConvTransposeNode(node, inputs);
}
else if (onnxOpName == "BatchNormalization" || onnxOpName == "SpatialBN")
{
auto operandPlaceholder = PlaceholderVariable(inputs[0].Shape(), L"operand", {});
const Variable &operand = ToBatch(operandPlaceholder);
const Variable &scale = PlaceholderVariable(inputs[1].Shape(), inputs[1].Name(), {});
const Variable &bias = PlaceholderVariable(inputs[2].Shape(), inputs[2].Name(), {});
const Variable &runningMean = PlaceholderVariable(inputs[3].Shape(), inputs[3].Name(), {});
const Variable &runningInvStd = PlaceholderVariable(inputs[4].Shape(), inputs[4].Name(), {});
const Variable &runningCount = Constant::Scalar(0.0F);
bool spatial = onnxOpName == "SpatialBN" || GetNamedAttributeAsInt64(node, "spatial", 1) != 0;
double normalizationTimeConstant = 0.0;
float momentum = GetNamedAttributeAsFloat(node, "momentum", 0.9f);
if ((momentum > (1.0f - std::numeric_limits<float>::epsilon())) &&
(momentum < (1.0f + std::numeric_limits<float>::epsilon())))
normalizationTimeConstant = INFINITY;
else if (momentum > 0.0f)
normalizationTimeConstant = -48.0f / log1p(momentum - 1.0f);
else
normalizationTimeConstant = 0.0;
// TODO: avoid hardcoded values
double blendTimeConstant = 0;
double epsilon = static_cast<double>(GetNamedAttributeAsFloat(node, "epsilon", 0.00001f));
bool useCuDNNEngine = true;
if ((epsilon < (0.00001f - std::numeric_limits<float>::epsilon())))
{
// REVIEW SPTIWARI: We are leaving some buffer in comparing with 1e-5 in the "if" condition above,
// because 1e-5 is a common value for epsilon (ONNX default) and we do not want the model
// to run slow for this common case because of any floating point differences. But for anything
// clearly lower than 1e-5, we will not use cuDNN's batch normalization engine, because it floors
// epsilon at 1e-5, and that can produce wrong numbers. For the special case when epsilon happens
// to be within (1e-5 , 1e-5 - std::numeric_limits<float>::epsilon()] range, cuDNN engine will be
// used but it will print a warning that it is flooring epsilon to 1e-5.
fprintf(stderr, "Epsilon = %0.7f, which is < 1e-5. CuDNN engine cannot be used for Batch Normalization. Could be slow.", epsilon);
useCuDNNEngine = false;
}
bool disableRegularization = false;
FunctionPtr cntkFunctionWithBatchAxis = BatchNormalization(operand,
scale,
bias,
runningMean,
runningInvStd,
runningCount,
spatial,
normalizationTimeConstant,
blendTimeConstant,
epsilon,
useCuDNNEngine,
disableRegularization,
ToFixedWStringFromMultiByte(node->Name()));
FunctionPtr cntkFunctionWithStaticAxis = UnpackBatch(cntkFunctionWithBatchAxis, ToFixedWStringFromMultiByte(node->Name()));
vector<Variable> operands{ operandPlaceholder, scale, bias, runningMean, runningInvStd };
vector<pair<Variable, Variable>> argsMap{ pair<Variable, Variable>{operands[0], inputs[0]} };
for (int i = 1; i < 5; ++i)
{
// TODO: this does not work if mean/var inputs are not constant/parameters.
argsMap.push_back(pair<Variable, Variable>{ operands[i], inputs[0].GetDataType() == DataType::Float16 ? Utils::ConvertVariableType<float16, float>(inputs[i], true) : inputs[i]});
}
return AsBlock(std::move(cntkFunctionWithStaticAxis), argsMap,
cntkFunctionWithBatchAxis->OpName(), ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Gemm")
{
float alpha = GetNamedAttributeAsFloat(node, "alpha", 1.0f);
float beta = GetNamedAttributeAsFloat(node, "beta", 1.0f);
bool transA = GetNamedAttributeAsInt64(node, "transA", 0) != 0;
bool transB = GetNamedAttributeAsInt64(node, "transB", 0) != 0;
// we need to swap position of inputs[0] and inputs[1], since c++ has different matrix row/col major than python.
FunctionPtr cntkFunction = ::CNTK::Internal::Gemm(inputs[1], inputs[0], inputs[2], alpha, beta, transB, transA);
return cntkFunction;
}
else if (onnxOpName == "Dropout")
{
const Variable &operand = inputs[0];
double dropoutRate = GetNamedAttributeAsFloat(node, "ratio");
unsigned long seed = SentinelValueForAutoSelectRandomSeed;
FunctionPtr cntkFunction = Dropout(operand, dropoutRate, seed, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "RandomUniform")
{
const NDShape &shape = GetNamedAttributeAsShape(node, "shape", false);
TensorProto_DataType onnxDataType = static_cast<TensorProto_DataType>(GetNamedAttributeAsInt64(
node, "dtype", TensorProto_DataType::TensorProto_DataType_FLOAT));
CNTK::DataType dataType = ConvertDataTypeTensorProtoToCNTK(onnxDataType);
double low = GetNamedAttributeAsFloat(node, "low");
double high = GetNamedAttributeAsFloat(node, "high");
unsigned long seed = GetNamedAttributeAsInt64(node, "seed");
FunctionPtr cntkFunction = UniformRandom(shape, dataType, low, high, seed, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "RandomNormal")
{
const NDShape &shape = GetNamedAttributeAsShape(node, "shape", false);
TensorProto_DataType onnxDataType = static_cast<TensorProto_DataType>(GetNamedAttributeAsInt64(
node, "dtype", TensorProto_DataType::TensorProto_DataType_FLOAT));
CNTK::DataType dataType = ConvertDataTypeTensorProtoToCNTK(onnxDataType);
double mean = GetNamedAttributeAsFloat(node, "mean");
double scale = GetNamedAttributeAsFloat(node, "scale");
unsigned long seed = GetNamedAttributeAsInt64(node, "seed");
FunctionPtr cntkFunction = NormalRandom(shape, dataType, mean, scale, seed, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "RandomUniformLike")
{
const Variable &operand = inputs[0];
double low = GetNamedAttributeAsFloat(node, "low");
double high = GetNamedAttributeAsFloat(node, "high");
unsigned long seed = GetNamedAttributeAsInt64(node, "seed");
FunctionPtr cntkFunction = UniformRandomLike(operand, low, high, seed, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "RandomNormalLike")
{
const Variable &operand = inputs[0];
double mean = GetNamedAttributeAsFloat(node, "mean");
double scale = GetNamedAttributeAsFloat(node, "scale");
unsigned long seed = GetNamedAttributeAsInt64(node, "seed");
FunctionPtr cntkFunction = NormalRandomLike(operand, mean, scale, seed, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Add")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = Plus(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Sub")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = Minus(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Mul")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = ElementTimes(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Div")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = ElementDivide(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "And")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = ElementAnd(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Not")
{
FunctionPtr cntkFunction = ElementNot(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Or")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = ElementOr(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Xor")
{
Variable input0, input1;
std::tie<Variable, Variable>(input0, input1) = BroadcastElementWiseInput(node, inputs[0], inputs[1]);
FunctionPtr cntkFunction = ElementXor(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Neg")
{
FunctionPtr cntkFunction = Negate(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Abs")
{
FunctionPtr cntkFunction = Abs(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Reciprocal")
{
FunctionPtr cntkFunction = Reciprocal(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Floor")
{
FunctionPtr cntkFunction = Floor(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Ceil")
{
FunctionPtr cntkFunction = Ceil(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Sqrt")
{
FunctionPtr cntkFunction = Sqrt(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Relu")
{
FunctionPtr cntkFunction = ReLU(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "LeakyRelu")
{
double alpha = static_cast<double>(GetNamedAttributeAsFloat(node, "alpha", 0.01F));
FunctionPtr cntkFunction = LeakyReLU(inputs[0], alpha, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Selu")
{
double alpha = static_cast<double>(GetNamedAttributeAsFloat(node, "alpha", 1.6732F));
double gamma = static_cast<double>(GetNamedAttributeAsFloat(node, "gamma", 1.0507F));
FunctionPtr cntkFunction = SELU(inputs[0], gamma, alpha, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Elu")
{
double alpha = static_cast<double>(GetNamedAttributeAsFloat(node, "alpha", 1.0f));
FunctionPtr cntkFunction = ELU(inputs[0], alpha, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Exp")
{
FunctionPtr cntkFunction = Exp(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Log")
{
FunctionPtr cntkFunction = Log(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Tanh")
{
FunctionPtr cntkFunction = Tanh(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Pow")
{
FunctionPtr cntkFunction = Pow(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "MatMul")
{
// in case of input with both static batch and sequence axes, need to convert them
// to dynamic axes for MatMul to work.
auto input0 = inputs[0];
auto input1 = inputs[1];
auto HasBatchAndSequenceAxes = [](Variable input) {
return input.Shape().Rank() >= 2 &&
input.Shape()[input.Shape().Rank() - 1] == NDShape::FreeDimension &&
input.Shape()[input.Shape().Rank() - 2] == NDShape::FreeDimension; };
auto HasFreeDimensionAt0Axes = [](Variable input) {
return input.Shape().Rank() >= 1 &&
input.Shape()[input.Shape().Rank() - 1] == NDShape::FreeDimension; };
bool input0HasBatchAndSequenceAxes = HasBatchAndSequenceAxes(inputs[0]);
bool input1HasBatchAndSequenceAxes = HasBatchAndSequenceAxes(inputs[1]);
bool input0HasFreeDimensionAt0Axes = HasFreeDimensionAt0Axes(inputs[0]);
bool input1HasFreeDimensionAt0Axes = HasFreeDimensionAt0Axes(inputs[1]);
if (input0HasBatchAndSequenceAxes || input1HasBatchAndSequenceAxes)
{
if (input0HasBatchAndSequenceAxes)
input0 = ToBatchAndSequence(inputs[0], sequenceWrapperInputToFunctionPtr);
if (input1HasBatchAndSequenceAxes)
input1 = ToBatchAndSequence(inputs[1], sequenceWrapperInputToFunctionPtr);
FunctionPtr cntkFunction = ::CNTK::Internal::MatMul(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
cntkFunction = UnpackBatchAndSequence(cntkFunction);
return cntkFunction;
}
else if (input0HasFreeDimensionAt0Axes || input1HasFreeDimensionAt0Axes)
{
if (input0HasFreeDimensionAt0Axes)
input0 = ToBatch(inputs[0], L"");
if (input1HasFreeDimensionAt0Axes)
input1 = ToBatch(inputs[1], L"");
FunctionPtr cntkFunction = ::CNTK::Internal::MatMul(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
cntkFunction = UnpackBatch(cntkFunction, L"");
return cntkFunction;
}
else
{
FunctionPtr cntkFunction = ::CNTK::Internal::MatMul(input0, input1, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
}
else if (onnxOpName == "PRelu")
{
FunctionPtr cntkFunction = PReLU(inputs[1], inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Sigmoid")
{
FunctionPtr cntkFunction = Sigmoid(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Max")
{
if (inputs.size() > 1)
{
FunctionPtr cntkFunction = ElementMax(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
for (int i = 2; i < inputs.size(); i++) {
cntkFunction = ElementMax(cntkFunction, inputs[i], ToFixedWStringFromMultiByte(node->Name() + "_" + std::to_string(i)));
}
return cntkFunction;
}
else
{
return ElementMax(inputs[0], inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
}
else if (onnxOpName == "Min")
{
if (inputs.size() > 1)
{
FunctionPtr cntkFunction = ElementMin(inputs[0], inputs[1], ToFixedWStringFromMultiByte(node->Name()));
for (int i = 2; i < inputs.size(); i++) {
cntkFunction = ElementMin(cntkFunction, inputs[i], ToFixedWStringFromMultiByte(node->Name() + "_" + std::to_string(i)));
}
return cntkFunction;
}
else
{
return ElementMin(inputs[0], inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
}
else if (onnxOpName == "Sum")
{
// not specified in Operators.cpp
return nullptr;
}
else if (onnxOpName == "Softmax" || onnxOpName == "LogSoftmax" || onnxOpName == "Hardmax")
{
int64_t onnxAxis = GetNamedAttributeAsInt64(node, "axis", 1);
if (onnxAxis == static_cast<int>(inputs[0].Shape().Rank() + inputs[0].DynamicAxes().size() - 1))
{
// in case of the last axis, ONNX and CNTK are equivalent
if (onnxOpName == "Softmax")
{
return Softmax(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "LogSoftmax")
{
return LogSoftmax(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Hardmax")
{
return Hardmax(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
}
auto inputOperand0Placeholder = PlaceholderVariable(inputs[0].Shape(), inputs[0].GetDataType(), L"operand", {});
Axis axis(ConvertONNXAxisToCNTKCppApi(GetNamedAttributeAsInt64(node, "axis", 1), inputOperand0Placeholder));
Variable input = Flatten(inputOperand0Placeholder, axis);
FunctionPtr cntkFunction;
if (onnxOpName == "Softmax")
{
cntkFunction = Softmax(input, ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "LogSoftmax")
{
cntkFunction = LogSoftmax(input, ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Hardmax")
{
cntkFunction = Hardmax(input, ToFixedWStringFromMultiByte(node->Name()));
}
NDShape originalShape = inputOperand0Placeholder.Shape();
assert(originalShape.Rank() > 0);
// If original shape has free dimension(batch axis), we'll need to have reshape node infer that for us.
if (originalShape[originalShape.Rank() - 1] == NDShape::FreeDimension)
originalShape[originalShape.Rank() - 1] = NDShape::InferredDimension;
cntkFunction = Reshape(cntkFunction, originalShape);
auto additionalProperties = Dictionary();
additionalProperties[L"axis"] = axis;
return AsBlock(std::move(cntkFunction), {{inputOperand0Placeholder, inputs[0]}}, std::move(additionalProperties),
ToFixedWStringFromMultiByte(onnxOpName) + L"_onnx", ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Softplus")
{
FunctionPtr cntkFunction = Softplus(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Softsign")
{
FunctionPtr cntkFunction = Softsign(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceMax")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceMax(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceMin")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceMin(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceSum")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceSum(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceMean")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceMean(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceProd")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceProd(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceLogSumExp" || onnxOpName == "ReduceLogSum")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceLogSum(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceL1")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceL1(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceL2")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceL2(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ReduceSumSquare")
{
bool keepdims;
std::vector<Axis> axes;
std::tie<std::vector<Axis>, bool>(axes, keepdims) = GetReduceElementsAttributes(node, inputs[0]);
FunctionPtr cntkFunction = ReduceSumSquare(inputs[0], axes, keepdims, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ArgMax")
{
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
// -1 to compensate what ConvertAxisToCNTKCppApi assumes that axis is already decreased by 1
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkfunction = Argmax(inputs[0], axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkfunction;
}
else if (onnxOpName == "ArgMin")
{
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkFunction = Argmin(inputs[0], axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Reshape")
{
if (!inputs[0].DynamicAxes().empty())
NOT_IMPLEMENTED;
std::vector<int64_t> newShape = GetShapeFromInput(node->InputDefs()[1], graph);
std::vector<int64_t> inputShape = CastVector<size_t, int64_t>(inputs[0].Shape().Dimensions());
std::reverse(inputShape.begin(), inputShape.end());
int inferredDimIndex = -1;
int totalInputSizeExceptFreeDim = 1, totalReshapeSizeExceptFreeAndInferredDim = 1;
// process free and inferred dimensions. ONNX dimensions are left aligned, likely starting with sequence, batch, then static axes.
// NDShape dimension order is reversed w.r.t. ONNX.
for (int index = 0; index < std::max(newShape.size(), inputShape.size()); index++)
{
if (index < inputShape.size() && newShape[index] != ReshapeKeepInputDim)
totalInputSizeExceptFreeDim *= inputShape[index];
if (index < newShape.size())
{
if (newShape[index] == ReshapeInferredDim)
{
if (inferredDimIndex == -1)
{
inferredDimIndex = index;
}
else
LogicError("Reshape: 'shape' contains more than one inferred dimension.");
}
else if (newShape[index] == ReshapeKeepInputDim)
{
if (index < inputShape.size())
newShape[index] = inputShape[index];
else
LogicError("Reshape: 'shape' has a 'keep_dimension' without matching input dimension.");
}
else
{
totalReshapeSizeExceptFreeAndInferredDim *= newShape[index];
}
}
}
if (inferredDimIndex != -1)
{
if (totalInputSizeExceptFreeDim % totalReshapeSizeExceptFreeAndInferredDim != 0)
LogicError("Reshape: inferred dimension cannot be calculated from input and new shape size.");
newShape[inferredDimIndex] = totalInputSizeExceptFreeDim / totalReshapeSizeExceptFreeAndInferredDim;
}
std::reverse(newShape.begin(), newShape.end());
NDShape newNDShape(CastVector<int64_t, size_t>(newShape));
FunctionPtr cntkFunction = Reshape(inputs[0], newNDShape, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Unsqueeze")
{
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ::CNTK::Internal::Unsqueeze(inputs[0], axes, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Concat")
{
// We allow the 'axis' attribute to be optional, and not required (as
// given in Concat's ONNX spec), to be consistent with other frameworks.
// 'axis' can be enforced as a required attribute, if needed.
int64_t onnxAxis = GetNamedAttributeAsInt64(node, "axis", 0);
// c.f. ConvertAxisToOnnxBroadcastOfOp where axis is computed taking into consideration
// dynamic axes of all inputs and possible of broadcasting.
Axis axis = ConvertONNXAxisToCNTKCppApi(onnxAxis, inputs[0]);
std::vector<Variable> fixedInputs;
if (FixConstantShapeForConstantVariableInputPair(inputs, fixedInputs))
{
FunctionPtr cntkFunction = Splice(fixedInputs, axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else
{
FunctionPtr cntkFunction = Splice(inputs, axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
}
// { L"", "Split)
else if (onnxOpName == "Slice")
{
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
std::vector<int64_t> ends64 = GetNamedAttributeAsInt64Vec(node, "ends");
if (starts64.size() != ends64.size())
{
LogicError("starts (of size %d) and ends (of size %d) attributes of Slice operation must be the same size.",
(int) starts64.size(), (int) ends64.size());
}
std::vector<int> starts = VecInt64ToVecInt(starts64);
std::vector<int> ends = VecInt64ToVecInt(ends64);
for (auto &e : ends)
{
// CNTK treats endIndex of 0 as to (and include) the last.
if (e == INT_MAX)
e = 0;
}
std::vector<Axis> axes;
if (HasNamedAttribute(node, "axes"))
axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
// axes is optional so provide a default
if (axes.empty())
{
for (int i = starts.size() - 1; i >= 0; i--)
{
Axis axis(i);
axes.push_back(axis);
}
}
if (axes.size() == 1 && axes[0].IsSequenceAxis())
{
FunctionPtr cntkFunction = Sequence::Slice(inputs[0], starts[0], ends[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else
{
FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
}
else if (onnxOpName == "Transpose")
{
std::vector<int64_t> permutation = GetNamedAttributeAsInt64Vec(node, "perm");
Variable input = inputs[0];
// Transpose takes permutation with static axes only. ConvertPermutationONNXToCNTK assumes batch and sequence axes,
// if they exist, are not involved in transpose. e.g. permutation is always in the form of [batch_perm = 0, sequence_perm = 1, perm0, perm1, ..perm1]
// ConvertPermutationONNXToCNTK fails if above is not true. This is the case when uppack batch/sequence is needed
bool needToUnpack = (permutation.size() - inputs[0].DynamicAxes().size()) < 2;
for (int i = 0; i < inputs[0].DynamicAxes().size(); i++)
{
if (permutation[i] != i)
{
needToUnpack = true;
}
}
if (needToUnpack)
{
if (inputs[0].DynamicAxes().size() == 2)
{
input = Sequence::Unpack(input, 0, L"");
input = UnpackBatch(input, L"");
}
else
input = UnpackBatch(input, L"");
}
std::vector<Axis> argsortedPermutation = ConvertPermutationONNXToCNTK(permutation, input.HasBatchAxis(), input.HasSequenceAxis());
FunctionPtr cntkFunction = Transpose(input, argsortedPermutation, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Pad")
{
std::vector<int64_t> pads = GetNamedAttributeAsInt64Vec(node, "pads");
if (inputs[0].HasBatchAxis())
{
pads.erase(pads.begin() + pads.size() / 2);
pads.erase(pads.begin());
}
if (pads.size() != 2 * inputs[0].Shape().Rank())
LogicError("Pad: Incorrect length of 'pads' attribute in Pad op. Length of 'pads' attribute should be twice the number of dimensions in input tensor.");
auto padsPair = SplitAndReverseVec(pads);
CNTK::PaddingMode cntkPaddingMode;
double cntkConstantValue = 0.0;
auto mode = GetNamedAttributeAsString(node, "mode", "constant");
std::transform(mode.begin(), mode.end(), mode.begin(), [](char v) { return (char) ::tolower(v); });
if (mode == "constant")
cntkPaddingMode = CNTK::PaddingMode::CONSTANTPAD;
else if (mode == "reflect")
cntkPaddingMode = CNTK::PaddingMode::REFLECTPAD;
else if (mode == "edge")
NOT_IMPLEMENTED
else
LogicError("Pad: Invalid 'mode' attribute value, %s, specified for Pad node.", mode.c_str());
if (cntkPaddingMode == CNTK::PaddingMode::CONSTANTPAD)
cntkConstantValue = static_cast<double>(GetNamedAttributeAsFloat(node, "value", 0.0));
FunctionPtr cntkPadFunction = Pad(inputs[0],
cntkPaddingMode,
padsPair.first,
padsPair.second,
cntkConstantValue,
ToFixedWStringFromMultiByte(node->Name()));
return cntkPadFunction;
}
else if (onnxOpName == "Gather")
{
FunctionPtr indices = [&](DataType referenceDataType, DataType indicesDataType) -> FunctionPtr {
if (referenceDataType == indicesDataType)
return inputs[1];
return Cast(inputs[1], referenceDataType, inputs[1].Name() + L"_cast");
}(inputs[0].GetDataType(), inputs[1].GetDataType());
if (HasNamedAttribute(node, "axis"))
{
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis", 0);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkFunction = GatherOp(indices, inputs[0], axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else
{
FunctionPtr cntkFunction = GatherOp(indices, inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
}
else if (onnxOpName == "DepthToSpace")
{
auto blockSize = GetNamedAttributeAsInt64(node, "blocksize", 1);
return DepthToSpace(inputOperand0, static_cast<size_t>(blockSize), ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "SpaceToDepth")
{
auto blockSize = GetNamedAttributeAsInt64(node, "blocksize", 1);
return SpaceToDepth(inputOperand0, static_cast<size_t>(blockSize), ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Squeeze")
{
std::vector<Axis> axes = GetNamedAttributeAsAxes(node, "axes");
return Squeeze(inputs[0], axes, ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "ImageScaler")
{
float scale = GetNamedAttributeAsFloat(node, "scale", 1);
std::vector<float> bias = GetNamedAttributeAsFloatVec(node, "bias", std::vector<float>());
return ImageScaler(inputOperand0, scale, bias, ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "MeanVarianceNormalization")
{
// REVIEW: ONNX MeanVarianceNormalization spec does not have an 'epsilon' attribute.
// But corresponding CNTK node does. We construct the CNTK node with default value of epsilon
// when loading the ONNX MeanVarianceNormalization node in CNTK.
std::vector<int64_t> axes = GetNamedAttributeAsInt64Vec(node, "axes");
auto rank = inputOperand0.Shape().Rank();
bool acrossChannels = true;
bool supported = true;
for (size_t i = 0; i < axes.size(); ++i)
{
if (i == 1 && axes[i] == 2) acrossChannels = false;
if (static_cast<int64_t>(i) != (!acrossChannels ? axes[i] - 1 : axes[i]))
{
supported = false;
break;
}
}
if (!(axes.size() == rank || axes.size() == rank + 1) || !supported)
LogicError("MeanVarianceNormalization: cntk supports only computing mean/variance over all tensor, or over channel axis. Other axes combinations are not supported");
return MeanVarianceNormalization(inputOperand0, acrossChannels, /*normalizeVariance=*/ true, ToFixedWStringFromMultiByte(node->Name()));
}
else if (onnxOpName == "Identity")
{
FunctionPtr cntkFunction = Alias(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Sin")
{
FunctionPtr cntkFunction = Sin(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Asin")
{
FunctionPtr cntkFunction = Asin(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Cos")
{
FunctionPtr cntkFunction = Cos(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Acos")
{
FunctionPtr cntkFunction = Acos(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Cast")
{
TensorProto_DataType newDataType = static_cast<TensorProto_DataType>(GetNamedAttributeAsInt64(node, "to"));
if (newDataType != TensorProto_DataType::TensorProto_DataType_FLOAT &&
newDataType != TensorProto_DataType::TensorProto_DataType_DOUBLE &&
newDataType != TensorProto_DataType::TensorProto_DataType_FLOAT16)
{
// for cast to types not supported by CNTK, we simply pass it through
// CNTK data type is more adaptive. For example, an ONNX gather op requires
// int64_t or int. CNTK float16, float, and double are not accepted by
// ONNX but it can input to an CNTK node.
return Alias(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
}
DataType cntkNewDataType = ConvertDataTypeTensorProtoToCNTK(newDataType);
FunctionPtr cntkFunction = Cast(inputs[0], cntkNewDataType, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Tan")
{
FunctionPtr cntkFunction = Tan(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Atan")
{
FunctionPtr cntkFunction = Atan(inputs[0], ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "TopK")
{
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis", (size_t)-1);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
auto k = GetNamedAttributeAsInt64(node, "k", 1);
FunctionPtr cntkFunction = TopK(inputs[0], k, axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "EyeLike")
{
// Only limited import support is provided.
FunctionPtr cntkFunction = EyeLike(inputs[0], false, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ConstantOfShape")
{
LogicError("Importing ONNX (ConstantOfShape) is not yet supported in CNTK");
return nullptr;
}
else if (onnxOpName == "Crop")
{
// inputShape: [W, H, C] x [N]
const NDShape& inputShape = inputOperand0.Shape();
if (inputShape.Rank() != 3)
RuntimeError("Crop input tensor must have shape [N,C,H,W]. ");
std::vector<int64_t> border = GetNamedAttributeAsInt64Vec(node, "border");
if (border.size() != 4)
RuntimeError("Crop attribute border must be a 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).");
const size_t leftBorder = border[0];
const size_t topBorder = border[1];
const size_t rightBorder = border[2];
const size_t bottomBorder = border[3];
NDShape targetShape = [&](){
const size_t channelSize = inputShape[inputShape.Rank() - 1];
if (HasNamedAttribute(node, "scale"))
{
// targetShape: [W, H]
NDShape targetShape = GetNamedAttributeAsShape(node, "scale", false);
if (targetShape.Rank() != 2)
RuntimeError("Crop attribute scale must be a 1-D values of (height, width).");
// targetShape: [W, H, C]
targetShape.AppendShape(NDShape(channelSize));
return targetShape;
}
else
{
assert((inputShape[0] > (leftBorder + rightBorder)) && (inputShape[1] > (topBorder + bottomBorder)));
size_t targetWidth = inputShape[0] - leftBorder - rightBorder;
size_t targetHeight = inputShape[1] - topBorder - bottomBorder;
return NDShape({ targetWidth, targetHeight, channelSize });
}
}();
auto referent = Constant(targetShape, inputOperand0.GetDataType(), 0.0);
FunctionPtr cntkFunction = Crop(inputOperand0, referent, leftBorder, topBorder, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "OneHotEncoder")
{
// TODO: this only works in this specific case.
std::vector<int64_t> cats = GetNamedAttributeAsInt64Vec(node, "cats_int64s");
int numClass = cats.size();
Axis axis = ConvertONNXAxisToCNTKCppApi(2, inputs[0]);
FunctionPtr cntkFunction = OneHotOp(inputs[0], numClass, false, axis);
return cntkFunction;
}
else
{
LogicError("ONNX (%s) is not supported in CNTK", onnxOpName.c_str());
return nullptr;
}
}