in crates/ratchet-core/src/ops/matmul/gemm.rs [361:460]
fn write_readers_and_writers<P: WgslPrimitive>(
&self,
builder: &mut WgslKernelBuilder,
fits: (bool, bool, bool),
) -> Result<(), OperationError> {
let FIT_A_OUTER = fits.0;
let FIT_INNER = fits.1;
let FIT_B_OUTER = fits.2;
let accessor = P::render_type();
let a_inner = if self.trans_lhs {
wgsl! { value = getA(batch, col, row); }
} else {
wgsl! { value = getA(batch, row, col); }
};
let readA = if FIT_A_OUTER && FIT_INNER {
a_inner
} else if self.trans_lhs {
wgsl! {
if (row < metadata.lhs_shape.z && col < metadata.lhs_shape.y) {
'a_inner
}
}
} else {
wgsl! {
if (row < metadata.lhs_shape.y && col < metadata.lhs_shape.z) {
'a_inner
}
}
};
let aAccessor = match self.lhs.dt() {
DType::Q8_0F(_) => Vec4::<f32>::render_type(),
DType::Q8_0H(_) => Vec4::<f16>::render_type(),
_ => accessor.clone(),
};
builder.write_global(wgsl! {
fn mm_readA(batch: i32, row: i32, col: i32) -> 'aAccessor {
var value = 'aAccessor(0.0);
'readA
return value;
}
});
let b_inner = if self.trans_rhs {
wgsl! { value = getB(batch, col, row); }
} else {
wgsl! { value = getB(batch, row, col); }
};
let readB = if FIT_INNER && FIT_B_OUTER {
b_inner
} else if self.trans_rhs {
wgsl! {
if (row < metadata.rhs_shape.z && col < metadata.rhs_shape.y) {
'b_inner
}
}
} else {
wgsl! {
if (row < metadata.rhs_shape.y && col < metadata.rhs_shape.z) {
'b_inner
}
}
};
builder.write_global(wgsl! {
fn mm_readB(batch: i32, row: i32, col: i32) -> 'accessor {
var value = 'accessor(0.0);
'readB
return value;
}
});
let write = if FIT_A_OUTER && FIT_B_OUTER {
wgsl! {
var value = valueIn;
let coords = vec3<i32>(batch, row, col);
setOutputAtCoords(coords[0], coords[1], coords[2], value);
}
} else {
wgsl! {
if (row < metadata.dim_lhs_outer && col < metadata.dim_rhs_outer) {
var value = valueIn;
let coords = vec3<i32>(batch, row, col);
setOutputAtCoords(coords[0], coords[1], coords[2], value);
}
}
};
builder.write_global(wgsl! {
fn mm_write(batch: i32, row: i32, col: i32, valueIn: 'accessor) {
'write
}
});
Ok(())
}