export async function runAccuracyTest()

in src/webgpu/shader/execution/expression/call/builtin/subgroup_util.ts [126:306]


export async function runAccuracyTest(
  t: GPUTest,
  seed: number,
  wgSize: number[],
  operation: string,
  type: 'f16' | 'f32',
  identity: number,
  intervalGen: (x: number | FPInterval, y: number | FPInterval) => FPInterval
) {
  assert(seed < kNumCases);
  const prng = new PRNG(seed);

  // Compatibility mode has lower workgroup limits.
  const wgThreads = wgSize[0] * wgSize[1] * wgSize[2];
  const {
    maxComputeInvocationsPerWorkgroup,
    maxComputeWorkgroupSizeX,
    maxComputeWorkgroupSizeY,
    maxComputeWorkgroupSizeZ,
  } = t.device.limits;
  t.skipIf(
    maxComputeInvocationsPerWorkgroup < wgThreads ||
      maxComputeWorkgroupSizeX < wgSize[0] ||
      maxComputeWorkgroupSizeY < wgSize[1] ||
      maxComputeWorkgroupSizeZ < wgSize[2],
    'Workgroup size too large'
  );

  // Bias half the cases to lower indices since most subgroup sizes are <= 64.
  let indexLimit = kStride;
  if (seed < kNumCases / 4) {
    indexLimit = 16;
  } else if (seed < kNumCases / 2) {
    indexLimit = 64;
  }

  // Ensure two distinct indices are picked.
  const idx1 = prng.uniformInt(indexLimit);
  let idx2 = prng.uniformInt(indexLimit - 1);
  if (idx1 === idx2) {
    idx2++;
  }
  assert(idx2 < indexLimit);

  // Select two random values.
  const range = type === 'f16' ? sparseScalarF16Range() : sparseScalarF32Range();
  const numVals = range.length;
  const val1 = range[prng.uniformInt(numVals)];
  const val2 = range[prng.uniformInt(numVals)];

  const extraEnables = type === 'f16' ? `enable f16;` : ``;
  const wgsl = `
enable subgroups;
${extraEnables}

@group(0) @binding(0)
var<storage> inputs : array<${type}>;

@group(0) @binding(1)
var<storage, read_write> outputs : array<${type}>;

struct Metadata {
  subgroup_id : array<u32, ${kStride}>,
}

@group(0) @binding(2)
var<storage, read_write> metadata : Metadata;

@compute @workgroup_size(${wgSize[0]}, ${wgSize[1]}, ${wgSize[2]})
fn main(
  @builtin(local_invocation_index) lid : u32,
) {
  metadata.subgroup_id[lid] = subgroupBroadcast(lid, 0);
  outputs[lid] = ${operation}(inputs[lid]);
}`;

  const inputData =
    type === 'f16'
      ? new Float16Array([
          ...iterRange(kStride, x => {
            if (x === idx1) return val1;
            if (x === idx2) return val2;
            return identity;
          }),
        ])
      : new Float32Array([
          ...iterRange(kStride, x => {
            if (x === idx1) return val1;
            if (x === idx2) return val2;
            return identity;
          }),
        ]);

  const inputBuffer = t.makeBufferWithContents(
    inputData,
    GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
  );
  t.trackForCleanup(inputBuffer);

  const outputBuffer = t.makeBufferWithContents(
    new Float32Array([...iterRange(kStride, x => 0)]),
    GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
  );
  t.trackForCleanup(outputBuffer);

  const numMetadata = kStride;
  const metadataBuffer = t.makeBufferWithContents(
    new Uint32Array([...iterRange(numMetadata, x => 0)]),
    GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
  );

  const pipeline = t.device.createComputePipeline({
    layout: 'auto',
    compute: {
      module: t.device.createShaderModule({
        code: wgsl,
      }),
      entryPoint: 'main',
    },
  });
  const bg = t.device.createBindGroup({
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      {
        binding: 0,
        resource: {
          buffer: inputBuffer,
        },
      },
      {
        binding: 1,
        resource: {
          buffer: outputBuffer,
        },
      },
      {
        binding: 2,
        resource: {
          buffer: metadataBuffer,
        },
      },
    ],
  });

  const encoder = t.device.createCommandEncoder({ label: 'runAccuracyTest' });
  const pass = encoder.beginComputePass();
  pass.setPipeline(pipeline);
  pass.setBindGroup(0, bg);
  pass.dispatchWorkgroups(1, 1, 1);
  pass.end();
  t.queue.submit([encoder.finish()]);

  const metadataReadback = await t.readGPUBufferRangeTyped(metadataBuffer, {
    srcByteOffset: 0,
    type: Uint32Array,
    typedLength: numMetadata,
    method: 'copy',
  });
  const metadata = metadataReadback.data;

  let output: Float16Array | Float32Array;
  if (type === 'f16') {
    const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
      srcByteOffset: 0,
      type: Float16Array,
      typedLength: kStride,
      method: 'copy',
    });
    output = outputReadback.data;
  } else {
    const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
      srcByteOffset: 0,
      type: Float32Array,
      typedLength: kStride,
      method: 'copy',
    });
    output = outputReadback.data;
  }

  t.expectOK(checkAccuracy(metadata, output, [idx1, idx2], [val1, val2], identity, intervalGen));
}