in candle-core/src/cpu_backend/mod.rs [1181:1264]
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst_s0 = p.c_out * out_h * out_w;
let dst_s1 = out_h * out_w;
let dst_s2 = out_w;
let dst_s3 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
let cont_s0 = p.i_h * p.i_w * p.c_in;
let cont_s1 = p.i_w * p.c_in;
let cont_s2 = p.c_in;
for b_idx in 0..p.b_size {
for h_idx in 0..p.i_h {
for w_idx in 0..p.i_w {
for c_idx in 0..p.c_in {
let src_idx =
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
}
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
})
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for inp_y in 0..p.i_h {
for inp_x in 0..p.i_w {
let out_x = inp_x * p.stride + k_x * p.dilation;
let out_y = inp_y * p.stride + k_y * p.dilation;
if out_x < p.padding || out_y < p.padding {
continue;
}
let out_x = out_x - p.padding;
let out_y = out_y - p.padding;
if out_x < out_w && out_y < out_h {
let inp_cont = &inp_cont
[b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
let dst_idx = b_idx * dst_s0
+ out_y * dst_s2
+ out_x * dst_s3
+ dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(
inp_cont.as_ptr(),
k_cont.as_ptr(),
&mut d,
p.c_in,
)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
}
})
}
}
Ok(dst)
}