fn run_unary_trial()

in crates/ratchet-core/src/ops/unary.rs [411:448]


    fn run_unary_trial(prob: UnaryProblem, device: Device) -> anyhow::Result<()> {
        let UnaryProblem { op, B, M, N: _ } = prob;
        let a = Tensor::randn::<f32>(shape![B, M], Device::CPU);

        let args = match op {
            UnaryOp::Gelu => "approximate=\"tanh\"",
            _ => "",
        };
        let ground = ground_truth(&a, &op, args)?;

        let a = a.to(&device)?;
        let c = match op {
            UnaryOp::Gelu => a.gelu()?,
            UnaryOp::Tanh => a.tanh()?,
            UnaryOp::Exp => a.exp()?,
            UnaryOp::Log => a.log()?,
            UnaryOp::Sin => a.sin()?,
            UnaryOp::Cos => a.cos()?,
            UnaryOp::Abs => a.abs()?,
            UnaryOp::Sqrt => a.sqrt()?,
            UnaryOp::Relu => a.relu()?,
            UnaryOp::Floor => a.floor()?,
            UnaryOp::Ceil => a.ceil()?,
            UnaryOp::Neg => a.neg()?,
            UnaryOp::Silu => a.silu()?,
            UnaryOp::Sigmoid => a.sigmoid()?,
        }
        .resolve()?;

        let (atol, rtol) = match op {
            UnaryOp::Gelu | UnaryOp::Tanh => (5e-2, 5e-2),
            _ => (1e-4, 1e-4),
        };

        let d = c.to(&Device::CPU)?;
        ground.all_close(&d, atol, rtol)?;
        Ok(())
    }