export function batchMatMulImpl()

in tfjs-backend-webgl/src/kernels/BatchMatMul_impl.ts [46:187]


export function batchMatMulImpl({
  a,
  b,
  transposeA,
  transposeB,
  backend,
  bias = null,
  preluActivationWeights = null,
  leakyreluAlpha = 0,
  activation = null
}: BatchMatMulConfig): TensorInfo {
  const aRank = a.shape.length;
  const bRank = b.shape.length;

  const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
  const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];

  const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
  const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];

  const outerDimsA = a.shape.slice(0, -2);
  const outerDimsB = b.shape.slice(0, -2);

  const batchDimA = util.sizeFromShape(outerDimsA);
  const batchDimB = util.sizeFromShape(outerDimsB);

  const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(
      a.shape.slice(0, -2), b.shape.slice(0, -2));
  const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);

  util.assert(
      innerShapeA === innerShapeB,
      () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
          `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
          `${b.shape} and transposeA=${transposeA}` +
          ` and transposeB=${transposeB} must match.`);

  const a3dShape: [number, number, number] = transposeA ?
      [batchDimA, innerShapeA, outerShapeA] :
      [batchDimA, outerShapeA, innerShapeA];
  const b3dShape: [number, number, number] = transposeB ?
      [batchDimB, outerShapeB, innerShapeB] :
      [batchDimB, innerShapeB, outerShapeB];

  // The rest of the implementation is designed to operate on rank-3 tensors
  const a3d = reshape({inputs: {x: a}, backend, attrs: {shape: a3dShape}});
  const b3d = reshape({inputs: {x: b}, backend, attrs: {shape: b3dShape}});

  const intermediates: TensorInfo[] = [a3d, b3d];

  const batchDim = Math.max(batchDimA, batchDimB);
  const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];

  const hasBias = bias != null;
  const hasPreluActivationWeights = preluActivationWeights != null;
  const hasLeakyreluAlpha = activation === 'leakyrelu';
  const fusedActivation = activation != null ?
      mapActivationToShaderProgram(activation, true) :
      null;
  const containsFusedOps = hasBias || hasPreluActivationWeights ||
      hasLeakyreluAlpha || fusedActivation != null;
  let out: TensorInfo;

  // Since the matrices are vectors, it is faster to call mul().sum()
  // because sum() is O(sqrt(N)) due to divide-and-conquer.
  if ((outerShapeA === 1 || outerShapeB === 1) &&
      sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
    let aVec = a3d;
    let bVec = b3d;
    if (transposeA) {
      aVec = transpose({inputs: {x: a3d}, backend, attrs: {perm: [0, 2, 1]}});
      intermediates.push(aVec);
    }
    if (transposeB) {
      bVec = transpose({inputs: {x: b3d}, backend, attrs: {perm: [0, 2, 1]}});
      intermediates.push(bVec);
    }

    const shouldReshapeA = outerShapeB !== 1;
    const shouldReshapeB = outerShapeB === 1;

    let aVec3d = aVec;
    if (shouldReshapeA) {
      aVec3d = reshape({
        inputs: {x: aVec},
        backend,
        attrs: {shape: [batchDim, sharedDim, 1]}
      });

      intermediates.push(aVec3d);
    }

    const axis = outerShapeB === 1 ? 2 : 1;

    let bVec3d = bVec;
    if (shouldReshapeB) {
      bVec3d = reshape({
        inputs: {x: bVec},
        backend,
        attrs: {shape: [batchDim, 1, sharedDim]}
      });

      intermediates.push(bVec3d);
    }

    const product = multiply({inputs: {a: aVec3d, b: bVec3d}, backend});
    out = sum({inputs: {x: product}, backend, attrs: {axis, keepDims: true}});
    intermediates.push(product);
  } else {
    const dtype = upcastType(a.dtype, b.dtype);

    const program = new MatMulPackedProgram(
        a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA,
        transposeB, hasBias, fusedActivation, hasPreluActivationWeights,
        hasLeakyreluAlpha);

    const inputs: TensorInfo[] = [a3d, b3d];
    if (bias != null) {
      inputs.push(bias);
    }
    if (hasPreluActivationWeights) {
      inputs.push(preluActivationWeights);
    }
    if (hasLeakyreluAlpha) {
      const $leakyreluAlpha = backend.makeTensorInfo(
          [], 'float32',
          util.createScalarValue(leakyreluAlpha as {} as 'float32', 'float32'));
      inputs.push($leakyreluAlpha);
      intermediates.push($leakyreluAlpha);
    }

    out = backend.runWebGLProgram(program, inputs, dtype);
  }

  const outReshaped =
      reshape({inputs: {x: out}, backend, attrs: {shape: outShape}});
  intermediates.push(out);
  for (const i of intermediates) {
    backend.disposeIntermediateTensorInfo(i);
  }
  return outReshaped;
}