src/webgpu/shader/execution/expression/call/builtin/dot.cache.ts (99 lines of code) (raw):
import { ROArrayArray } from '../../../../../../common/util/types.js';
import { assert } from '../../../../../../common/util/util.js';
import { kValue } from '../../../../../util/constants.js';
import { FP } from '../../../../../util/floating_point.js';
import {
calculatePermutations,
sparseVectorI32Range,
sparseVectorI64Range,
sparseVectorU32Range,
vectorI32Range,
vectorI64Range,
vectorU32Range,
} from '../../../../../util/math.js';
import {
generateVectorVectorToI32Cases,
generateVectorVectorToI64Cases,
generateVectorVectorToU32Cases,
} from '../../case.js';
import { makeCaseCache } from '../../case_cache.js';
function ai_dot(x: bigint[], y: bigint[]): bigint | undefined {
assert(x.length === y.length, 'Cannot calculate dot for vectors of different lengths');
const multiplications = x.map((_, idx) => x[idx] * y[idx]);
if (multiplications.some(kValue.i64.isOOB)) return undefined;
const result = multiplications.reduce((prev, curr) => prev + curr);
if (kValue.i64.isOOB(result)) return undefined;
// The spec does not state the ordering of summation, so all the
// permutations are calculated and the intermediate results checked for
// going OOB. vec2 does not need permutations, since a + b === b + a.
// All the end results should be the same regardless of the order if the
// intermediate additions stay inbounds.
if (x.length !== 2) {
let wentOOB: boolean = false;
const permutations: ROArrayArray<bigint> = calculatePermutations(multiplications);
permutations.forEach(p => {
if (!wentOOB) {
p.reduce((prev, curr) => {
const next = prev + curr;
if (kValue.i64.isOOB(next)) {
wentOOB = true;
}
return next;
});
}
});
if (wentOOB) return undefined;
}
return !kValue.i64.isOOB(result) ? result : undefined;
}
function ci_dot(x: number[], y: number[]): number | undefined {
assert(x.length === y.length, 'Cannot calculate dot for vectors of different lengths');
return x.reduce((prev, _, idx) => prev + Math.imul(x[idx], y[idx]), 0);
}
// Cases: [f32|f16|abstract]_vecN_[non_]const
const float_cases = (['f32', 'f16', 'abstract'] as const)
.flatMap(trait =>
([2, 3, 4] as const).flatMap(N =>
([true, false] as const).map(nonConst => ({
[`${trait === 'abstract' ? 'abstract_float' : trait}_vec${N}_${
nonConst ? 'non_const' : 'const'
}`]: () => {
// Emit an empty array for not const abstract float, since they will never be run
if (trait === 'abstract' && nonConst) {
return [];
}
// vec3 and vec4 require calculating all possible permutations, so their runtime is much
// longer per test, so only using sparse vectors for them.
return FP[trait].generateVectorPairToIntervalCases(
N === 2 ? FP[trait].vectorRange(2) : FP[trait].sparseVectorRange(N),
N === 2 ? FP[trait].vectorRange(2) : FP[trait].sparseVectorRange(N),
nonConst ? 'unfiltered' : 'finite',
// dot has an inherited accuracy, so abstract is only expected to be as accurate as f32
FP[trait !== 'abstract' ? trait : 'f32'].dotInterval
);
},
}))
)
)
.reduce((a, b) => ({ ...a, ...b }), {});
const cases = {
...float_cases,
abstract_int_vec2: () => {
return generateVectorVectorToI64Cases(vectorI64Range(2), vectorI64Range(2), ai_dot);
},
abstract_int_vec3: () => {
return generateVectorVectorToI64Cases(sparseVectorI64Range(3), sparseVectorI64Range(3), ai_dot);
},
abstract_int_vec4: () => {
return generateVectorVectorToI64Cases(sparseVectorI64Range(4), sparseVectorI64Range(4), ai_dot);
},
i32_vec2: () => {
return generateVectorVectorToI32Cases(vectorI32Range(2), vectorI32Range(2), ci_dot);
},
i32_vec3: () => {
return generateVectorVectorToI32Cases(sparseVectorI32Range(3), sparseVectorI32Range(3), ci_dot);
},
i32_vec4: () => {
return generateVectorVectorToI32Cases(sparseVectorI32Range(4), sparseVectorI32Range(4), ci_dot);
},
u32_vec2: () => {
return generateVectorVectorToU32Cases(vectorU32Range(2), vectorU32Range(2), ci_dot);
},
u32_vec3: () => {
return generateVectorVectorToU32Cases(sparseVectorU32Range(3), sparseVectorU32Range(3), ci_dot);
},
u32_vec4: () => {
return generateVectorVectorToU32Cases(sparseVectorU32Range(4), sparseVectorU32Range(4), ci_dot);
},
};
export const d = makeCaseCache('dot', cases);