fn select_argmax()

in gad/src/compare.rs [202:245]


            fn select_argmax(
                &mut self,
                v0: &Value<D>,
                v1: &Value<D>,
                r0: Option<&Value<D>>,
                r1: Option<&Value<D>>,
            ) -> Result<Value<D>> {
                let result = self.eval().select_argmax(
                    v0.data(),
                    v1.data(),
                    r0.map(|r| r.data()),
                    r1.map(|r| r.data()),
                )?;
                let inputs = {
                    let mut i = Vec::new();
                    if let Some(r) = r0 {
                        i.push(r.input());
                    }
                    if let Some(r) = r1 {
                        i.push(r.input());
                    }
                    i
                };
                let value = self.make_node(result, inputs, {
                    let v0 = v0.clone();
                    let v1 = v1.clone();
                    let id0 = r0.and_then(Value::id);
                    let id1 = r1.and_then(Value::id);
                    move |graph, store, gradient| {
                        let c0 = graph.link(&v0);
                        let c1 = graph.link(&v1);
                        if let Some(id) = id0 {
                            let grad = graph.select_argmax(c0, c1, Some(&gradient), None)?;
                            store.add_gradient(graph, id, &grad)?;
                        }
                        if let Some(id) = id1 {
                            let grad = graph.select_argmax(c0, c1, None, Some(&gradient))?;
                            store.add_gradient(graph, id, &grad)?;
                        }
                        Ok(())
                    }
                });
                Ok(value)
            }