fn softmax()

in crates/ratchet-core/src/cpu/softmax.rs [21:42]


fn softmax<T>(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError>
where
    T: TensorDType + Float + NumAssignOps,
{
    let src_shape = input.shape();
    let mut input = input.to_vec::<T>()?;
    let N = src_shape[dim];
    input.chunks_mut(N).for_each(|chunk| {
        let mut sum = T::zero();
        for j in 0..N {
            chunk[j] = chunk[j].exp();
            sum += chunk[j];
        }
        for j in 0..N {
            chunk[j] /= sum;
        }
    });

    cpu_store_result(dst, &input);

    Ok(())
}