crates/ratchet-core/src/cpu/binary.rs (111 lines of code) (raw):
use crate::cpu::cpu_store_result;
use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType};
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::NumOps;
#[inline]
pub(crate) fn binary_map<T: TensorDType, U: TensorDType>(
lhs: &[T],
rhs: &[T],
dst: &mut [U],
f: fn(T, T) -> U,
) {
assert_eq!(lhs.len(), dst.len());
assert_eq!(rhs.len(), dst.len());
for ((l, r), d) in lhs
.iter()
.copied()
.zip(rhs.iter().copied())
.zip(dst.iter_mut())
{
*d = f(l, r);
}
}
#[inline]
pub(crate) fn binary_map_inplace<T: TensorDType>(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) {
assert_eq!(lhs.len(), rhs.len());
lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| {
*l = f(*l, *r);
});
}
#[inline]
pub(crate) fn binary_apply<T: TensorDType, U: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_map(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}
#[inline]
pub(crate) fn binary_apply_inplace<T: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
let mut lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
binary_map_inplace(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
}
pub struct BinaryOps<T: TensorDType> {
dtype: PhantomData<T>,
}
macro_rules! impl_cpu_binary_op {
($method_name:ident, $dtype:ident, $op:expr) => {
fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op)?;
Ok(dst)
}
};
}
macro_rules! cpu_binary_op_fn {
($method_name:ident, $op:expr) => {
#[inline]
pub(crate) fn $method_name<T: TensorDType + NumOps>(lhs: &mut [T], rhs: &[T]) {
binary_map_inplace::<T>(lhs, rhs, $op);
}
};
}
cpu_binary_op_fn!(add, |lhs, rhs| lhs + rhs);
cpu_binary_op_fn!(sub, |lhs, rhs| lhs - rhs);
cpu_binary_op_fn!(mul, |lhs, rhs| lhs * rhs);
cpu_binary_op_fn!(div, |lhs, rhs| lhs / rhs);
macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl BinaryOps<$dtype> {
impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs);
impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs);
impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs);
impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs);
pub fn apply(op: &Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match op.op() {
BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst),
BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst),
BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst),
BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst),
}
}
}
};
}
impl CPUOperation for Binary {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => BinaryOps::<f32>::apply(self, dst),
DType::F16 => BinaryOps::<f16>::apply(self, dst),
DType::BF16 => BinaryOps::<bf16>::apply(self, dst),
_ => todo!(),
}
}
}
impl_cpu_binary!(f32);
impl_cpu_binary!(f16);
impl_cpu_binary!(bf16);