in tfjs-backend-webgpu/src/shader_preprocessor.ts [74:647]
export function getMainHeaderString(): string {
return `
${getWorkGroupSizeString()}
fn main(@builtin(local_invocation_id) LocalId : vec3<u32>,
@builtin(global_invocation_id) GlobalId : vec3<u32>,
@builtin(num_workgroups) NumWorkgroups: vec3<u32>) {
localId = LocalId;
globalId = GlobalId;
numWorkgroups = NumWorkgroups;
`;
}
export function getMainHeaderAndGlobalIndexString(): string {
return `
${getMainHeaderString()}
let index = getGlobalIndex();
`;
}
export function makeShader(
inputInfo: InputInfo[], outputData: {dtype: DataType, shape: number[]},
program: ProgramParams, isFromPixel = false): string {
const prefixSnippets: string[] = [];
prefixSnippets.push(`
let workGroupSizeX = ${program.workGroupSize[0]}u;
let workGroupSizeY = ${program.workGroupSize[1]}u;
let workGroupSizeZ = ${program.workGroupSize[2]}u;
var<private> localId: vec3<u32>;
var<private> globalId: vec3<u32>;
var<private> numWorkgroups: vec3<u32>;
// Only used when the y/z dimension of workgroup size is 1.
fn getGlobalIndex() -> i32 {
if (numWorkgroups.y == 1u && numWorkgroups.z == 1u) {
return i32(globalId.x);
}
let localInvocationIndex = localId.z * workGroupSizeX * workGroupSizeY +
localId.y * workGroupSizeX + localId.x;
let workGroupID = (globalId - localId)/vec3<u32>(
workGroupSizeX, workGroupSizeY, workGroupSizeZ);
return i32((workGroupID.z * numWorkgroups.x * numWorkgroups.y +
workGroupID.y * numWorkgroups.x + workGroupID.x) *
(workGroupSizeX * workGroupSizeY * workGroupSizeZ) +
localInvocationIndex);
}
`);
if (isFromPixel === true) {
prefixSnippets.push(`
struct Matrix0 {
numbers: array<${mapToWgslTypes(outputData.dtype, program.isVec4)}>;
};
struct Uniform {
size : i32;
numChannels : i32;
outShapeStrides : vec2<i32>;
dispatchSize : vec3<u32>;
};
@group(0) @binding(0) var<storage, write> result : Matrix0;
@group(0) @binding(2) var<uniform> uniforms: Uniform;
`);
return [
commonSnippet,
prefixSnippets.join('\n'),
getCoordsFromIndexSnippet(outputData.shape),
program.getUserCode(),
].join('\n');
}
let uniformDeclaration = 'struct Uniforms { NAN : f32; ';
program.variableNames.forEach((x, i) => {
uniformDeclaration += `${x.charAt(0).toLowerCase() + x.slice(1)}Shape : ${
getCoordsDataType(inputInfo[i].shape.length)}; `;
});
uniformDeclaration +=
`outShape : ${getCoordsDataType(outputData.shape.length)} ; `;
const stridesLength = outputData.shape.length - 1;
uniformDeclaration += `
outShapeStrides: ${getCoordsDataType(stridesLength)}; `;
if (program.size) {
uniformDeclaration += 'size : i32; ';
}
if (program.uniforms) {
uniformDeclaration += program.uniforms;
}
uniformDeclaration += '};';
prefixSnippets.push(uniformDeclaration);
// Output buffer.
if (program.atomic) {
prefixSnippets.push(`
struct Matrix0 {
numbers: array<atomic<i32>>;
};
@group(0) @binding(0) var<storage, read_write> result : Matrix0;
`);
} else {
prefixSnippets.push(`
struct Matrix0 {
numbers: array<${mapToWgslTypes(outputData.dtype, program.isVec4)}>;
};
@group(0) @binding(0) var<storage, write> result : Matrix0;
`);
}
program.variableNames.forEach((x, i) => {
prefixSnippets.push(`
struct Matrix${1 + i} {
numbers: array<${mapToWgslTypes(inputInfo[i].dtype, program.isVec4)}>;
};
@group(0) @binding(${1 + i}) var<storage, read> ${x} : Matrix${1 + i};
`);
});
if (uniformDeclaration !== '') {
prefixSnippets.push(`
@group(0) @binding(${
1 + program.variableNames.length}) var<uniform> uniforms : Uniforms;
`);
}
const [coordsSnippet, dispatchLayoutRank] =
getOutputCoordsSnippet(outputData.shape, program.dispatchLayout);
const sources = [
commonSnippet,
prefixSnippets.join('\n'),
getCoordsFromIndexSnippet(outputData.shape),
coordsSnippet,
getOutputIndexFromCoordsSnippet(outputData.shape.length)
];
if (!program.atomic) {
sources.push(setOutputSnippet(
outputData.shape, outputData.dtype, program.isVec4));
}
if (dispatchLayoutRank === outputData.shape.length) {
// Input snippet is only meaningful when the output isn't getting
// implicitly reshaped (like it does in conv2d_matmul).
const inputSnippet =
inputInfo
.map(
x => getInputSnippet(
x, outputData.shape, program.isVec4,
program.dispatchLayout.x.length ===
outputData.shape.length))
.join('\n');
sources.push(inputSnippet);
}
sources.push(program.getUserCode());
const source = sources.join('\n');
return source;
}
const commonSnippet = `
// Checks whether coordinates lie within the bounds of the shape.
fn coordsInBounds2D(coord : vec2<i32>, shape : vec2<i32>) -> bool {
return all(coord >= vec2<i32>(0)) && all(coord < shape);
}
fn coordsInBounds3D(coord : vec3<i32>, shape : vec3<i32>) -> bool {
return all(coord >= vec3<i32>(0)) && all(coord < shape);
}
fn coordsInBounds4D(coord : vec4<i32>, shape : vec4<i32>) -> bool {
return all(coord >= vec4<i32>(0)) && all(coord < shape);
}
fn getIndexFromCoords1D(coord : i32, shape : i32) -> i32 {
return coord;
}
fn getIndexFromCoords2D(coords : vec2<i32>, shape : vec2<i32>) -> i32 {
return dot(coords, vec2<i32>(shape.y, 1));
}
fn getIndexFromCoords3D(coords : vec3<i32>, shape : vec3<i32>) -> i32 {
return dot(coords, vec3<i32>(shape.y * shape.z, shape.z, 1));
}
fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
return dot(coords, vec4<i32>(
shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
}
fn idiv(a: i32, b: i32, sign: f32) -> i32 {
var res: i32 = a / b;
let mod: i32 = a % b;
if (sign < 0. && mod != 0) {
res = res - 1;
}
return res;
}
fn isNanCustom(val : f32) -> bool {
if (val > 0.0) {
return false;
}
if (val < 0.0) {
return false;
}
if (val == 0.0) {
return false;
}
return true;
}
fn isNanCustomVec4(val : vec4<f32>) -> vec4<bool> {
return vec4<bool>(isNanCustom(val[0]), isNanCustom(val[1]), isNanCustom(val[2]), isNanCustom(val[3]));
}
`;
function getOutputIndexFromCoordsSnippet(outRank: number) {
let snippet = '';
switch (outRank) {
case 0:
case 1:
snippet += `
fn getOutputIndexFromCoords(coords : i32) -> i32 {
return coords;
}
`;
break;
case 2:
snippet += `
fn getOutputIndexFromCoords(coords : vec2<i32>) -> i32 {
return dot(coords, vec2<i32>(uniforms.outShapeStrides, 1));
}
`;
break;
case 3:
snippet += `
fn getOutputIndexFromCoords(coords : vec3<i32>) -> i32 {
return dot(coords, vec3<i32>(uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, 1));
}
`;
break;
case 4:
snippet += `
fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
return dot(coords, vec4<i32>(
uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, uniforms.outShapeStrides.z, 1));
}
`;
break;
default:
util.assert(false, () => `Unsupported ${outRank}D shape`);
break;
}
return snippet;
}
function setOutputSnippet(
outShape: number[], outBufferType: DataType, isVec4: boolean): string {
const outRank = outShape.length;
const wgslType = mapToWgslTypes(outBufferType, isVec4);
let snippet;
if (isVec4) {
snippet = `fn setOutputAtIndex(flatIndex : i32, value : vec4<f32>) {
result.numbers[flatIndex] = ${wgslType}(value);
}
fn setOutputAtIndexI32(flatIndex : i32, value : vec4<i32>) {
result.numbers[flatIndex] = ${wgslType}(value);
}`;
} else {
snippet = `fn setOutputAtIndex(flatIndex : i32, value : f32) {
result.numbers[flatIndex] = ${wgslType}(value);
}
fn setOutputAtIndexI32(flatIndex : i32, value : i32) {
result.numbers[flatIndex] = ${wgslType}(value);
}`;
}
if (outRank >= 2) {
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, outRank);
const type = getCoordsDataType(outRank);
if (isVec4) {
snippet += `
fn setOutputAtCoords(${
dims.map(d => `${d} : i32`).join(', ')}, value : vec4<f32>) {
let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')}));
setOutputAtIndex(flatIndex / 4, value);
}
fn setOutputAtCoordsI32(${
dims.map(d => `${d} : i32`).join(', ')}, value : vec4<i32>) {
let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')}));
setOutputAtIndexI32(flatIndex / 4, value);
}
`;
} else {
snippet += `
fn setOutputAtCoords(${dims.map(d => `${d} : i32`).join(', ')}, value : f32) {
let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')}));
setOutputAtIndex(flatIndex, value);
}
fn setOutputAtCoordsI32(${dims.map(d => `${d} : i32`).join(', ')}, value : i32) {
let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')}));
setOutputAtIndexI32(flatIndex, value);
}
`;
}
}
return snippet;
}
function getInputSnippet(
inputInfo: InputInfo, outShape: number[], isVec4: boolean,
isFlatDispatchLayout: boolean): string {
let res = getInputAtCoordsSnippet(inputInfo, isVec4);
const inShape = inputInfo.shape;
if (inShape.length <= outShape.length) {
res += getInputByOutputSnippet(
inputInfo, outShape, isVec4, isFlatDispatchLayout);
}
return res;
}
function getInputAtCoordsSnippet(
inputInfo: InputInfo, isVec4: boolean): string {
const texName = inputInfo.name;
const rank = inputInfo.shape.length;
const type = getCoordsDataType(rank);
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, rank);
const inputs = dims.map(d => `${d} : i32`).join(', ');
if (rank < 1) {
if (isVec4) {
return `
fn ${funcName}() -> vec4<f32> {
return vec4<f32>(${texName}.numbers[0]);
}
`;
}
return `
fn ${funcName}() ->f32 {
return f32(${texName}.numbers[0]);
}
`;
}
const shapeStr =
`uniforms.${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape`;
let rankStr = `${rank}D`;
if (rank === 0) {
rankStr = '1D';
}
if (isVec4) {
return `
fn ${funcName}(${inputs}) -> vec4<f32> {
return vec4<f32>(${texName}.numbers[getIndexFromCoords${rankStr}(${type}(${
dims.join(',')}),
${shapeStr}) / 4]);
}
`;
}
return `
fn ${funcName}(${inputs}) -> f32 {
return f32(${texName}.numbers[getIndexFromCoords${rankStr}(${type}(${
dims.join(',')}),
${shapeStr})]);
}
`;
}
export function getInputByOutputSnippet(
inputInfo: InputInfo, outShape: number[], isVec4: boolean,
isFlatDispatchLayout: boolean): string {
const texName = inputInfo.name;
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
const funcName = 'get' + texFuncSnippet + 'ByOutput';
const inRank = inputInfo.shape.length;
const outRank = outShape.length;
const type = getCoordsDataType(outRank);
// If the inShape equals the outShape and the dispatch layout is flat, we can
// directly use |gl_GlobalInvocationID.x| as the index and don't need coords
// conversion between these two shapes.
if (util.arraysEqual(inputInfo.shape, outShape) && isFlatDispatchLayout) {
if (isVec4) {
return `
fn ${funcName}Index(globalIndex : i32) -> vec4<f32> {
return vec4<f32>(${texName}.numbers[globalIndex]);
}
fn ${funcName}Coords(coords : ${type}) -> vec4<f32> {
return vec4<f32>(${texName}.numbers[${
outRank > 1 ? 'getOutputIndexFromCoords(coords)' : 'coords'} / 4]);
}
`;
} else {
return `
fn ${funcName}Index(globalIndex : i32) -> f32 {
return f32(${texName}.numbers[globalIndex]);
}
fn ${funcName}Coords(coords : ${type}) -> f32 {
return f32(${texName}.numbers[${
outRank > 1 ? 'getOutputIndexFromCoords(coords)' : 'coords'}]);
}
`;
}
}
const broadcastDims =
backend_util.getBroadcastDims(inputInfo.shape, outShape);
const rankDiff = outRank - inRank;
let coordsSnippet = '';
if (inRank === 0) {
if (isVec4) {
return `
fn ${funcName}Index(globalIndex : i32) -> vec4<f32> {
return get${texFuncSnippet}();
}
fn ${funcName}Coords(coords : ${type}) -> vec4<f32> {
return get${texFuncSnippet}();
}
`;
}
return `
fn ${funcName}Index(globalIndex : i32) -> f32{
return get${texFuncSnippet}();
}
fn ${funcName}Coords(coords : ${type}) -> f32{
return get${texFuncSnippet}();
}
`;
} else {
if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
} else {
coordsSnippet =
broadcastDims.map(d => `coords[${d + rankDiff}] = 0;`).join('\n');
}
}
let unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
if (outRank > 1) {
const coordsType = getCoordsDataType(inRank);
const coordsValues =
inputInfo.shape.map((s, i) => `coords[${i + rankDiff}]`).join(', ');
unpackedCoordsSnippet = `${coordsType}(${coordsValues})`;
} else {
unpackedCoordsSnippet = 'coords';
}
}
const shapeStr =
`uniforms.${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape`;
const rankStr = `${inRank}D`;
if (isVec4) {
return `
fn ${funcName}Index(globalIndex : i32) -> vec4<f32> {
var coords = getCoordsFromIndex(globalIndex);
${coordsSnippet}
return ${texName}.numbers[getIndexFromCoords${rankStr}(${
unpackedCoordsSnippet}, ${shapeStr}) / 4];
}
fn ${funcName}Coords(coordsIn : ${type}) -> vec4<f32> {
var coords = coordsIn;
${coordsSnippet}
return ${texName}.numbers[getIndexFromCoords${rankStr}(${
unpackedCoordsSnippet}, ${shapeStr}) / 4];
}
`;
}
return `
fn ${funcName}Index(globalIndex : i32) -> f32 {
var coords = getCoordsFromIndex(globalIndex);
${coordsSnippet}
return f32(${texName}.numbers[getIndexFromCoords${rankStr}(${
unpackedCoordsSnippet}, ${shapeStr})]);
}
fn ${funcName}Coords(coordsIn : ${type}) -> f32 {
var coords = coordsIn;
${coordsSnippet}
return f32(${texName}.numbers[getIndexFromCoords${rankStr}(${
unpackedCoordsSnippet}, ${shapeStr})]);
}
`;
}
/**
* Generates getOutputCoords() function that computes output coordinates from
* dispatch geometry to reduce arithmetic.
*/
export function getOutputCoordsSnippet(
outShape: number[],
dispatchLayout: {x: number[], y?: number[], z?: number[]}):
[string, number] {
const {x, y = [], z = []} = dispatchLayout;
const outRank = outShape.length;
if (x.length === outRank) {
const dtype = getCoordsDataType(outRank);
const snippet = `fn getOutputCoords() -> ${dtype}{
let globalIndex = getGlobalIndex();
return getCoordsFromIndex(globalIndex);
}
`;
return [snippet, outRank];
}
let gatherDimensionsStr = '';
const dims = [x, y, z];
let rank = 0;
for (let i = 0; i < dims.length; i++) {
const arr = dims[i];
if (arr.length === 0) {
continue;
}
rank += arr.length;
if (arr.length === 1) {
gatherDimensionsStr += `let d${arr[0]} = i32(globalId[${i}]);`;
} else {
const strides = symbolicallyComputeStrides(arr, 'uniforms.outShape');
gatherDimensionsStr += `var index${i} = i32(globalId[${i}]);`;
for (let j = 0; j < strides.length; j++) {
gatherDimensionsStr += `let d${arr[j]} = index${i} / ${strides[j]};`;
if (j === strides.length - 1) {
gatherDimensionsStr += `let d${arr[j + 1]} = ` +
`index${i} - d${arr[j]} * ${strides[j]};`;
} else {
gatherDimensionsStr +=
`index${i} = index${i} - d${arr[j]} * ${strides[j]};`;
}
}
}
}
const dimensions = [];
for (let i = 0; i < rank; i++) {
dimensions.push(`d${i}`);
}
const dtype = getCoordsDataType(rank);
let snippet = `fn getOutputCoords() -> ${dtype} {
${gatherDimensionsStr}
`;
if (dimensions.length === 0) {
snippet += `return ${dtype}(0); }`;
} else {
snippet += `return ${dtype}(${dimensions.join(',')}); }`;
}
return [snippet, rank];
}