in crates/ratchet-core/src/ops/conv.rs [40:123]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::GlobalInvocationId,
BuiltIn::LocalInvocationId,
BuiltIn::WorkgroupId,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
let dt = P::T::DT;
kernel_builder.write_global(wgsl! {
var<workgroup> F: array<'dt, 4096u>;
});
kernel_builder.write_global(wgsl! {
fn inner(input_index: u32, filter_index: u32, output_index: u32, bias_index: u32, start: u32, end: u32) {
var inp = vec3<'dt>(0f);
var kernel = vec3<'dt>(0f);
var acc = vec3<'dt>(0f);
for(var i = 0u; i < metadata.Cin; i++) {
let input_start = input_index + (i * metadata.Lin) - metadata.padding; //-1 is for padding
//We only populate the input between the provided indices, used for padding
for(var j = start; j <= end; j++) {
inp[j] = X[input_start + j];
}
let filter_start = i * metadata.KS;
kernel.x = F[filter_start];
kernel.y = F[filter_start + 1u];
kernel.z = F[filter_start + 2u];
acc = fma(inp, kernel, acc);
}
Y[output_index] = acc.x + acc.y + acc.z + B[bias_index];
}
//Each thread may load more than 1 element into shared memory
fn load_filters_into_smem(local_invocation_id: vec3<u32>, filter_index: u32) {
let windex = filter_index + (local_invocation_id.x * metadata.Fperthread);
let findex = (local_invocation_id.x * metadata.Fperthread);
for(var i=0u; i < metadata.Fperthread; i++) {
if findex + i < metadata.F_numel {
F[findex + i] = W[windex + i];
}
}
}
});
let wgsx = workgroup_size.x.render();
kernel_builder.write_main(wgsl!{
let input_index = (workgroup_id.x * 'wgsx + local_invocation_id.x) * metadata.stride;
let filter_index = (workgroup_id.y * metadata.F_numel);
load_filters_into_smem(local_invocation_id, filter_index);
workgroupBarrier();
if input_index >= metadata.Lin {
//Break after loading because all threads may be needed for loading F
return;
}
let output_index = (workgroup_id.x * 'wgsx + local_invocation_id.x) + (workgroup_id.y * metadata.Lout);
let bias_index = workgroup_id.y;
if input_index == metadata.Lin - metadata.padding {
inner(input_index, filter_index, output_index, bias_index, 0u, 1u);
} else if input_index == 0u {
inner(input_index, filter_index, output_index, bias_index, 1u, 2u);
} else {
inner(input_index, filter_index, output_index, bias_index, 0u, 2u);
}
});
Ok(kernel_builder.build()?)
}