void Symbol::Compose()

in nnvm/src/core/symbolic.cc [275:465]


void Symbol::Compose(const array_view<const Symbol*>& args,
                     const std::unordered_map<std::string, const Symbol*>& kwargs,
                     const std::string& name) {
  static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
  static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
  static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");

  // The arguments that contain graphs.
  Node* n = outputs[0].node.get();
  FInputGraph fng = fgraph.get(n->op(), nullptr);
  std::vector<uint32_t> garg_idx;
  if (fng != nullptr) garg_idx = fng(n->attrs);

  // The names of the arguments that contain graphs.
  FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
  auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
  std::vector<std::string> garg_names(garg_idx.size());
  for (size_t i = 0; i < garg_idx.size(); i++) {
    size_t idx = garg_idx[i];
    if (idx < arg_names.size()) garg_names[i] = arg_names[idx];
  }

  // parameter check.
  for (size_t i = 0; i < args.size(); ++i) {
    // If the argument isn't a graph, it should have only one output.
    if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
      CHECK_EQ(args[i]->outputs.size(), 1U)
          << "Argument " << i << " is a tuple, single value is required";
  }
  for (const auto& kv : kwargs) {
    if (garg_names.empty() ||
        std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
      CHECK_EQ(kv.second->outputs.size(), 1U)
          << "Keyword Argument " << kv.first << " is a tuple, single value is required";
  }
  // assign new name
  if (!name.empty()) outputs[0].node->attrs.name = name;

  // Atomic functor composition.
  if (IsAtomic(outputs)) {
    uint32_t n_req = n->num_inputs();
    std::vector<const Symbol*> arg_vec(args.begin(), args.end());
    std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
    // If one of the input arguments is a graph, we need to remove it from the
    // list.
    if (fng != nullptr) {
      std::vector<uint32_t> idxes = fng(n->attrs);
      for (auto idx : idxes) {
        const Symbol* sym;
        if (idx < arg_vec.size()) {
          sym = arg_vec[idx];
        } else {
          auto it = kwarg_map.find(arg_names[idx]);
          CHECK(it != kwarg_map.end());
          sym = it->second;
          kwarg_map.erase(it);
        }
        if (n_req != kVarg) n_req--;
        n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
      }
      // Because idxes does not contain duplicates, the loop below functions well.
      // Note that it is as slow as O(|idxes| * |args|),
      // but given that |idxes| is small, it is just fine
      sort(std::begin(idxes), std::end(idxes), std::greater<int>());
      for (auto idx : idxes) {
        if (idx < arg_vec.size()) {
          arg_vec.erase(arg_vec.begin() + idx);
        }
        arg_names.erase(arg_names.begin() + idx);
      }
    }

    if (n_req != kVarg) {
      n->inputs.resize(n_req);
      CHECK_LE(arg_vec.size(), n_req)
          << "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size();
      for (size_t i = 0; i < arg_vec.size(); ++i) {
        n->inputs[i] = arg_vec[i]->outputs[0];
      }
      // switch to keyword argument matching
      if (arg_vec.size() != n_req) {
        if (arg_names.size() != n_req) {
          LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
        }
        size_t nmatched = 0;
        for (size_t i = arg_vec.size(); i < n_req; ++i) {
          auto it = kwarg_map.find(arg_names[i]);
          if (it != kwarg_map.end() && it->first == arg_names[i]) {
            n->inputs[i] = it->second->outputs[0];
            ++nmatched;
          } else {
            n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
            // copy attribute of parent over automatically created variables
            n->inputs[i].node->attrs.dict = n->attrs.dict;
          }
        }

        if (nmatched != kwarg_map.size()) {
          n->inputs.clear();
          std::vector<std::string> keys = GetKeys(kwarg_map);
          array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
                                       dmlc::BeginPtr(arg_names) + arg_names.size());
          KeywordArgumentMismatch("Symbol.Compose", keys, view);
        }
      }
    } else {
      CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
      n->inputs.reserve(arg_vec.size());
      for (const Symbol* s : arg_vec) {
        n->inputs.push_back(s->outputs[0]);
      }
    }
    UpdateNodeVersion(n);

    FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr);
    if (fn != nullptr) {
      for (size_t i = 0; i < n->inputs.size(); ++i) {
        if (n->inputs[i].node->is_variable()) {
          fn(n->attrs, n->inputs[i].node, i);
        }
      }
    }
  } else {
    // general composition
    CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now";
    size_t nmatched = 0;
    size_t arg_counter = 0;
    std::unordered_map<Node*, const NodeEntry*> replace_map;
    // replace map stores the existing replacement plan for arguments node
    auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs,
                             &replace_map](const ObjectPtr& node) {
      if (node->is_variable()) {
        if (arg_counter < args.size()) {
          replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
          ++arg_counter;
        } else {
          // match kwargs
          auto kit = kwargs.find(node->attrs.name);
          if (kit != kwargs.end()) {
            replace_map[node.get()] = &(kit->second->outputs[0]);
            ++nmatched;
          }
        }
      }
    };
    DFSVisit(this->outputs, find_replace_map);

    if (nmatched == kwargs.size() && arg_counter <= args.size()) {
      std::vector<Node*> update_nodes;
      std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
      auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) {
        // visit all the childs, find possible replacement
        bool repl = false;
        for (size_t i = 0; i < node->inputs.size(); ++i) {
          NodeEntry* e = &(node->inputs[i]);
          if (e->node->is_variable()) {
            auto iter = replace_map.find(e->node.get());
            if (iter != replace_map.end()) {
              replace_plan.push_back(std::make_pair(e, iter->second));
              repl = true;
            }
          }
        }
        if (repl) update_nodes.push_back(node.get());
      };
      DFSVisit(this->outputs, find_replace_plan);

      for (const auto& kv : replace_plan) {
        *(kv.first) = *(kv.second);
      }
      for (Node* n : update_nodes) {
        UpdateNodeVersion(n);
      }
    } else {
      std::vector<std::string> keys = GetKeys(kwargs);
      std::vector<std::string> arg_names = ListInputNames(kAll);
      array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
                                   dmlc::BeginPtr(arg_names) + arg_names.size());
      KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
    }

    // update outputs in case the composed variable is part of outputs.
    for (size_t i = 0; i < outputs.size(); ++i) {
      if (outputs[i].node->is_variable()) {
        CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
        const auto it = kwargs.find(outputs[i].node->attrs.name);
        if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
      }
    }
  }
}