in gad/src/graph.rs [210:246]
fn do_compute_gradients_once<D, G>(
mut self,
graph: &mut C::GradientAlgebra,
gid: GradientId<D>,
gradient: G,
) -> Result<C::GradientStore>
where
C::GradientAlgebra: CoreAlgebra<D, Value = G>,
C::GradientStore: GradientStore<GradientId<D>, G> + Default,
{
let mut store = C::GradientStore::default();
store.insert(gid, gradient);
let mut heap = BinaryHeap::with_capacity(self.nodes.len());
heap.push(gid.inner);
let mut guard = gid.inner.next_id();
while let Some(id) = heap.pop() {
if id < guard {
guard = id;
let node = self
.nodes
.get_mut(id)
.ok_or_else(|| Error::missing_node(func_name!()))?;
if let Some(update_func) = &node.update_func {
update_func(graph, &mut store, id)?;
}
for input in &node.inputs {
if let Some(id) = input {
heap.push(*id);
}
}
node.clear();
}
}
Ok(store)
}