std::vector Imperative::Backward()

in src/imperative/imperative.cc [445:757]


std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
                                           const std::vector<NDArray*>& ograds,
                                           const std::vector<NDArray*>& variables,
                                           bool is_train,
                                           bool retain_graph,
                                           bool create_graph) {
  using namespace nnvm;
  using namespace imperative;
  static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
  static const Op* copy_op = Op::Get("_copy");

  // Construct forward graph
  Graph graph;
  graph.outputs.reserve(outputs.size());
  for (const auto& i : outputs) {
    CHECK(!AGInfo::IsNone(*i))
        << "Cannot differentiate node because it is not in a computational graph. "
        << "You need to set is_recording to true or use autograd.record() to save "
        << "computational graphs for backward. If you want to differentiate the same "
        << "graph twice, you need to pass retain_graph=True to backward.";
    graph.outputs.emplace_back(i->autograd_entry_);
  }
  size_t num_forward_outputs = graph.outputs.size();

  // Prepare head gradients
  std::vector<NodeEntry> ograd_entries;
  ograd_entries.reserve(ograds.size());
  for (size_t i = 0; i < outputs.size(); ++i) {
    nnvm::ObjectPtr np = Node::Create();
    np->attrs.name     = "_head_grad_" + std::to_string(i);
    ograd_entries.emplace_back(NodeEntry{np, 0, 0});
    AGInfo& info = AGInfo::Create(ograd_entries.back().node);
    info.ctx     = outputs[i]->ctx();
    if (ograds[i] != nullptr) {
      info.outputs.emplace_back(*ograds[i]);
    } else {
      info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), true, outputs[i]->dtype());
      if (info.outputs.back().shape().Size() != 0) {
        info.outputs.back() = static_cast<real_t>(1.0);
      }
    }
  }

  // Get gradient graph
  Symbol sym;
  sym.outputs = graph.outputs;
  std::vector<NodeEntry> xs;
  std::vector<NDArray*> x_grads;
  std::vector<OpReqType> x_reqs;
  if (variables.size()) {
    xs.reserve(variables.size());
    x_grads.reserve(variables.size());
    x_reqs.reserve(variables.size());
    for (size_t i = 0; i < variables.size(); ++i) {
      CHECK(!AGInfo::IsNone(*variables[i]) &&
            AGInfo::IsVariable(variables[i]->autograd_entry_.node))
          << "Cannot differentiate with respect to the " << i + 1 << "-th variable"
          << " because it does not require gradient.";
      xs.emplace_back(variables[i]->autograd_entry_);
      x_grads.push_back(new NDArray());
      x_reqs.push_back(kWriteTo);
    }
  } else {
    std::vector<ObjectPtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
    xs.reserve(args.size());
    x_grads.reserve(args.size());
    x_reqs.reserve(args.size());
    for (const auto& i : args) {
      AGInfo& info = AGInfo::Get(i);
      if (info.grad_req == kNullOp)
        continue;
      xs.emplace_back(NodeEntry{i, 0, 0});
      x_grads.push_back(&info.out_grads[0]);
      x_reqs.push_back(info.grad_req);
      info.fresh_out_grad = true;
    }
    CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
  }
  std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
  std::vector<NodeEntry> us;
  us.reserve(nleaf_vars.size());
  for (const auto& i : nleaf_vars) {
    us.emplace_back(NodeEntry{i, 0, 0});
  }

  Graph g_graph = pass::MXGradient(graph,
                                   graph.outputs,
                                   xs,
                                   ograd_entries,
                                   mxnet::AggregateGradient,
                                   nullptr,
                                   zero_ops,
                                   "_copy",
                                   ShapeVector(),
                                   DTypeVector(),
                                   us);
  CHECK_EQ(g_graph.outputs.size(), xs.size());
  for (const auto& e : g_graph.outputs) {
    if (e.node->op() == nullptr) {
      auto node      = Node::Create();
      node->attrs.op = copy_op;
      node->inputs.push_back(e);
      graph.outputs.emplace_back(std::move(node));
    } else {
      graph.outputs.push_back(e);
    }
  }
  const auto& idx = graph.indexed_graph();
  // get number of nodes used in forward pass
  size_t num_forward_nodes   = 0;
  size_t num_forward_entries = 0;
  for (size_t i = 0; i < num_forward_outputs; ++i) {
    num_forward_nodes =
        std::max(num_forward_nodes, static_cast<size_t>(idx.outputs()[i].node_id + 1));
    num_forward_entries =
        std::max(num_forward_entries, static_cast<size_t>(idx.entry_id(idx.outputs()[i])) + 1);
  }

  // Allocate buffer
  std::vector<NDArray> buff(idx.num_node_entries());
  std::vector<uint32_t> ref_count(buff.size(), 0);
  std::vector<OpStatePtr> states;
  std::vector<NDArray*> arrays;
  arrays.reserve(buff.size());
  for (auto& buffered_array : buff) {
    arrays.push_back(&buffered_array);
  }
  if (create_graph) {
    states.resize(num_forward_nodes);
    nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) {
      AGInfo& info                 = AGInfo::Get(n);
      states[idx.node_id(n.get())] = info.state;
      for (uint32_t i = 0; i < info.outputs.size(); ++i) {
        CHECK(idx.exist(n.get()));
        size_t nid                = idx.node_id(n.get());
        size_t eid                = idx.entry_id(nid, i);
        buff[eid]                 = info.outputs[i];
        buff[eid].autograd_entry_ = NodeEntry{n, i, 0};
        ref_count[eid]            = 1;
      }
    });
    for (auto& ograd_entry : ograd_entries) {
      AGInfo& info = AGInfo::Get(ograd_entry.node);
      if (!idx.exist(ograd_entry.node.get()))
        continue;
      size_t eid                = idx.entry_id(ograd_entry);
      buff[eid]                 = info.outputs[0];
      buff[eid].autograd_entry_ = ograd_entry;
    }
  } else {
    states.reserve(num_forward_nodes);
    for (size_t i = 0; i < num_forward_nodes; ++i) {
      const AGInfo& info = dmlc::get<AGInfo>(idx[i].source->info);
      states.emplace_back(info.state);
      for (size_t j = 0; j < info.outputs.size(); ++j) {
        size_t eid  = idx.entry_id(i, j);
        arrays[eid] = const_cast<NDArray*>(&(info.outputs[j]));

        if (retain_graph || info.grad_req != kNullOp)
          ref_count[eid] = 1;
      }
    }
    for (auto& ograd_entry : ograd_entries) {
      if (!idx.exist(ograd_entry.node.get()))
        continue;
      AGInfo& info                      = AGInfo::Get(ograd_entry.node);
      arrays[idx.entry_id(ograd_entry)] = &info.outputs[0];
    }
  }
  for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) {
    size_t eid     = idx.entry_id(graph.outputs[i]);
    arrays[eid]    = x_grads[i - num_forward_outputs];
    ref_count[eid] = 1;
  }
  const std::vector<NodeEntry>& us_grads = g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
  CHECK_EQ(us_grads.size(), us.size())
      << "Size of queried nleaf_vars and size of their gradients don't match.";
  for (size_t i = 0; i < us_grads.size(); i++) {
    size_t eid   = idx.entry_id(us_grads[i]);
    AGInfo& info = AGInfo::Get(us[i].node);
    if (arrays[eid]->dtype_ == -1) {
      arrays[eid] = &info.out_grads[0];
    } else {
      info.out_grads[0] = *arrays[eid];
    }
    ref_count[eid] = 1;
  }

  // Assign context
  auto vctx = PlaceDevice(idx);

  // Infer shape type
  {
    std::pair<uint32_t, uint32_t> node_range, entry_range;
    node_range  = {num_forward_nodes, idx.num_nodes()};
    entry_range = {num_forward_entries, idx.num_node_entries()};

    ShapeVector shapes;
    shapes.reserve(idx.num_node_entries());
    bool contain_unknown = false;
    for (const auto& i : arrays)
      shapes.emplace_back(i->shape());
    CheckAndInferShape(&graph, std::move(shapes), false, node_range, entry_range, &contain_unknown);

    DTypeVector dtypes;
    dtypes.reserve(idx.num_node_entries());
    for (const auto& i : arrays)
      dtypes.emplace_back(i->dtype());
    CheckAndInferType(&graph, std::move(dtypes), false, node_range, entry_range);

    StorageTypeVector stypes;
    stypes.reserve(idx.num_node_entries());
    for (const auto& i : arrays)
      stypes.emplace_back(i->storage_type());
    exec::DevMaskVector dev_mask;
    dev_mask.reserve(idx.num_nodes());
    for (const auto& i : vctx)
      dev_mask.emplace_back(i.dev_mask());
    CheckAndInferStorageType(
        &graph, std::move(dev_mask), std::move(stypes), false, node_range, entry_range);
  }

  // Calculate ref count
  for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
    for (const auto& j : idx[i].inputs) {
      ++ref_count[idx.entry_id(j)];
    }
  }

  // Assign reqs
  std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
  for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) {
    if (ref_count[i] == 0)
      array_reqs[i] = kNullOp;
  }
  for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) {
    size_t eid      = idx.entry_id(idx.outputs()[i]);
    array_reqs[eid] = x_reqs[i - num_forward_outputs];
  }
  for (size_t i = 0; i < us_grads.size(); i++) {
    size_t eid      = idx.entry_id(us_grads[i]);
    AGInfo& info    = AGInfo::Get(us[i].node);
    array_reqs[eid] = info.grad_req;
  }

  const auto& shapes         = graph.GetAttr<mxnet::ShapeVector>("shape");
  const auto& dtypes         = graph.GetAttr<DTypeVector>("dtype");
  const auto& stypes         = graph.GetAttr<StorageTypeVector>("storage_type");
  const auto& dispatch_modes = graph.GetAttr<DispatchModeVector>("dispatch_mode");

  for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
    auto num_outputs = idx[i].source->num_outputs();
    for (size_t j = 0; j < num_outputs; ++j) {
      auto eid = idx.entry_id(i, j);
      if (arrays[eid]->is_none())
        arrays[eid]->ReInit(
            static_cast<NDArrayStorageType>(stypes[eid]), shapes[eid], vctx[i], dtypes[eid]);
    }
  }

  for (size_t nid = num_forward_nodes; nid < idx.num_nodes(); ++nid) {
    const nnvm::NodeAttrs& attrs = idx[nid].source->attrs;
    for (size_t oid = 0; oid < idx[nid].source->num_outputs(); ++oid) {
      size_t eid = idx.entry_id(nid, oid);
      arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
    }
  }  // for (nid ∈ [num_forward_nodes, idx.num_nodes()))

  if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
    common::LogMemoryPlan(graph);
  }

  // Execution

  bool prev_recording = set_is_recording(create_graph);
  bool prev_training  = set_is_training(is_train);
  int prev_bulk_size  = Engine::Get()->set_bulk_size(backward_bulk_size_);

  try {
    RunGraph(retain_graph,
             idx,
             arrays,
             num_forward_nodes,
             idx.num_nodes(),
             std::move(array_reqs),
             std::move(ref_count),
             &states,
             dispatch_modes,
             is_recording());
  } catch (const dmlc::Error& e) {
    Engine::Get()->set_bulk_size(prev_bulk_size);
    set_is_recording(prev_recording);
    set_is_training(prev_training);
    throw e;
  }

  Engine::Get()->set_bulk_size(prev_bulk_size);
  set_is_recording(prev_recording);
  set_is_training(prev_training);

  // Clear history
  if (!retain_graph) {
    nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) {
      AGInfo::Clear(n);
      n->inputs.clear();
    });
  }

  if (variables.size()) {
    return x_grads;
  }
  return {};
}