in src/webgpu/shader/execution/expression/call/builtin/subgroup_util.ts [126:306]
export async function runAccuracyTest(
t: GPUTest,
seed: number,
wgSize: number[],
operation: string,
type: 'f16' | 'f32',
identity: number,
intervalGen: (x: number | FPInterval, y: number | FPInterval) => FPInterval
) {
assert(seed < kNumCases);
const prng = new PRNG(seed);
// Compatibility mode has lower workgroup limits.
const wgThreads = wgSize[0] * wgSize[1] * wgSize[2];
const {
maxComputeInvocationsPerWorkgroup,
maxComputeWorkgroupSizeX,
maxComputeWorkgroupSizeY,
maxComputeWorkgroupSizeZ,
} = t.device.limits;
t.skipIf(
maxComputeInvocationsPerWorkgroup < wgThreads ||
maxComputeWorkgroupSizeX < wgSize[0] ||
maxComputeWorkgroupSizeY < wgSize[1] ||
maxComputeWorkgroupSizeZ < wgSize[2],
'Workgroup size too large'
);
// Bias half the cases to lower indices since most subgroup sizes are <= 64.
let indexLimit = kStride;
if (seed < kNumCases / 4) {
indexLimit = 16;
} else if (seed < kNumCases / 2) {
indexLimit = 64;
}
// Ensure two distinct indices are picked.
const idx1 = prng.uniformInt(indexLimit);
let idx2 = prng.uniformInt(indexLimit - 1);
if (idx1 === idx2) {
idx2++;
}
assert(idx2 < indexLimit);
// Select two random values.
const range = type === 'f16' ? sparseScalarF16Range() : sparseScalarF32Range();
const numVals = range.length;
const val1 = range[prng.uniformInt(numVals)];
const val2 = range[prng.uniformInt(numVals)];
const extraEnables = type === 'f16' ? `enable f16;` : ``;
const wgsl = `
enable subgroups;
${extraEnables}
@group(0) @binding(0)
var<storage> inputs : array<${type}>;
@group(0) @binding(1)
var<storage, read_write> outputs : array<${type}>;
struct Metadata {
subgroup_id : array<u32, ${kStride}>,
}
@group(0) @binding(2)
var<storage, read_write> metadata : Metadata;
@compute @workgroup_size(${wgSize[0]}, ${wgSize[1]}, ${wgSize[2]})
fn main(
@builtin(local_invocation_index) lid : u32,
) {
metadata.subgroup_id[lid] = subgroupBroadcast(lid, 0);
outputs[lid] = ${operation}(inputs[lid]);
}`;
const inputData =
type === 'f16'
? new Float16Array([
...iterRange(kStride, x => {
if (x === idx1) return val1;
if (x === idx2) return val2;
return identity;
}),
])
: new Float32Array([
...iterRange(kStride, x => {
if (x === idx1) return val1;
if (x === idx2) return val2;
return identity;
}),
]);
const inputBuffer = t.makeBufferWithContents(
inputData,
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
);
t.trackForCleanup(inputBuffer);
const outputBuffer = t.makeBufferWithContents(
new Float32Array([...iterRange(kStride, x => 0)]),
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
);
t.trackForCleanup(outputBuffer);
const numMetadata = kStride;
const metadataBuffer = t.makeBufferWithContents(
new Uint32Array([...iterRange(numMetadata, x => 0)]),
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
);
const pipeline = t.device.createComputePipeline({
layout: 'auto',
compute: {
module: t.device.createShaderModule({
code: wgsl,
}),
entryPoint: 'main',
},
});
const bg = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: inputBuffer,
},
},
{
binding: 1,
resource: {
buffer: outputBuffer,
},
},
{
binding: 2,
resource: {
buffer: metadataBuffer,
},
},
],
});
const encoder = t.device.createCommandEncoder({ label: 'runAccuracyTest' });
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg);
pass.dispatchWorkgroups(1, 1, 1);
pass.end();
t.queue.submit([encoder.finish()]);
const metadataReadback = await t.readGPUBufferRangeTyped(metadataBuffer, {
srcByteOffset: 0,
type: Uint32Array,
typedLength: numMetadata,
method: 'copy',
});
const metadata = metadataReadback.data;
let output: Float16Array | Float32Array;
if (type === 'f16') {
const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
srcByteOffset: 0,
type: Float16Array,
typedLength: kStride,
method: 'copy',
});
output = outputReadback.data;
} else {
const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
srcByteOffset: 0,
type: Float32Array,
typedLength: kStride,
method: 'copy',
});
output = outputReadback.data;
}
t.expectOK(checkAccuracy(metadata, output, [idx1, idx2], [val1, val2], identity, intervalGen));
}