in gad/src/compare.rs [202:245]
fn select_argmax(
&mut self,
v0: &Value<D>,
v1: &Value<D>,
r0: Option<&Value<D>>,
r1: Option<&Value<D>>,
) -> Result<Value<D>> {
let result = self.eval().select_argmax(
v0.data(),
v1.data(),
r0.map(|r| r.data()),
r1.map(|r| r.data()),
)?;
let inputs = {
let mut i = Vec::new();
if let Some(r) = r0 {
i.push(r.input());
}
if let Some(r) = r1 {
i.push(r.input());
}
i
};
let value = self.make_node(result, inputs, {
let v0 = v0.clone();
let v1 = v1.clone();
let id0 = r0.and_then(Value::id);
let id1 = r1.and_then(Value::id);
move |graph, store, gradient| {
let c0 = graph.link(&v0);
let c1 = graph.link(&v1);
if let Some(id) = id0 {
let grad = graph.select_argmax(c0, c1, Some(&gradient), None)?;
store.add_gradient(graph, id, &grad)?;
}
if let Some(id) = id1 {
let grad = graph.select_argmax(c0, c1, None, Some(&gradient))?;
store.add_gradient(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}