in lib/Importer/Caffe2ModelLoader.cpp [812:2265]
Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
ArgumentDictionaryTy dict = loadArgumentMap(op);
const std::string &typeName = op.type();
mod_.registerOriginalName(op.name());
// Check if operator is supported in parent class, CommonOperatorLoader.
bool loadCommonOperatorSuccess;
ASSIGN_VALUE_OR_RETURN_ERR(loadCommonOperatorSuccess,
tryLoadCommonOperator(typeName, op, dict));
if (loadCommonOperatorSuccess) {
return Error::success();
}
const std::string &opName = loadOperatorName(op);
if (typeName == "Gelu") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
Node *node = G_->createGelu(opName, in);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Conv" || typeName == "ConvRelu") {
return loadConv(op, dict);
}
if (typeName == "Softmax") {
return loadSoftmax(op, dict);
}
if (typeName == "PRelu") {
return loadPRelu(op, dict);
}
if (typeName == "ConvTranspose") {
return loadConvTranspose(op, dict);
}
if (typeName == "Int8Conv" || typeName == "Int8ConvRelu") {
return loadConvQuantized(op, dict);
}
if (typeName == "LayerNorm") {
return loadLayerNorm(op, dict);
}
if (typeName == "Int8SumRelu") {
RETURN_ERR_IF_NOT(op.input_size() == 2,
opErrMsg(op, "Only Sum of 2 inputs is supported."));
RETURN_ERR_IF_NOT(
dict.count("Y_zero_point"),
opErrMsg(op, "missing zero point for quantized outout type"));
RETURN_ERR_IF_NOT(
dict.count("Y_scale"),
opErrMsg(op, "missing Y_scale for quantized output type"));
NodeValue in0;
ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
NodeValue in1;
ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
auto outDims = in0.getType()->dims();
TypeRef outTy;
ASSIGN_VALUE_OR_RETURN_ERR(
outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
auto *add = G_->createAdd(opName + ".sum", outTy, in0, in1);
auto *relu = G_->createRELU(opName + ".relu", add);
RETURN_IF_ERR(addNodeAsOutput(op, relu));
return Error::success();
}
if (typeName == "Int8Relu") {
RETURN_ERR_IF_NOT(op.input_size() == 1,
opErrMsg(op, "Only one input is supported."));
RETURN_ERR_IF_NOT(
dict.count("Y_zero_point"),
opErrMsg(op, "missing zero point for quantized outout type"));
RETURN_ERR_IF_NOT(
dict.count("Y_scale"),
opErrMsg(op, "missing Y_scale for quantized output type"));
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto outDims = in.getType()->dims();
TypeRef outTy;
ASSIGN_VALUE_OR_RETURN_ERR(
outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
auto *relu = G_->createRELU(opName, in, outTy);
RETURN_IF_ERR(addNodeAsOutput(op, relu));
return Error::success();
}
if (typeName == "Int8Quantize") {
RETURN_ERR_IF_NOT(
op.input_size() == 1,
opErrMsg(op, "Glow only supports Int8Quantize with 1 input"));
RETURN_ERR_IF_NOT(
dict.count("Y_zero_point"),
opErrMsg(op, "missing zero point for quantized output type"));
RETURN_ERR_IF_NOT(
dict.count("Y_scale"),
opErrMsg(op, "missing Y_scale for quantized output type"));
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto outDims = in.getType()->dims();
TypeRef outTy;
ASSIGN_VALUE_OR_RETURN_ERR(
outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
Node *N = G_->createQuantize(opName, in, outTy);
RETURN_IF_ERR(addNodeAsOutput(op, N));
return Error::success();
}
if (typeName == "Int8Dequantize") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto *node = G_->createDequantize(opName, in, ElemKind::FloatTy);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "MaxPool" || typeName == "AveragePool" ||
typeName == "Int8MaxPool" || typeName == "Int8AveragePool") {
// Load the inputs:
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
std::vector<unsigned_t> strides;
ASSIGN_VALUE_OR_RETURN_ERR(strides, getSizeHW(dict, "stride", 1));
std::vector<unsigned_t> kernels;
ASSIGN_VALUE_OR_RETURN_ERR(kernels, getSizeHW(dict, "kernel", 0));
std::vector<unsigned_t> pads;
ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict));
bool countIncludePads;
ASSIGN_VALUE_OR_RETURN_ERR(
countIncludePads, getCountIncludePads(dict, /* defaultValue */ true));
std::string order = "NCHW";
if (dict.count("order")) {
ASSIGN_VALUE_OR_RETURN_ERR(order, loadStr(dict["order"]));
}
// We expect the input to be NHWC.
NodeValue finalIn;
if (order == "NCHW") {
finalIn = G_->createTranspose(opName, in, NCHW2NHWC)->getResult();
} else {
finalIn = in;
}
// If 'global_pooling' is set then the operation will pool over the size
// of the input by doing: kernels = {height, width}.
if (dict.count("global_pooling")) {
auto Ty = in.getType();
kernels[0] = Ty->dims()[2];
kernels[1] = Ty->dims()[3];
}
// Check the padding style.
if (dict.count("legacy_pad")) {
int mode;
ASSIGN_VALUE_OR_RETURN_ERR(mode, loadInt(dict["legacy_pad"]));
// Caffe1 (legacy) rounded-up and Caffe2 rounds down.
// This style is deprecated according to caffe2's caffe2_legacy.proto
// definition.
if (static_cast<LegacyPaddingMode>(mode) ==
LegacyPaddingMode::CAFFE_LEGACY_POOLING) {
return MAKE_ERR(opErrMsg(op,
"MaxPool nodes with legacy caffe padding are "
"deprecated and not supported."));
}
}
Node *node = nullptr;
if (typeName == "Int8MaxPool" || typeName == "Int8AveragePool") {
// Create the node with quantized type.
RETURN_ERR_IF_NOT(
dict.count("Y_zero_point"),
opErrMsg(op, "missing zero point for quantized output type"));
RETURN_ERR_IF_NOT(
dict.count("Y_scale"),
opErrMsg(op, "missing Y_scale for quantized output type"));
TypeRef finalInType = finalIn.getType();
ShapeNHWC idim = ShapeNHWC(finalInType->dims());
auto outSz =
calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
std::array<dim_t, 4> outDims = {
{idim.n, outSz.first, outSz.second, idim.c}};
if (typeName == "Int8MaxPool") {
// Int8Maxpool output quantization should be same as the input, so
// just ignore the given params.
node = G_->createMaxPool(opName, finalIn, kernels, strides, pads);
} else {
TypeRef outTy;
ASSIGN_VALUE_OR_RETURN_ERR(
outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
node = G_->createAvgPool(opName, finalIn, outTy, kernels, strides, pads,
NHWC, countIncludePads);
}
} else if (typeName == "MaxPool") {
node = G_->createMaxPool(opName, finalIn, kernels, strides, pads);
} else {
node = G_->createAvgPool(opName, finalIn, kernels, strides, pads, NHWC,
countIncludePads);
}
if (order == "NCHW") {
unsigned resIdx = 0;
if (llvm::isa<MaxPoolNode>(node)) {
resIdx = MaxPoolNode::ResultIdx;
} else if (llvm::isa<AvgPoolNode>(node)) {
resIdx = AvgPoolNode::ResultIdx;
} else {
return MAKE_ERR("Expected either Max or Avg Pool.");
}
// Transpose the output back.
node = G_->createTranspose(opName, node->getNthResult(resIdx), NHWC2NCHW);
}
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "SpatialBN") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
Constant *scale;
ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1)));
Constant *bias;
ASSIGN_VALUE_OR_RETURN_ERR(bias, getConstantByName(op.input(2)));
Constant *mean;
ASSIGN_VALUE_OR_RETURN_ERR(mean, getConstantByName(op.input(3)));
Constant *var;
ASSIGN_VALUE_OR_RETURN_ERR(var, getConstantByName(op.input(4)));
float epsilon = 1e-5f; // default
auto epsilonIt = dict.find("epsilon");
if (epsilonIt != dict.end()) {
ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second));
}
unsigned_t channel;
ASSIGN_VALUE_OR_RETURN_ERR(channel, getChannel(dict));
auto *node = G_->createBatchNormalization(
opName, in.getType(), in, bias, scale, mean, var, channel, epsilon);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Bucketize") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(
dict.count("boundaries"),
opErrMsg(op, "Bucketize: Expected a boundaries member vector"));
std::vector<float> boundaries;
ASSIGN_VALUE_OR_RETURN_ERR(boundaries, getFloats(dict["boundaries"]));
auto *node = G_->createBucketizeNode(opName, in, boundaries);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "ResizeNearest") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
std::string order = "NCHW";
if (dict.count("order")) {
ASSIGN_VALUE_OR_RETURN_ERR(order, loadStr(dict["order"]));
}
// We expect the input to be NHWC.
NodeValue finalIn;
if (order == "NCHW") {
finalIn = G_->createTranspose(opName, in, NCHW2NHWC)->getResult();
} else {
finalIn = in;
}
float heightScale;
ASSIGN_VALUE_OR_RETURN_ERR(heightScale, loadFloat(dict["height_scale"]));
float widthScale;
ASSIGN_VALUE_OR_RETURN_ERR(widthScale, loadFloat(dict["width_scale"]));
std::vector<float> scales;
scales.push_back(1.0f);
scales.push_back(heightScale);
scales.push_back(widthScale);
scales.push_back(1.0f);
auto *node = G_->createResizeNearest(opName, finalIn, scales);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Concat") {
const unsigned numInputs = op.input_size();
llvm::SmallVector<NodeValue, 4> inputs;
inputs.reserve(numInputs);
for (unsigned i = 0; i < numInputs; i++) {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
inputs.push_back(std::move(in));
}
// If axis exists it takes priority over channel.
unsigned_t channel;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(channel, loadInt(dict["axis"]));
} else {
ASSIGN_VALUE_OR_RETURN_ERR(channel, getChannel(dict));
}
unsigned_t addAxis = 0;
if (dict.count("add_axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(addAxis, loadInt(dict["add_axis"]));
}
Node *node{nullptr};
if (addAxis) {
// When add axis is used, this means we have to add a new dimension
// before the axis, instead of merging on the axis.
std::vector<dim_t> outputDims = inputs[0].dims();
if (channel < outputDims.size()) {
unsigned i = 0;
for (const auto &input : inputs) {
RETURN_ERR_IF_NOT(
outputDims[channel] == input.dims()[channel],
opErrMsg(op,
strFormat("inputs need all to have the same dims for "
"concat with add_axis: input 0 (%s) vs "
"input %u (%s), %u vs %u, channel = %u",
op.input(0).c_str(), i, op.input(i).c_str(),
static_cast<unsigned>(outputDims[channel]),
static_cast<unsigned>(input.dims()[channel]),
channel)));
++i;
}
outputDims.insert(outputDims.begin() + channel, numInputs);
node = G_->createConcat(opName, inputs, channel);
node = G_->createReshape(opName, node, outputDims);
} else if (channel == outputDims.size()) {
// We convert inputs into 2D arrays with single columns, thus the
// number of rows will be equal to the product of all original dims.
// Every converted input will look like a vertical line of numbers.
const auto flatVerticalShape = flattenCdr(inputs[0].dims(), channel);
llvm::SmallVector<NodeValue, 4> verticalInputs;
for (auto &input : inputs) {
verticalInputs.push_back(G_->createReshape(
opName, input,
{flatVerticalShape.first, flatVerticalShape.second}));
}
// We glue together the vertical lines, so, the number of columns
// becomes equal to the number of original inputs.
node = G_->createConcat(opName, verticalInputs, 1);
// Reshape to convert to desired shape.
outputDims.push_back(numInputs);
node = G_->createReshape(opName, node, outputDims);
} else {
return MAKE_ERR(opErrMsg(
op, strFormat("Invalid input: channel (=%u) > number of dims (=%u)",
channel, static_cast<unsigned>(outputDims.size()))));
}
} else {
// In normal case (i.e. when we are not adding a new dimension)
// plain createConcat() would suffice.
node = G_->createConcat(opName, inputs, channel);
}
// If we add the axis then node is a Reshape, otherwise it should be
// Concat.
RETURN_ERR_IF_NOT(
llvm::isa<ConcatNode>(node) || llvm::isa<ReshapeNode>(node),
opErrMsg(op,
"Internal error: Node should either be a Concat or Reshape."));
NodeValue finalNode = llvm::isa<ConcatNode>(node)
? NodeValue(node, ConcatNode::ResultIdx)
: NodeValue(node, ReshapeNode::ResultIdx);
nodeValueByName_[op.output(0)] = finalNode;
// Concat may have a second output in Caffe2 (split_info), but we don't
// use it for inference
return Error::success();
}
if (typeName == "FC" || typeName == "FCTransposed" || typeName == "Int8FC" ||
typeName == "FbFCPacked") {
RETURN_ERR_IF_NOT(op.input_size() == 3,
"Glow only suports FC with 3 inputs");
// Load the inputs:
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto originalInputDims = in.getType()->dims();
size_t axis = 1;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
}
// Load weights.
unsigned_t axis_w = 1;
if (dict.count("axis_w")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis_w, loadInt(dict["axis_w"]));
}
NodeValue W;
if (hasConstantByName(op.input(1))) {
ASSIGN_VALUE_OR_RETURN_ERR(W, getConstantByName(op.input(1)));
} else {
ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1)));
}
// Caffe2 stores the transposed W matrix. In here we first coerce W to a
// 2D matrix size if necessary and then transpose it back.
auto wDims = flattenCdr(W.dims(), axis_w);
if (W.dims().size() > 2) {
W = G_->createReshape(W.getNode()->getName(), W,
{wDims.first, wDims.second});
}
if (typeName == "FC" || typeName == "Int8FC" || typeName == "FbFCPacked") {
W = G_->createTranspose(W.getNode()->getName(), W, {1, 0});
}
NodeValue B;
if (hasConstantByName(op.input(2))) {
ASSIGN_VALUE_OR_RETURN_ERR(B, getConstantByName(op.input(2)));
} else {
ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(2)));
}
Node *node = nullptr;
if (typeName == "Int8FC") {
// Create a node with quantized type.
auto outputDims = flattenCdr(in.dims(), axis);
TypeRef outTy;
ASSIGN_VALUE_OR_RETURN_ERR(
outTy, loadQuantTy(opName, ElemKind::Int8QTy,
{outputDims.first, B.dims()[0]}, dict));
int dequantizeOutput = 0;
if (dict.count("dequantize_output")) {
ASSIGN_VALUE_OR_RETURN_ERR(dequantizeOutput,
loadInt(dict["dequantize_output"]));
}
if (dequantizeOutput == 1) {
node = G_->createDynamicQuantizedFullyConnected(opName, in, W, B);
} else {
node = G_->createFullyConnected(opName, in, W, B, outTy, axis);
}
} else if (typeName == "FbFCPacked") {
RETURN_ERR_IF_NOT(W.getElementType() == ElemKind::Float16Ty,
opErrMsg(op, "Expected float16 weights."));
auto fp16InputType =
mod_.uniqueType(ElemKind::Float16Ty, in.getType()->dims());
in = G_->createConvertTo(opName + ".ConvertInput", in, fp16InputType);
auto fp16BiasType = mod_.uniqueType(ElemKind::Float16Ty, B.dims());
auto *fp16Bias =
G_->createConvertTo(opName + ".ConvertBias", B, fp16BiasType);
auto outputDims = flattenCdr(in.dims(), axis);
TypeRef OT =
mod_.uniqueType(ElemKind::Float16Ty, {outputDims.first, B.dims()[0]});
auto fc = G_->createFullyConnected(opName, in, W, fp16Bias, OT, axis);
auto outputType =
mod_.uniqueType(ElemKind::FloatTy, fc->getResult().dims());
node = G_->createConvertTo(opName + ".ConvertOutput", fc, outputType);
} else {
auto outputDims = flattenCdr(in.dims(), axis);
TypeRef outputType =
mod_.uniqueType(ElemKind::FloatTy, {outputDims.first, B.dims()[0]});
node = G_->createFullyConnected(opName, in, W, B, outputType, axis);
}
// If number of original input dims is greater than 2, expand the output
// dims back with the same axis.
if (axis != 1) {
llvm::SmallVector<dim_t, max_tensor_dimensions> reshapeDims;
size_t totalReshapeSize = 1;
for (size_t i = 0; i < axis; ++i) {
auto d = originalInputDims[i];
reshapeDims.push_back(d);
totalReshapeSize *= static_cast<dim_t>(d);
}
size_t finalDim = typeName == "FCTransposed" ? wDims.second : wDims.first;
reshapeDims.push_back(finalDim);
totalReshapeSize *= finalDim;
size_t totalOriginalOutputSize = node->getNthResult(0).getType()->size();
RETURN_ERR_IF_NOT(
totalReshapeSize == totalOriginalOutputSize,
opErrMsg(op, strFormat("Cannot reshape from size %lu to size %lu",
totalOriginalOutputSize, totalReshapeSize)));
node = G_->createReshape(opName + ".fc.out", node, reshapeDims);
}
// Save the outputs:
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "ChannelShuffle") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
size_t group;
ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict["group"]));
size_t kernel;
ASSIGN_VALUE_OR_RETURN_ERR(kernel, loadInt(dict["kernel"]));
Node *node = G_->createChannelShuffle(opName, in, group, kernel);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Squeeze") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
std::vector<dim_t> dims;
ASSIGN_VALUE_OR_RETURN_ERR(dims, getShape<dim_t>(dict["dims"]));
Node *node = G_->createSqueeze(opName, in, dims);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Log") {
// Load the inputs:
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
// Create the log:
auto *R = G_->createLog(opName, in);
RETURN_IF_ERR(addNodeAsOutput(op, R));
return Error::success();
}
if (typeName == "Swish") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto *S = G_->createSwish(opName, in);
RETURN_IF_ERR(addNodeAsOutput(op, S));
return Error::success();
}
if (typeName == "Logit") {
// Load the input and (optional) epsilon clamping value:
NodeValue input;
ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
auto epsIt = dict.find("eps");
// default: 1e-6 (as in Caffe2)
float eps = 1E-6f;
if (epsIt != dict.end()) {
ASSIGN_VALUE_OR_RETURN_ERR(eps, loadFloat(epsIt->second));
}
auto *node = G_->createLogit(opName, input, eps);
// Save the outputs:
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "EQ") {
NodeValue in0;
ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
NodeValue in1;
ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
auto *node = G_->createCmpEQ(opName, in0, in1);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Tile") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
unsigned_t tiles;
ASSIGN_VALUE_OR_RETURN_ERR(tiles, loadInt(dict["tiles"]));
unsigned_t axis;
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
auto *node = G_->createTile(opName, in, tiles, axis);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Free") {
// Glow frees memory automatically.
return Error::success();
}
if (typeName == "StopGradient" || typeName == "ScaleGradient") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
// Currently Caffe2 importer only supports inference.
RETURN_IF_ERR(addNodeAsOutput(op, in));
return Error::success();
}
if (typeName == "Transpose") {
RETURN_IF_ERR(loadTranspose(op, dict, "axes"));
return Error::success();
}
if (typeName == "NCHW2NHWC") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto *node = G_->createTranspose(opName, in, NCHW2NHWC);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "CopyCPUToMKL" || typeName == "CopyMKLToCPU" ||
typeName == "Copy" || typeName == "EnsureCPUOutput" ||
typeName == "EnsureDense" || typeName == "Dropout") {
// Glow does not support any of these ops now, so implement them as
// no-ops. Note: Implement this as a no-op reshape because these ops may
// have partition information, and we need a node to maintain the parent
// Function partition it specified. This reshape will get eliminated later
// on during graph optimizations.
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
ReshapeNode *RN = G_->createReshape(in.getNode()->getName(), in, in.dims());
RETURN_IF_ERR(addNodeAsOutput(op, RN));
return Error::success();
}
if (typeName == "Slice") {
NodeValue data;
ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
std::vector<ssize_t> starts;
ASSIGN_VALUE_OR_RETURN_ERR(starts, getShape<ssize_t>(dict["starts"]));
std::vector<ssize_t> ends;
ASSIGN_VALUE_OR_RETURN_ERR(ends, getShape<ssize_t>(dict["ends"]));
std::vector<dim_t> newStarts, newEnds;
RETURN_ERR_IF_NOT(
starts.size() == ends.size(),
opErrMsg(op, strFormat(
"Slice starts %lu and %lu ends must be the same size.",
starts.size(), ends.size())));
for (size_t i = 0; i < starts.size(); i++) {
ssize_t newStart = starts[i];
if (newStart == -1) {
newStart = data.dims()[i];
}
RETURN_ERR_IF_NOT(
newStart >= 0,
opErrMsg(op,
strFormat("Indices should never be negative, but found %lu ",
newStart)));
newStarts.push_back(newStart);
ssize_t newEnd = ends[i];
if (newEnd == -1) {
newEnd = data.dims()[i];
}
RETURN_ERR_IF_NOT(
newEnd >= 0,
opErrMsg(op,
strFormat("Indices should never be negative, but found %lu ",
newEnd)));
newEnds.push_back(newEnd);
}
Node *SN = G_->createSlice(opName, data, newStarts, newEnds);
RETURN_IF_ERR(addNodeAsOutput(op, SN));
return Error::success();
}
if (typeName == "Clip") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
float cmin = std::numeric_limits<float>::lowest();
if (dict.count("min")) {
ASSIGN_VALUE_OR_RETURN_ERR(cmin, loadFloat(dict.find("min")->second));
}
float cmax = std::numeric_limits<float>::max();
if (dict.count("max")) {
ASSIGN_VALUE_OR_RETURN_ERR(cmax, loadFloat(dict.find("max")->second));
}
auto *node = G_->createClip(loadOperatorName(op), in, cmin, cmax);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "MatMul") {
RETURN_IF_ERR(loadMatMul(op, dict));
return Error::success();
}
if (typeName == "Cast") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
int to;
ASSIGN_VALUE_OR_RETURN_ERR(to, loadInt(dict["to"]));
switch (to) {
case caffe2::TensorProto_DataType_FLOAT: {
RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::FloatTy,
opErrMsg(op, "Can only cast float to float."));
break;
}
case caffe2::TensorProto_DataType_INT32: {
RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::Int32ITy,
opErrMsg(op, "Can only cast int32 to int32."));
break;
}
case caffe2::TensorProto_DataType_INT64: {
RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::Int64ITy,
opErrMsg(op, "Can only cast int64 to int64."));
break;
}
default:
return MAKE_ERR(opErrMsg(op, "Unsupported Cast type."));
}
RETURN_IF_ERR(addNodeAsOutput(op, in));
return Error::success();
}
if (typeName == "HalfToFloat") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto convertedType =
mod_.uniqueType(ElemKind::FloatTy, in.getType()->dims());
auto *R = G_->createConvertTo(opName + ".ConvertInput", in, convertedType);
RETURN_IF_ERR(addNodeAsOutput(op, R));
return Error::success();
}
if (typeName == "ScatterAssign") {
NodeValue data;
ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
NodeValue indices;
ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
NodeValue slices;
ASSIGN_VALUE_OR_RETURN_ERR(slices, getNodeValueByName(op.input(2)));
assert(indices.dims().size() == 1 && "Indices should be 1-dimensional!");
NodeValue indices2D = G_->createReshape(opName + ".indices.2d", indices,
{indices.dims()[0], 1});
Node *SAN = G_->createScatterData(opName, data, indices2D, slices);
RETURN_IF_ERR(addNodeAsOutput(op, SAN));
return Error::success();
}
if (typeName == "ConstantFill" || typeName == "GivenTensorIntFill" ||
typeName == "GivenTensorInt64Fill" || typeName == "GaussianFill" ||
typeName == "UniformFill") {
RETURN_IF_ERR(loadWeight(op));
return Error::success();
}
if (typeName == "SigmoidCrossEntropyWithLogits") {
NodeValue logits;
ASSIGN_VALUE_OR_RETURN_ERR(logits, getNodeValueByName(op.input(0)));
NodeValue targets;
ASSIGN_VALUE_OR_RETURN_ERR(targets, getNodeValueByName(op.input(1)));
Node *SCEL =
G_->createSigmoidCrossEntropyWithLogits(opName, logits, targets);
RETURN_IF_ERR(addNodeAsOutput(op, SCEL));
return Error::success();
}
if (typeName == "ElementwiseLinear") {
NodeValue X, w, b;
// If the axis argument does not exist in the protobuf, the default
// value should be 1.
unsigned axis = 1;
ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
ASSIGN_VALUE_OR_RETURN_ERR(w, getNodeValueByName(op.input(1)));
ASSIGN_VALUE_OR_RETURN_ERR(b, getNodeValueByName(op.input(2)));
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
}
Node *EL = G_->createElementwiseLinear(opName, X, w, b, axis);
RETURN_IF_ERR(addNodeAsOutput(op, EL));
return Error::success();
}
if (typeName == "AveragedLoss") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
auto *node = G_->createBatchedReduceMean(opName, in, 0);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Mod") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
int64_t divisor;
ASSIGN_VALUE_OR_RETURN_ERR(divisor, loadInt(dict["divisor"]));
RETURN_ERR_IF_NOT(
divisor >= 1,
opErrMsg(op,
strFormat("Divisor must not be less than 1, but found %ld ",
divisor)));
bool signFollowDivisor = false;
if (dict.count("sign_follow_divisor")) {
ASSIGN_VALUE_OR_RETURN_ERR(signFollowDivisor,
loadInt(dict["sign_follow_divisor"]));
}
auto *node = G_->createModulo(opName, in, divisor, signFollowDivisor);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Scale") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
float scale = 1.0;
if (dict.count("scale")) {
ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict["scale"]));
}
auto scaleType = mod_.uniqueType(ElemKind::FloatTy, {in.dims()});
auto scales = G_->createSplat(opName + ".scales", scaleType, scale);
Node *node = G_->createMul(opName, in, scales);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "SparseLengthsWeightedSum8BitsRowwise" ||
typeName == "SparseLengthsSum8BitsRowwise" ||
typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
typeName == "SparseLengthsSumFused8BitRowwise" ||
typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
typeName == "SparseLengthsSumFused4BitRowwise") {
const bool isWeighted =
typeName == "SparseLengthsWeightedSum8BitsRowwise" ||
typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
typeName == "SparseLengthsWeightedSumFused4BitRowwise";
const bool isFused =
typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
typeName == "SparseLengthsSumFused8BitRowwise" ||
typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
typeName == "SparseLengthsSumFused4BitRowwise";
const bool is4Bit =
typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
typeName == "SparseLengthsSumFused4BitRowwise";
// If weighted, then the weights are the second input and so we need to
// shift indices/lengths/scalesBiases.
size_t indicesIdx = 1;
size_t lengthsIdx = 2;
size_t scalesBiasesIdx = 3;
if (isWeighted) {
indicesIdx++;
lengthsIdx++;
scalesBiasesIdx++;
}
NodeValue data;
ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
NodeValue weights;
if (isWeighted) {
ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1)));
}
NodeValue indices;
ASSIGN_VALUE_OR_RETURN_ERR(indices,
getNodeValueByName(op.input(indicesIdx)));
NodeValue lengths;
ASSIGN_VALUE_OR_RETURN_ERR(lengths,
getNodeValueByName(op.input(lengthsIdx)));
Storage *dataS = llvm::dyn_cast<Storage>(data);
const dim_t numRows = data.dims()[0];
// Make sure all the shapes make sense.
RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
opErrMsg(op, "lengths must be a vector."));
RETURN_ERR_IF_NOT(indices.dims().size() == 1,
opErrMsg(op, "indices must be a vector."));
LengthsMode lengthsMode;
ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));
float avgLength;
ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict));
Node *node;
if (isFused) {
RETURN_IF_ERR(setFusedTy(dataS, is4Bit ? ElemKind::UInt4FusedFP16QTy
: ElemKind::UInt8FusedQTy));
// No other work to do, since the data is already loaded fused, so just
// create the new node with its inputs.
if (isWeighted) {
node = G_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
opName, dataS, weights, indices, lengths,
/* useFP16Accumulation */ false, lengthsMode, avgLength);
} else {
node = G_->createFusedRowwiseQuantizedSparseLengthsSum(
opName, dataS, indices, lengths, /* useFP16Accumulation */ false,
lengthsMode, avgLength);
}
if (is4Bit) {
node = G_->createConvertTo(opName, node, ElemKind::FloatTy);
}
} else {
NodeValue scalesBiases;
ASSIGN_VALUE_OR_RETURN_ERR(scalesBiases,
getNodeValueByName(op.input(scalesBiasesIdx)));
Constant *scalesBiasesC = llvm::dyn_cast<Constant>(scalesBiases);
RETURN_ERR_IF_NOT(scalesBiasesC,
opErrMsg(op, "scales_biases must be Constant."));
RETURN_ERR_IF_NOT(scalesBiases.dims().size() == 2,
opErrMsg(op, "scale_bias has to be a matrix."));
RETURN_ERR_IF_NOT(
scalesBiases.dims()[0] == numRows,
opErrMsg(
op,
strFormat("scale_bias must have the same number of rows as data, "
"but found scale_bias %d and rows %d ",
int(scalesBiases.dims()[0]), int(numRows))));
RETURN_ERR_IF_NOT(
scalesBiases.dims()[1] == 2,
opErrMsg(op,
strFormat("Second dim of scale_bias has to be equal to 2 "
"but found %d ",
int(scalesBiases.dims()[1]))));
// Now strip out the scales and biases into their own tensors.
NodeValue sliceScales =
G_->createSlice(scalesBiasesC->getName().str() + "_scale",
scalesBiasesC, {0, 0}, {numRows, 1});
NodeValue sliceBiases =
G_->createSlice(scalesBiasesC->getName().str() + "_bias",
scalesBiasesC, {0, 1}, {numRows, 2});
sliceScales =
G_->createReshape(sliceScales.getNode()->getName().str() + "_1D",
sliceScales, {numRows});
sliceBiases =
G_->createReshape(sliceBiases.getNode()->getName().str() + "_1D",
sliceBiases, {numRows});
// Now create the actual node.
if (isWeighted) {
node = G_->createRowwiseQuantizedSparseLengthsWeightedSum(
opName, dataS, sliceScales, sliceBiases, weights, indices, lengths,
/* precision */ ElemKind::FloatTy,
/* useFP16Accumulation */ false, lengthsMode, avgLength);
} else {
node = G_->createRowwiseQuantizedSparseLengthsSum(
opName, dataS, sliceScales, sliceBiases, indices, lengths,
/* precision */ ElemKind::FloatTy,
/* useFP16Accumulation */ false, lengthsMode, avgLength);
}
}
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "LengthsRangeFill") {
NodeValue lengths;
ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
opErrMsg(op, "lengths must be a 1D vector."));
auto maxOutputSizeIt = dict.find("maxOutputSize");
RETURN_ERR_IF_NOT(
maxOutputSizeIt != dict.end(),
opErrMsg(op, "Require maxOutputSize when loading LengthsRangeFill."));
unsigned_t maxOutputSize;
ASSIGN_VALUE_OR_RETURN_ERR(maxOutputSize, loadInt(maxOutputSizeIt->second));
auto *LRF = G_->createLengthsRangeFill(opName, lengths, maxOutputSize);
RETURN_IF_ERR(addNodeAsOutput(op, LRF));
return Error::success();
}
// TODO: add checks for number of inputs and argument values
if (typeName == "ReduceBackSum") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(in.dims().size() >= 2,
opErrMsg(op, "Input should be at least 2D."));
Node *node = G_->createBatchedReduceAdd(opName, in, in.dims().size() - 1);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "RMSNorm") {
NodeValue X, gamma, beta;
ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(X.dims().size() == 2,
opErrMsg(op, "X should be a 2D tensor."));
ASSIGN_VALUE_OR_RETURN_ERR(gamma, getNodeValueByName(op.input(1)));
RETURN_ERR_IF_NOT(gamma.dims().size() == 1,
opErrMsg(op, "gamma should be a 1D tensor."));
ASSIGN_VALUE_OR_RETURN_ERR(beta, getNodeValueByName(op.input(2)));
RETURN_ERR_IF_NOT(beta.dims().size() == 1,
opErrMsg(op, "beta should be a 1D tensor."));
float epsilon = .0f;
if (dict.count("eps")) {
ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(dict["eps"]));
}
auto nodes = G_->createRMSNorm(opName, X, gamma, beta, epsilon);
nodeValueByName_[op.output(0)] = nodes[0];
nodeValueByName_[op.output(1)] = nodes[1];
return Error::success();
}
if (typeName == "Mean") {
const unsigned numInputs = op.input_size();
RETURN_ERR_IF_NOT(numInputs > 0,
opErrMsg(op, "Expect at least one input."));
std::vector<NodeValue> inputs;
inputs.reserve(numInputs);
for (unsigned i = 0; i < numInputs; i++) {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
inputs.push_back(std::move(in));
}
// Check that all inputs have the same shape
const auto shape = inputs[0].dims();
for (unsigned i = 1; i < numInputs; i++) {
RETURN_ERR_IF_NOT(
shape == inputs[i].dims(),
opErrMsg(op,
"All inputs should have the same shape, violating input " +
op.input(i)));
}
if (numInputs == 1) {
RETURN_IF_ERR(addNodeAsOutput(op, inputs[0]));
return Error::success();
}
Node *node = G_->createConcat(opName + ".concat", inputs, 0);
std::vector<dim_t> newShape{numInputs};
newShape.insert(newShape.end(), shape.begin(), shape.end());
node = G_->createReshape(opName + ".reshape", node, newShape);
node = G_->createBatchedReduceMean(opName + ".reduceMean", node, 0);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Negative") {
RETURN_IF_ERR(loadNeg(op, dict));
return Error::success();
}
if (typeName == "LpNorm") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
int p = 2;
if (dict.count("p")) {
ASSIGN_VALUE_OR_RETURN_ERR(p, loadInt(dict["p"]));
RETURN_ERR_IF_NOT(p == 1 || p == 2,
opErrMsg(op, "p should be either 1 or 2."));
}
bool average = false;
if (dict.count("average")) {
ASSIGN_VALUE_OR_RETURN_ERR(average, loadInt(dict["average"]));
}
RETURN_ERR_IF_NOT(!average, opErrMsg(op, "average is not supported."));
Node *node = nullptr;
if (p == 1) {
node = G_->createAbs(opName, in);
} else {
node = G_->createPow(opName, in, 2);
}
const auto dims1D = flattenCdr(in.dims(), in.dims().size());
node = G_->createReshape(opName + ".reshape1D", node, dims1D.first);
auto outputType = mod_.uniqueType(in.getElementType(), {1});
node = G_->createBatchedReduceAdd(opName + ".sum", outputType, node, 0);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "ArgMin") {
NodeValue input;
ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
int axis = 0;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
}
bool keepDims = true;
if (dict.count("keepdims")) {
ASSIGN_VALUE_OR_RETURN_ERR(keepDims, loadInt(dict.at("keepdims")));
}
auto node = G_->createArgMin(opName, input, axis, keepDims);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Sign") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
Node *zeroes = G_->createSplat(opName + ".zeroes", in.getType(), 0.f);
Node *isPos = G_->createCmpLT(opName + ".isPos", zeroes, in);
Node *isNeg = G_->createCmpLT(opName + ".isNeg", in, zeroes);
Node *posOnes = G_->createSplat(opName + ".posOnes", in.getType(), 1);
Node *negOnes = G_->createSplat(opName + ".negOnes", in.getType(), -1);
Node *node = G_->createSelect(opName + ".fillPos", isPos, posOnes, zeroes);
node = G_->createSelect(opName + ".fillNeg", isNeg, negOnes, node);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "Softplus") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
Node *node = G_->createSoftPlus(opName, in);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "TopK") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(
op.input_size() <= 2,
opErrMsg(
op,
strFormat(
"TopK: Maximum number of inputs is 2, but found input size %d ",
op.input_size())));
unsigned_t k = 0;
if (op.input_size() > 1) {
Constant *kConst = getConstantByNameOrNull(op.input(1));
RETURN_ERR_IF_NOT(
kConst,
opErrMsg(op, "TopK: Non-constant k is not supported by Glow."));
RETURN_ERR_IF_NOT(
kConst->getElementType() == ElemKind::Int64ITy,
opErrMsg(op, strFormat(
"TopK: k input must be of type Int64, but found "
"input type '%s' ",
kConst->getType()->getElementName().str().c_str())));
auto constH = kConst->getPayload().getHandle<int64_t>();
k = constH.at({0});
} else {
ASSIGN_VALUE_OR_RETURN_ERR(k, loadInt(dict["k"]));
}
int lastDim = in.dims().size() - 1;
int axis = lastDim;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis,
loadAxis<int>(dict["axis"], in.dims().size()));
}
RETURN_ERR_IF_NOT(
axis == lastDim,
opErrMsg(
op,
strFormat(
"TopK: Currently only support axis %d being last dimension %d ",
axis, lastDim)));
TopKNode *R = G_->createTopK(opName, in, k, ElemKind::Int32ITy);
RETURN_IF_ERR(addNodeAsOutput(op, R));
return Error::success();
}
if (typeName == "FillExamplesWithIndicator") {
// Support FillExamplesWithIndicator
NodeValue data;
ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
NodeValue indicator;
ASSIGN_VALUE_OR_RETURN_ERR(indicator, getNodeValueByName(op.input(1)));
// Validating input types and shapes
RETURN_ERR_IF_NOT(
indicator.getElementType() == ElemKind::Int32ITy ||
indicator.getElementType() == ElemKind::Int64ITy,
opErrMsg(op, "Indicator should be of int32 or int64 type."));
RETURN_ERR_IF_NOT(indicator.dims().size() == 1,
opErrMsg(op, "Indicator should be 1D tensor."));
dim_t dataReshapeDim = flattenCdr(data.dims()).second;
ShapeVector outDims{indicator.dims()[0]};
outDims.insert(outDims.end(), data.dims().begin() + 1, data.dims().end());
auto outTy2D = mod_.uniqueTypeWithNewShape(
data.getType(), {indicator.dims()[0], dataReshapeDim});
auto data2D = G_->createReshape(opName + ".data2D", data,
{data.dims()[0], dataReshapeDim});
if (indicator.getElementType() == ElemKind::Int64ITy) {
indicator = G_->createConvertTo(opName + ".int64ToInt32", indicator,
ElemKind::Int32ITy);
}
// Select only takes boolean indicators, and converting from int to bool
// must go from int -> float -> bool. Due to fp16 clipping, since only
// int32 -> fp16 conversions are available, there is an initial conversion
// from int64 to int32 if necessary.
auto indicatorFloat = G_->createConvertTo(opName + ".intToFloat", indicator,
ElemKind::FloatTy);
auto indicatorBool = G_->createConvertTo(opName + ".floatToBool",
indicatorFloat, ElemKind::BoolTy);
auto nzIndices = G_->createNonZero(opName + ".nonzero", indicatorBool);
auto nzIndicesFixed = fixNonZero(G_, mod_, opName, nzIndices);
auto nonZeroCount = data.dims()[0];
RETURN_ERR_IF_NOT(nonZeroCount <= nzIndicesFixed->getNthResult(0).dims()[0],
opErrMsg(op,
"The number of "
"non-zero elements in the indicator must be at "
"least that of the first dimension of data"));
auto indices = G_->createSlice(opName + ".indices", nzIndicesFixed, {0, 0},
{data.dims()[0], 1});
auto zeros = G_->createSplat(opName + ".zeros", outTy2D, 0);
auto res2D = G_->createScatterData(opName + ".scatterData", zeros, indices,
data2D, true);
auto node = G_->createReshape(opName + ".result", res2D, outDims);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "BatchSparseToDense") {
// Support BatchSparseToDense for output second dim = 1 only
NodeValue lengths;
ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
NodeValue indices;
ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
NodeValue values;
ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));
dim_t denseLastDim = 1;
if (dict.count("dense_last_dim")) {
ASSIGN_VALUE_OR_RETURN_ERR(denseLastDim,
loadInt(dict.at("dense_last_dim")));
}
RETURN_ERR_IF_NOT(
denseLastDim == 1,
opErrMsg(op, "Only output second dimension = 1 supported"));
// Validating input types and shapes
RETURN_ERR_IF_NOT(
lengths.getElementType() == ElemKind::Int32ITy ||
lengths.getElementType() == ElemKind::Int64ITy,
opErrMsg(op, "Lengths should be of int32 or int64 type."));
RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
opErrMsg(op, "Lengths should be 1D tensor."));
RETURN_ERR_IF_NOT(
indices.getElementType() == ElemKind::Int32ITy ||
indices.getElementType() == ElemKind::Int64ITy,
opErrMsg(op, "Indices should be of int32 or int64 type."));
RETURN_ERR_IF_NOT(indices.dims().size() == 1,
opErrMsg(op, "Indices should be 1D tensor."));
RETURN_ERR_IF_NOT(values.getElementType() == ElemKind::FloatTy,
opErrMsg(op, "Values should be of float type."));
RETURN_ERR_IF_NOT(
indices.dims()[0] == values.dims()[0],
opErrMsg(op, "There should be the same number of values as indices."));
float defaultValue = 0.0;
if (dict.count("default_value")) {
ASSIGN_VALUE_OR_RETURN_ERR(defaultValue,
loadFloat(dict.at("default_value")));
}
// Select only takes boolean indicators, and converting from int to bool
// must go from int -> float -> bool. Due to fp16 clipping, since only
// int32 -> fp16 conversions are available, there is an initial conversion
// from int64 to int32 if necessary.
if (lengths.getElementType() == ElemKind::Int64ITy) {
lengths = G_->createConvertTo(opName + ".int64ToInt32", lengths,
ElemKind::Int32ITy);
}
auto lengthsIntToFloat =
G_->createConvertTo(opName + ".intToFloat", lengths, ElemKind::FloatTy);
auto lengthsFloatToBool = G_->createConvertTo(
opName + ".floatToBool", lengthsIntToFloat, ElemKind::BoolTy);
auto nonZeroIndices =
G_->createNonZero(opName + ".nonzero", lengthsFloatToBool);
auto nonZeroIndicesFixed = fixNonZero(G_, mod_, opName, nonZeroIndices);
auto numIndices = indices.dims()[0];
auto indicesSliced = G_->createSlice(
opName + ".indicesSlice", nonZeroIndicesFixed, {0, 0}, {numIndices, 1});
ShapeVector outDims{lengths.dims()[0], 1};
auto dataTy = mod_.uniqueTypeWithNewShape(values.getType(), outDims);
auto data = G_->createSplat(opName + ".data", dataTy, defaultValue);
auto values2D =
G_->createReshape(opName + ".reshape", values, {numIndices, 1});
auto scatterData = G_->createScatterData(opName + ".scatterData", data,
indicesSliced, values2D, true);
RETURN_IF_ERR(addNodeAsOutput(op, scatterData));
return Error::success();
}
if (typeName == "SparseLabelSplit") {
NodeValue lengths;
ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
NodeValue indices;
ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
NodeValue values;
ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));
dim_t numLabels = 0;
RETURN_ERR_IF_NOT(dict.count("num_labels"),
opErrMsg(op, "num_labels was not provided."));
ASSIGN_VALUE_OR_RETURN_ERR(numLabels, loadInt(dict.at("num_labels")));
bool keepGradientOffsetMap = false;
if (dict.count("keep_gradient_offset_map")) {
ASSIGN_VALUE_OR_RETURN_ERR(keepGradientOffsetMap,
loadInt(dict.at("keep_gradient_offset_map")));
}
// Validating input types and shapes
RETURN_ERR_IF_NOT(lengths.getElementType() == ElemKind::Int32ITy,
opErrMsg(op, "Lengths should be of int32 type."));
RETURN_ERR_IF_NOT(lengths.dims().size() == 1 || lengths.dims().size() == 2,
opErrMsg(op, "Lengths should be 1D or 2D tensor."));
RETURN_ERR_IF_NOT(indices.getElementType() == ElemKind::Int64ITy,
opErrMsg(op, "Indices should be of int64 type."));
RETURN_ERR_IF_NOT(indices.dims().size() == 1 || indices.dims().size() == 2,
opErrMsg(op, "Indices should be 1D or 2D tensor."));
RETURN_ERR_IF_NOT(values.getElementType() == ElemKind::FloatTy,
opErrMsg(op, "Values should be of float type."));
RETURN_ERR_IF_NOT(values.dims().size() == 1 || values.dims().size() == 2,
opErrMsg(op, "Values should be 1D or 2D tensor."));
RETURN_ERR_IF_NOT(
indices.dims() == values.dims(),
opErrMsg(op, "Indices and values should have the same shape."));
// Optional conversion from 2D to 1D inputs
if (lengths.dims().size() == 2) {
RETURN_ERR_IF_NOT(
lengths.dims()[1] == 1,
opErrMsg(op, "Second dimension should be 1 in lengths."));
lengths = G_->createReshape(opName + ".lengths1D", lengths,
{lengths.dims()[0]});
}
if (indices.dims().size() == 2) {
RETURN_ERR_IF_NOT(
indices.dims()[1] == 1,
opErrMsg(op, "Second dimension should be 1 in indices."));
indices = G_->createReshape(opName + ".indices1D", indices,
{indices.dims()[0]});
}
if (values.dims().size() == 2) {
RETURN_ERR_IF_NOT(
values.dims()[1] == 1,
opErrMsg(op, "Second dimension should be 1 in values."));
values =
G_->createReshape(opName + ".values1D", values, {values.dims()[0]});
}
SparseLabelSplitNode *node =
G_->createSparseLabelSplit(opName, lengths, indices, values, numLabels);
std::vector<SliceNode *> labelValueSlices;
G_->createSplit(opName + ".splitLabelValues",
node->getNthResult(SparseLabelSplitNode::LabelValuesIdx),
numLabels, 0, {}, labelValueSlices);
std::vector<SliceNode *> exampleIdSlices;
G_->createSplit(opName + ".splitExampleIds",
node->getNthResult(SparseLabelSplitNode::ExampleIdsIdx),
numLabels, 0, {}, exampleIdSlices);
const auto numItems = indices.dims()[0] / numLabels;
std::vector<Node *> labelValues;
for (auto slice : labelValueSlices) {
labelValues.push_back(
G_->createReshape(opName + ".reshapeLabelValue", slice, {numItems}));
}
std::vector<Node *> exampleIds;
for (auto slice : exampleIdSlices) {
exampleIds.push_back(
G_->createReshape(opName + ".reshapeExamplId", slice, {numItems}));
}
for (dim_t i = 0; i < numLabels; ++i) {
nodeValueByName_[op.output(i)] = labelValues[i];
}
for (dim_t i = 0; i < numLabels; ++i) {
nodeValueByName_[op.output(numLabels + i)] = exampleIds[i];
}
if (keepGradientOffsetMap) {
nodeValueByName_[op.output(2 * numLabels)] =
node->getNthResult(SparseLabelSplitNode::GradientOffsetMapIdx);
}
return Error::success();
}
if (typeName == "Log1p") {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
Node *ones = G_->createSplat(opName + ".ones", in.getType(), 1.0f);
Node *add = G_->createAdd(opName + ".add", in, ones);
Node *node = G_->createLog(opName + ".log", add);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
if (typeName == "ReduceBackMean") {
const unsigned numInputs = op.input_size();
RETURN_ERR_IF_NOT(numInputs == 1,
opErrMsg(op, "Only single input is supported."));
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
RETURN_ERR_IF_NOT(in.dims().size() >= 2,
opErrMsg(op, "Input should be at least 2D."));
int numReduceDim = 1;
if (dict.count("num_reduce_dim")) {
ASSIGN_VALUE_OR_RETURN_ERR(numReduceDim, loadInt(dict["num_reduce_dim"]));
}
// TODO: check maybe we can support more dimensions to be reduced
RETURN_ERR_IF_NOT(numReduceDim == 1,
opErrMsg(op, "Supporting reducing only one dimension."));
Node *node = G_->createBatchedReduceMean(opName, in, in.dims().size() - 1);
RETURN_IF_ERR(addNodeAsOutput(op, node));
return Error::success();
}
return MAKE_ERR(unexpectedNodeErrorMessage(op, "Unsupported operator."));
}