in ngraph_bridge/ngraph_builder.cc [2381:2458]
static Status TranslateSplitVOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
ng::Output<ng::Node> ng_input, ng_split_length, ng_split_dim;
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, ng_input));
ng::Shape shape = ng_input.get_shape();
int rank = shape.size();
std::vector<int64> split_dim_vec;
TF_RETURN_IF_ERROR(
GetStaticInputVector(op, 2, static_input_map, &split_dim_vec));
// there should be at least one element specified as axis and not more than
// one as axis is 0-D
if (split_dim_vec.size() != 1) {
return errors::InvalidArgument(
"split_dim_tensor must have "
"exactly one element.");
}
TF_RETURN_IF_ERROR(CheckAxisDimInRange(split_dim_vec, rank));
int split_dim = split_dim_vec[0] + (split_dim_vec[0] < 0 ? (int64)rank : 0);
ng_split_dim = ConstructNgNode<opset::Constant>(op->name(), ng::element::i32,
ng::Shape{}, split_dim);
std::vector<int> split_lengths_vec;
TF_RETURN_IF_ERROR(
GetStaticInputVector(op, 1, static_input_map, &split_lengths_vec));
// length: Length of size_splits
int length = 0;
int idx = -1;
// Find out the total length of the splits and locate -1 's index, if any
bool has_one_neg = false;
for (size_t i = 0; i < split_lengths_vec.size(); ++i) {
if (split_lengths_vec[i] != -1) {
length += split_lengths_vec[i];
} else {
if (has_one_neg) {
return errors::InvalidArgument("size_splits can only have one -1");
} else {
idx = i;
has_one_neg = true;
}
}
}
// Size splits must sum to the dimension of value along split_dim
if (idx > 0) {
split_lengths_vec[idx] = shape[split_dim] - length;
}
if ((!has_one_neg && length != shape[split_dim]) ||
(has_one_neg && split_lengths_vec[idx] < 0)) {
return errors::InvalidArgument(
"The length of size_splits must sum to the value of the dimension "
"along split_dim");
}
ng_split_length = ConstructNgNode<opset::Constant>(
op->name(), ng::element::i32, ng::Shape{split_lengths_vec.size()},
split_lengths_vec);
if (split_lengths_vec.size() != 1) {
auto ng_split = make_shared<opset::VariadicSplit>(ng_input, ng_split_dim,
ng_split_length);
for (size_t i = 0; i < split_lengths_vec.size(); ++i) {
auto out = ng_split->output(i);
Builder::SetTracingInfo(op->name(), out);
SaveNgOp(ng_op_map, op->name(), out);
}
} else {
SaveNgOp(ng_op_map, op->name(), ng_input);
}
return Status::OK();
}