export function packScalarsToVector()

in src/webgpu/shader/execution/expression/expression.ts [1277:1352]


export function packScalarsToVector(
  parameterTypes: Array<Type>,
  resultType: Type,
  cases: Case[],
  vectorWidth: number
): { cases: Case[]; parameterTypes: Array<Type>; resultType: Type } {
  // Validate that the parameters and return type are all vectorizable
  for (let i = 0; i < parameterTypes.length; i++) {
    const ty = parameterTypes[i];
    if (!(ty instanceof ScalarType)) {
      throw new Error(
        `packScalarsToVector() can only be used on scalar parameter types, but the ${i}'th parameter type is a ${ty}'`
      );
    }
  }
  if (!(resultType instanceof ScalarType)) {
    throw new Error(
      `packScalarsToVector() can only be used with a scalar return type, but the return type is a ${resultType}'`
    );
  }

  const packedCases: Array<Case> = [];
  const packedParameterTypes = parameterTypes.map(p => Type.vec(vectorWidth, p as ScalarType));
  const packedResultType = Type.vec(vectorWidth, resultType);

  const clampCaseIdx = (idx: number) => Math.min(idx, cases.length - 1);

  let caseIdx = 0;
  while (caseIdx < cases.length) {
    // Construct the vectorized inputs from the scalar cases
    const packedInputs = new Array<VectorValue>(parameterTypes.length);
    for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) {
      const inputElements = new Array<ScalarValue>(vectorWidth);
      for (let i = 0; i < vectorWidth; i++) {
        const input = cases[clampCaseIdx(caseIdx + i)].input;
        inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as ScalarValue;
      }
      packedInputs[paramIdx] = new VectorValue(inputElements);
    }

    // Gather the comparators for the packed cases
    const cmp_impls = new Array<ComparatorImpl>(vectorWidth);
    for (let i = 0; i < vectorWidth; i++) {
      cmp_impls[i] = toComparator(cases[clampCaseIdx(caseIdx + i)].expected).compare;
    }
    const comparators: Comparator = {
      compare: (got: Value) => {
        let matched = true;
        const gElements = new Array<string>(vectorWidth);
        const eElements = new Array<string>(vectorWidth);
        for (let i = 0; i < vectorWidth; i++) {
          const d = cmp_impls[i]((got as VectorValue).elements[i]);
          matched = matched && d.matched;
          gElements[i] = d.got;
          eElements[i] = d.expected;
        }
        return {
          matched,
          got: `${packedResultType}(${gElements.join(', ')})`,
          expected: `${packedResultType}(${eElements.join(', ')})`,
        };
      },
      kind: 'packed',
    };

    // Append the new packed case
    packedCases.push({ input: packedInputs, expected: comparators });
    caseIdx += vectorWidth;
  }

  return {
    cases: packedCases,
    parameterTypes: packedParameterTypes,
    resultType: packedResultType,
  };
}