in tensorflow/tensorflow/core/grappler/optimizers/gemm_optimizer.cc [283:595]
bool FuseBinaryOpsAfterUnpack(Graph* graph) {
static int count = 0;
VLOG(2) << "FuseBinaryOpsAfterUnpack";
bool changed = false;
std::vector<Node*> nodes(graph->num_nodes());
int i = 0;
for (Node* node : graph->nodes()) {
nodes[i++] = node;
}
const std::unordered_set<string> binary_op_set = GetBinaryOps();
for (Node* node : nodes) {
if (!graph->IsValidNode(node).ok()) continue;
if (node->type_string() != "Unpack") continue;
// group ops based on its inputs
std::vector<Node*> binary_ops;
std::map<string, std::vector<Node*>> binary_ops_m;
std::vector<std::pair<Node*, int>> common_extra_input;
const Node* another_node = nullptr;
bool can_fuse = true;
string binary_type;
for (Node* out : node->out_nodes()) {
string out_type = out->type_string();
if (binary_type.empty()) {
binary_type = out->type_string();
if (binary_op_set.find(binary_type) == binary_op_set.end()) {
can_fuse = false;
break;
}
}
if (out->type_string() != binary_type) {
can_fuse = false;
break;
}
if (out->num_inputs() < 2) {
can_fuse = false;
break;
}
bool has_same_another_input = true;
bool another_input_from_unpack_or_const = true;
bool has_common_extra_input = true;
for (int i = 0; i < out->num_inputs(); ++i) {
const Edge* edge = nullptr;
out->input_edge(i, &edge);
Node* n = edge->src();
if (i < 2) {
if (n != node) {
if (another_node == nullptr) {
string type = n->type_string();
if (type == "Unpack" || type == "Const") {
another_node = n;
} else {
another_input_from_unpack_or_const = false;
break;
}
} else {
if (n != another_node) {
has_same_another_input = false;
}
}
}
} else {
size_t vec_idx = i - 2;
if (vec_idx >= common_extra_input.size()) {
common_extra_input.emplace_back(std::make_pair(n, edge->src_output()));
} else {
if (n != common_extra_input[vec_idx].first ||
edge->src_output() != common_extra_input[vec_idx].second) {
has_common_extra_input = false;
}
}
}
}
if (has_same_another_input && another_input_from_unpack_or_const && has_common_extra_input)
binary_ops.push_back(out);
}
if (!can_fuse || binary_ops.size() < 2) continue;
for (Node* b : binary_ops) {
for (int i = 0; i < std::min(b->num_inputs(), 2); ++i) {
const Node* n = nullptr;
b->input_node(i, &n);
if (n != node) {
string key;
if (another_node->type_string() == "Unpack") {
key = n->name();
} else {
TensorShapeProto s = n->def().attr().at("value").
tensor().tensor_shape();
for (int i = 0; i < s.dim_size(); i++) {
key += std::to_string(s.dim(i).size()) + ",";
}
}
if (binary_ops_m.find(key) != binary_ops_m.end()) {
binary_ops_m[key].push_back(b);
} else {
std::vector<Node*> ins;
ins.push_back(b);
binary_ops_m[key] = ins;
}
break;
}
}
}
VLOG(2) << "FuseBinaryOpsAfterUnpack: found pattern";
// Fuse binary_ops in each group.
// For each group, do:
// (1) add two Pack nodes to stack inputs on both sides respectively,
// (2) add a new node to replace old binary_ops, and
// (3) add a Unpack node to split result.
DataType dtype = node->output_type(0);
std::map<string, std::vector<Node*>>::iterator iter;
iter = binary_ops_m.begin();
while (iter != binary_ops_m.end()) {
std::vector<Node *> *binary_ops_group = &(iter->second);
if (binary_ops_group->size() < 2) iter++;
std::sort(binary_ops_group->begin(), binary_ops_group->end(),
[node](Node *a, Node *b) {
const Edge *e = nullptr;
a->input_edge(0, &e);
int a_src_output = e->src_output();
b->input_edge(0, &e);
int b_src_output = e->src_output();
return a_src_output < b_src_output;
});
std::vector<const Edge *> inputs[2];
VLOG(2) << "the following ops are fused: ";
for (Node *b : *binary_ops_group) {
for (int i = 0; i < 2; ++i) {
VLOG(2) << b->name();
const Edge *e = nullptr;
b->input_edge(i, &e);
inputs[e->dst_input()].push_back(e);
}
}
// Add two Pack nodes to group on two sides, respectively
Node *packs[2];
string pack_names[2];
string prefix = "GemmOptimizer/FuseBinaryOpsAfterUnpack/" +
std::to_string(count++);
pack_names[0] = prefix + "/Pack_0";
pack_names[1] = prefix + "/Pack_1";
Status status;
for (int i = 0; i < 2; i++) {
std::vector<NodeDefBuilder::NodeOut> pack_inputs;
for (const Edge *e : inputs[i]) {
string s = e->src()->name();
pack_inputs.emplace_back(s, e->src_output(), dtype);
}
NodeDefBuilder pack_builder(pack_names[i], "Pack");
pack_builder.Input(pack_inputs);
NodeDef pack_node;
status =
pack_builder
.Attr("N", (int) inputs[i].size())
.Attr("T", dtype)
.Attr("axis", 0)
.Finalize(&pack_node);
if (!status.ok()) {
LOG(ERROR) << "Pack node construction failed with" << status;
return false;
}
pack_node.set_device(inputs[i][0]->src()->def().device());
packs[i] = graph->AddNode(pack_node, &status);
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return false;
}
packs[i]->set_assigned_device_name(
inputs[i][0]->src()->assigned_device_name());
for (unsigned int j = 0; j < inputs[i].size(); j++) {
graph->AddEdge(inputs[i][j]->src(), inputs[i][j]->src_output(),
packs[i], j);
}
}
// Add a new BatchMatMulV2
std::vector<NodeDefBuilder::NodeOut> binary_op_inputs;
binary_op_inputs.emplace_back(pack_names[0], 0, dtype);
binary_op_inputs.emplace_back(pack_names[1], 0, dtype);
if (!common_extra_input.empty()) {
for (size_t i = 0; i < common_extra_input.size(); ++i) {
const Node *extra_node = common_extra_input[i].first;
int extra_src_idx = common_extra_input[i].second;
binary_op_inputs.emplace_back(extra_node->name(), extra_src_idx, extra_node->output_type(extra_src_idx));
}
}
string binary_op_name = prefix;
string type = (*binary_ops_group)[0]->type_string();
string new_type;
if (type == "MatMul" ||
type == "BatchMatMul" ||
type == "BatchMatMulV2") {
binary_op_name += "/BatchMatMulV2";
new_type = "BatchMatMulV2";
} else if (type == "IndicatorMatMul") {
binary_op_name += "/ParallelIndicatorMatMul";
new_type = "ParallelIndicatorMatMul";
} else {
binary_op_name += "/" + type;
new_type = type;
}
NodeDefBuilder binary_op_builder(binary_op_name, new_type);
for (size_t i = 0; i < binary_op_inputs.size(); ++i) {
binary_op_builder.Input(binary_op_inputs[i]);
}
NodeDef binary_op_node;
bool transpose_a = false;
bool transpose_b = false;
if (type == "MatMul") {
transpose_a = (*binary_ops_group)[0]->def().attr().at("transpose_a").b();
transpose_b = (*binary_ops_group)[0]->def().attr().at("transpose_b").b();
} else if (type == "BatchMatMul" || type == "BatchMatMulV2" || type == "IndicatorMatMul") {
transpose_a = (*binary_ops_group)[0]->def().attr().at("adj_x").b();
transpose_b = (*binary_ops_group)[0]->def().attr().at("adj_y").b();
}
if (type == "MatMul" ||
type == "BatchMatMul" ||
type == "BatchMatMulV2") {
status =
binary_op_builder
.Attr("adj_x", transpose_a)
.Attr("adj_y", transpose_b)
.Attr("T", dtype)
.Finalize(&binary_op_node);
} else if (type == "IndicatorMatMul") {
status =
binary_op_builder
.Attr("adj_x", transpose_a)
.Attr("adj_y", transpose_b)
.Attr("parallel_num", (int) (*binary_ops_group).size())
.Attr("T", dtype)
.Finalize(&binary_op_node);
} else {
status =
binary_op_builder
.Attr("T", dtype)
.Finalize(&binary_op_node);
}
if (!status.ok()) {
LOG(ERROR) << "BatchMatMulV2 node construction failed with" << status;
return false;
}
binary_op_node.set_device((*binary_ops_group)[0]->def().device());
Node* binary_op = graph->AddNode(binary_op_node, &status);
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return false;
}
binary_op->set_assigned_device_name((*binary_ops_group)[0]->
assigned_device_name());
graph->AddEdge(packs[0], 0, binary_op, 0);
graph->AddEdge(packs[1], 0, binary_op, 1);
// Add an Unpack node to split result
string unpack_name = prefix + "/Unpack" ;
NodeDefBuilder::NodeOut unpack_input(binary_op_name, 0, dtype);
NodeDefBuilder unpack_builder(unpack_name, "Unpack");
unpack_builder.Input(unpack_input);
NodeDef unpack_node;
status =
unpack_builder
.Attr("num", (int)(*binary_ops_group).size())
.Attr("T", dtype)
.Attr("axis", 0)
.Finalize(&unpack_node);
if (!status.ok()) {
LOG(ERROR) << "Unpack node construction failed with" << status;
return false;
}
unpack_node.set_device((*binary_ops_group)[0]->def().device());
Node* unpack = graph->AddNode(unpack_node, &status);
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return false;
}
unpack->set_assigned_device_name((*binary_ops_group)[0]->
assigned_device_name());
graph->AddEdge(binary_op, 0, unpack, 0);
// Add edges to forward split results to nodes after original binary_ops,
// and remove original binary_ops
int index = 0;
for (Node* b : *binary_ops_group) {
std::vector<Node*> dst_nodes;
std::vector<int> dst_inputs;
for (const Edge* e : b->out_edges()) {
dst_nodes.push_back(e->dst());
dst_inputs.push_back(e->dst_input());
}
for (unsigned int i = 0; i < dst_nodes.size(); i++) {
graph->UpdateEdge(unpack, index, dst_nodes[i], dst_inputs[i]);
}
graph->RemoveNode(b);
index++;
}
changed = true;
iter++;
}
}
return changed;
}