in nnvm/src/core/symbolic.cc [275:465]
void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");
// The arguments that contain graphs.
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
if (fng != nullptr) garg_idx = fng(n->attrs);
// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
if (idx < arg_names.size()) garg_names[i] = arg_names[idx];
}
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
if (garg_names.empty() ||
std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;
// Atomic functor composition.
if (IsAtomic(outputs)) {
uint32_t n_req = n->num_inputs();
std::vector<const Symbol*> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
const Symbol* sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
} else {
auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end());
sym = it->second;
kwarg_map.erase(it);
}
if (n_req != kVarg) n_req--;
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
// Because idxes does not contain duplicates, the loop below functions well.
// Note that it is as slow as O(|idxes| * |args|),
// but given that |idxes| is small, it is just fine
sort(std::begin(idxes), std::end(idxes), std::greater<int>());
for (auto idx : idxes) {
if (idx < arg_vec.size()) {
arg_vec.erase(arg_vec.begin() + idx);
}
arg_names.erase(arg_names.begin() + idx);
}
}
if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(arg_vec.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
// switch to keyword argument matching
if (arg_vec.size() != n_req) {
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = arg_vec.size(); i < n_req; ++i) {
auto it = kwarg_map.find(arg_names[i]);
if (it != kwarg_map.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
// copy attribute of parent over automatically created variables
n->inputs[i].node->attrs.dict = n->attrs.dict;
}
}
if (nmatched != kwarg_map.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwarg_map);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(arg_vec.size());
for (const Symbol* s : arg_vec) {
n->inputs.push_back(s->outputs[0]);
}
}
UpdateNodeVersion(n);
FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr);
if (fn != nullptr) {
for (size_t i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i].node->is_variable()) {
fn(n->attrs, n->inputs[i].node, i);
}
}
}
} else {
// general composition
CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now";
size_t nmatched = 0;
size_t arg_counter = 0;
std::unordered_map<Node*, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs,
&replace_map](const ObjectPtr& node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
++arg_counter;
} else {
// match kwargs
auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second->outputs[0]);
++nmatched;
}
}
}
};
DFSVisit(this->outputs, find_replace_map);
if (nmatched == kwargs.size() && arg_counter <= args.size()) {
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
NodeEntry* e = &(node->inputs[i]);
if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) {
replace_plan.push_back(std::make_pair(e, iter->second));
repl = true;
}
}
}
if (repl) update_nodes.push_back(node.get());
};
DFSVisit(this->outputs, find_replace_plan);
for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second);
}
for (Node* n : update_nodes) {
UpdateNodeVersion(n);
}
} else {
std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListInputNames(kAll);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
}
// update outputs in case the composed variable is part of outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
if (outputs[i].node->is_variable()) {
CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
const auto it = kwargs.find(outputs[i].node->attrs.name);
if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
}
}
}
}