in crates/ratchet-core/src/ops/unary.rs [411:448]
fn run_unary_trial(prob: UnaryProblem, device: Device) -> anyhow::Result<()> {
let UnaryProblem { op, B, M, N: _ } = prob;
let a = Tensor::randn::<f32>(shape![B, M], Device::CPU);
let args = match op {
UnaryOp::Gelu => "approximate=\"tanh\"",
_ => "",
};
let ground = ground_truth(&a, &op, args)?;
let a = a.to(&device)?;
let c = match op {
UnaryOp::Gelu => a.gelu()?,
UnaryOp::Tanh => a.tanh()?,
UnaryOp::Exp => a.exp()?,
UnaryOp::Log => a.log()?,
UnaryOp::Sin => a.sin()?,
UnaryOp::Cos => a.cos()?,
UnaryOp::Abs => a.abs()?,
UnaryOp::Sqrt => a.sqrt()?,
UnaryOp::Relu => a.relu()?,
UnaryOp::Floor => a.floor()?,
UnaryOp::Ceil => a.ceil()?,
UnaryOp::Neg => a.neg()?,
UnaryOp::Silu => a.silu()?,
UnaryOp::Sigmoid => a.sigmoid()?,
}
.resolve()?;
let (atol, rtol) = match op {
UnaryOp::Gelu | UnaryOp::Tanh => (5e-2, 5e-2),
_ => (1e-4, 1e-4),
};
let d = c.to(&Device::CPU)?;
ground.all_close(&d, atol, rtol)?;
Ok(())
}