in gad/src/compare.rs [94:114]
fn select_argmax(
&mut self,
v0: &af::Array<T>,
v1: &af::Array<T>,
r0: Option<&af::Array<T>>,
r1: Option<&af::Array<T>>,
) -> Result<af::Array<T>> {
self.check().select_argmax(
&v0.dims(),
&v1.dims(),
r0.map(|r| r.dims()).as_ref(),
r1.map(|r| r.dims()).as_ref(),
)?;
let cmp = af::ge(v0, v1, false);
match (r0, r1) {
(Some(r0), Some(r1)) => Ok(af::select(r0, &cmp, r1)),
(None, Some(r1)) => Ok(af::selectl(0.0, &cmp, r1)),
(Some(r0), None) => Ok(af::selectr(r0, &cmp, 0.0)),
(None, None) => Ok(af::constant(T::zero(), v0.dims())),
}
}