fn max_as()

in gad/src/array_compare.rs [123:139]


            fn max_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
                let result = self.eval().max_as(v.data(), rdims)?;
                let value = self.make_node(result, vec![v.input()], {
                    let v = v.clone();
                    move |graph, store, gradient| {
                        if let Some(id) = v.id() {
                            let v = graph.link(&v);
                            let mask = graph.argmax_as(v, rdims)?;
                            let tiled = graph.tile_as(&gradient, v.dims())?;
                            let grad = graph.mul(&tiled, &mask)?;
                            store.add_gradient::<D, _>(graph, id, &grad)?;
                        }
                        Ok(())
                    }
                });
                Ok(value)
            }