in src/imperative/imperative.cc [445:757]
std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
const std::vector<NDArray*>& variables,
bool is_train,
bool retain_graph,
bool create_graph) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const Op* copy_op = Op::Get("_copy");
// Construct forward graph
Graph graph;
graph.outputs.reserve(outputs.size());
for (const auto& i : outputs) {
CHECK(!AGInfo::IsNone(*i))
<< "Cannot differentiate node because it is not in a computational graph. "
<< "You need to set is_recording to true or use autograd.record() to save "
<< "computational graphs for backward. If you want to differentiate the same "
<< "graph twice, you need to pass retain_graph=True to backward.";
graph.outputs.emplace_back(i->autograd_entry_);
}
size_t num_forward_outputs = graph.outputs.size();
// Prepare head gradients
std::vector<NodeEntry> ograd_entries;
ograd_entries.reserve(ograds.size());
for (size_t i = 0; i < outputs.size(); ++i) {
nnvm::ObjectPtr np = Node::Create();
np->attrs.name = "_head_grad_" + std::to_string(i);
ograd_entries.emplace_back(NodeEntry{np, 0, 0});
AGInfo& info = AGInfo::Create(ograd_entries.back().node);
info.ctx = outputs[i]->ctx();
if (ograds[i] != nullptr) {
info.outputs.emplace_back(*ograds[i]);
} else {
info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), true, outputs[i]->dtype());
if (info.outputs.back().shape().Size() != 0) {
info.outputs.back() = static_cast<real_t>(1.0);
}
}
}
// Get gradient graph
Symbol sym;
sym.outputs = graph.outputs;
std::vector<NodeEntry> xs;
std::vector<NDArray*> x_grads;
std::vector<OpReqType> x_reqs;
if (variables.size()) {
xs.reserve(variables.size());
x_grads.reserve(variables.size());
x_reqs.reserve(variables.size());
for (size_t i = 0; i < variables.size(); ++i) {
CHECK(!AGInfo::IsNone(*variables[i]) &&
AGInfo::IsVariable(variables[i]->autograd_entry_.node))
<< "Cannot differentiate with respect to the " << i + 1 << "-th variable"
<< " because it does not require gradient.";
xs.emplace_back(variables[i]->autograd_entry_);
x_grads.push_back(new NDArray());
x_reqs.push_back(kWriteTo);
}
} else {
std::vector<ObjectPtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
xs.reserve(args.size());
x_grads.reserve(args.size());
x_reqs.reserve(args.size());
for (const auto& i : args) {
AGInfo& info = AGInfo::Get(i);
if (info.grad_req == kNullOp)
continue;
xs.emplace_back(NodeEntry{i, 0, 0});
x_grads.push_back(&info.out_grads[0]);
x_reqs.push_back(info.grad_req);
info.fresh_out_grad = true;
}
CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
}
std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
std::vector<NodeEntry> us;
us.reserve(nleaf_vars.size());
for (const auto& i : nleaf_vars) {
us.emplace_back(NodeEntry{i, 0, 0});
}
Graph g_graph = pass::MXGradient(graph,
graph.outputs,
xs,
ograd_entries,
mxnet::AggregateGradient,
nullptr,
zero_ops,
"_copy",
ShapeVector(),
DTypeVector(),
us);
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
if (e.node->op() == nullptr) {
auto node = Node::Create();
node->attrs.op = copy_op;
node->inputs.push_back(e);
graph.outputs.emplace_back(std::move(node));
} else {
graph.outputs.push_back(e);
}
}
const auto& idx = graph.indexed_graph();
// get number of nodes used in forward pass
size_t num_forward_nodes = 0;
size_t num_forward_entries = 0;
for (size_t i = 0; i < num_forward_outputs; ++i) {
num_forward_nodes =
std::max(num_forward_nodes, static_cast<size_t>(idx.outputs()[i].node_id + 1));
num_forward_entries =
std::max(num_forward_entries, static_cast<size_t>(idx.entry_id(idx.outputs()[i])) + 1);
}
// Allocate buffer
std::vector<NDArray> buff(idx.num_node_entries());
std::vector<uint32_t> ref_count(buff.size(), 0);
std::vector<OpStatePtr> states;
std::vector<NDArray*> arrays;
arrays.reserve(buff.size());
for (auto& buffered_array : buff) {
arrays.push_back(&buffered_array);
}
if (create_graph) {
states.resize(num_forward_nodes);
nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) {
AGInfo& info = AGInfo::Get(n);
states[idx.node_id(n.get())] = info.state;
for (uint32_t i = 0; i < info.outputs.size(); ++i) {
CHECK(idx.exist(n.get()));
size_t nid = idx.node_id(n.get());
size_t eid = idx.entry_id(nid, i);
buff[eid] = info.outputs[i];
buff[eid].autograd_entry_ = NodeEntry{n, i, 0};
ref_count[eid] = 1;
}
});
for (auto& ograd_entry : ograd_entries) {
AGInfo& info = AGInfo::Get(ograd_entry.node);
if (!idx.exist(ograd_entry.node.get()))
continue;
size_t eid = idx.entry_id(ograd_entry);
buff[eid] = info.outputs[0];
buff[eid].autograd_entry_ = ograd_entry;
}
} else {
states.reserve(num_forward_nodes);
for (size_t i = 0; i < num_forward_nodes; ++i) {
const AGInfo& info = dmlc::get<AGInfo>(idx[i].source->info);
states.emplace_back(info.state);
for (size_t j = 0; j < info.outputs.size(); ++j) {
size_t eid = idx.entry_id(i, j);
arrays[eid] = const_cast<NDArray*>(&(info.outputs[j]));
if (retain_graph || info.grad_req != kNullOp)
ref_count[eid] = 1;
}
}
for (auto& ograd_entry : ograd_entries) {
if (!idx.exist(ograd_entry.node.get()))
continue;
AGInfo& info = AGInfo::Get(ograd_entry.node);
arrays[idx.entry_id(ograd_entry)] = &info.outputs[0];
}
}
for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) {
size_t eid = idx.entry_id(graph.outputs[i]);
arrays[eid] = x_grads[i - num_forward_outputs];
ref_count[eid] = 1;
}
const std::vector<NodeEntry>& us_grads = g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
CHECK_EQ(us_grads.size(), us.size())
<< "Size of queried nleaf_vars and size of their gradients don't match.";
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
if (arrays[eid]->dtype_ == -1) {
arrays[eid] = &info.out_grads[0];
} else {
info.out_grads[0] = *arrays[eid];
}
ref_count[eid] = 1;
}
// Assign context
auto vctx = PlaceDevice(idx);
// Infer shape type
{
std::pair<uint32_t, uint32_t> node_range, entry_range;
node_range = {num_forward_nodes, idx.num_nodes()};
entry_range = {num_forward_entries, idx.num_node_entries()};
ShapeVector shapes;
shapes.reserve(idx.num_node_entries());
bool contain_unknown = false;
for (const auto& i : arrays)
shapes.emplace_back(i->shape());
CheckAndInferShape(&graph, std::move(shapes), false, node_range, entry_range, &contain_unknown);
DTypeVector dtypes;
dtypes.reserve(idx.num_node_entries());
for (const auto& i : arrays)
dtypes.emplace_back(i->dtype());
CheckAndInferType(&graph, std::move(dtypes), false, node_range, entry_range);
StorageTypeVector stypes;
stypes.reserve(idx.num_node_entries());
for (const auto& i : arrays)
stypes.emplace_back(i->storage_type());
exec::DevMaskVector dev_mask;
dev_mask.reserve(idx.num_nodes());
for (const auto& i : vctx)
dev_mask.emplace_back(i.dev_mask());
CheckAndInferStorageType(
&graph, std::move(dev_mask), std::move(stypes), false, node_range, entry_range);
}
// Calculate ref count
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) {
++ref_count[idx.entry_id(j)];
}
}
// Assign reqs
std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0)
array_reqs[i] = kNullOp;
}
for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) {
size_t eid = idx.entry_id(idx.outputs()[i]);
array_reqs[eid] = x_reqs[i - num_forward_outputs];
}
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
array_reqs[eid] = info.grad_req;
}
const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape");
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
const auto& stypes = graph.GetAttr<StorageTypeVector>("storage_type");
const auto& dispatch_modes = graph.GetAttr<DispatchModeVector>("dispatch_mode");
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
auto num_outputs = idx[i].source->num_outputs();
for (size_t j = 0; j < num_outputs; ++j) {
auto eid = idx.entry_id(i, j);
if (arrays[eid]->is_none())
arrays[eid]->ReInit(
static_cast<NDArrayStorageType>(stypes[eid]), shapes[eid], vctx[i], dtypes[eid]);
}
}
for (size_t nid = num_forward_nodes; nid < idx.num_nodes(); ++nid) {
const nnvm::NodeAttrs& attrs = idx[nid].source->attrs;
for (size_t oid = 0; oid < idx[nid].source->num_outputs(); ++oid) {
size_t eid = idx.entry_id(nid, oid);
arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
}
} // for (nid ∈ [num_forward_nodes, idx.num_nodes()))
if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
common::LogMemoryPlan(graph);
}
// Execution
bool prev_recording = set_is_recording(create_graph);
bool prev_training = set_is_training(is_train);
int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
try {
RunGraph(retain_graph,
idx,
arrays,
num_forward_nodes,
idx.num_nodes(),
std::move(array_reqs),
std::move(ref_count),
&states,
dispatch_modes,
is_recording());
} catch (const dmlc::Error& e) {
Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
set_is_training(prev_training);
throw e;
}
Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
set_is_training(prev_training);
// Clear history
if (!retain_graph) {
nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) {
AGInfo::Clear(n);
n->inputs.clear();
});
}
if (variables.size()) {
return x_grads;
}
return {};
}