fn select_argmax()

in gad/src/compare.rs [94:114]


        fn select_argmax(
            &mut self,
            v0: &af::Array<T>,
            v1: &af::Array<T>,
            r0: Option<&af::Array<T>>,
            r1: Option<&af::Array<T>>,
        ) -> Result<af::Array<T>> {
            self.check().select_argmax(
                &v0.dims(),
                &v1.dims(),
                r0.map(|r| r.dims()).as_ref(),
                r1.map(|r| r.dims()).as_ref(),
            )?;
            let cmp = af::ge(v0, v1, false);
            match (r0, r1) {
                (Some(r0), Some(r1)) => Ok(af::select(r0, &cmp, r1)),
                (None, Some(r1)) => Ok(af::selectl(0.0, &cmp, r1)),
                (Some(r0), None) => Ok(af::selectr(r0, &cmp, 0.0)),
                (None, None) => Ok(af::constant(T::zero(), v0.dims())),
            }
        }