in gad/src/array_compare.rs [146:168]
fn softmax_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
let result = self.eval().softmax_as(v.data(), rdims)?;
let value = self.make_node(result, vec![v.input()], {
let v = v.clone();
let dims = v.dims();
move |graph, store, gradient| {
if let Some(id) = v.id() {
let v = graph.link(&v);
let res = graph.softmax_as(v, rdims)?;
let g1 = graph.mul(&gradient, &res)?;
let g2 = {
let rg = graph.sum_as(&g1, rdims)?;
let g = graph.tile_as(&rg, dims)?;
graph.mul(&g, &res)?
};
let grad = graph.sub(&g1, &g2)?;
store.add_gradient::<D, _>(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}