function fusedMatMul_()

in tfjs-core/src/ops/fused/mat_mul.ts [58:211]


function fusedMatMul_({
  a,
  b,
  transposeA = false,
  transposeB = false,
  bias,
  activation = 'linear',
  preluActivationWeights,
  leakyreluAlpha,
}: {
  a: Tensor|TensorLike,
  b: Tensor|TensorLike,
  transposeA?: boolean,
  transposeB?: boolean,
  bias?: Tensor|TensorLike,
  activation?: Activation,
  preluActivationWeights?: Tensor
  leakyreluAlpha?: number
}): Tensor {
    if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
      let result = unfusedMatMul(a, b, transposeA, transposeB);
      if (bias != null) {
        result = add(result, bias);
      }

      return applyActivation(
                 result, activation, preluActivationWeights, leakyreluAlpha);
    }

    let $a = convertToTensor(a, 'a', 'fused matMul');
    let $b = convertToTensor(b, 'b', 'fused matMul');
    [$a, $b] = makeTypesMatch($a, $b);

    const innerShapeA =
        transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
    const innerShapeB =
        transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];

    const outerShapeA =
        transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
    const outerShapeB =
        transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];

    const outerDimsA = $a.shape.slice(0, -2);
    const outerDimsB = $b.shape.slice(0, -2);
    const batchDimA = util.sizeFromShape(outerDimsA);
    const batchDimB = util.sizeFromShape(outerDimsB);

    util.assert(
        innerShapeA === innerShapeB,
        () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
            `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
            `${$b.shape} and transposeA=${transposeA}` +
            ` and transposeB=${transposeB} must match.`);

    const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(
        $a.shape.slice(0, -2), $b.shape.slice(0, -2));
    const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);

    const a3D: Tensor3D = transposeA ?
        reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
        reshape($a, [batchDimA, outerShapeA, innerShapeA]);
    const b3D: Tensor3D = transposeB ?
        reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
        reshape($b, [batchDimB, innerShapeB, outerShapeB]);

    let $bias: Tensor;
    if (bias != null) {
      $bias = convertToTensor(bias, 'bias', 'fused matMul');
      [$bias] = makeTypesMatch($bias, $a);

      broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape);
    }

    let $preluActivationWeights: Tensor;
    if (preluActivationWeights != null) {
      $preluActivationWeights = convertToTensor(
          preluActivationWeights, 'prelu weights', 'fused matMul');
    }

    const grad = (dy: Tensor3D, saved: Tensor[]) => {
      const [a3D, b3D, y, $bias] = saved;
      // we reshape dy because the result of the forward is not
      // necessarily going to be a 3d tensor due to a reshape done at the end of
      // the customOp.
      const dyActivation =
          getFusedDyActivation(reshape(dy, y.shape), y, activation);
      let aDer: Tensor;
      let bDer: Tensor;

      if (!transposeA && !transposeB) {
        aDer = unfusedMatMul(dyActivation, b3D, false, true);
        bDer = unfusedMatMul(a3D, dyActivation, true, false);
      } else if (!transposeA && transposeB) {
        aDer = unfusedMatMul(dyActivation, b3D, false, false);
        bDer = unfusedMatMul(dyActivation, a3D, true, false);
      } else if (transposeA && !transposeB) {
        aDer = unfusedMatMul(b3D, dyActivation, false, true);
        bDer = unfusedMatMul(a3D, dyActivation, false, false);
      } else {
        aDer = unfusedMatMul(b3D, dyActivation, true, true);
        bDer = unfusedMatMul(dyActivation, a3D, true, true);
      }

      if (bias != null) {
        const biasDer = getFusedBiasGradient($bias, dyActivation);
        return [aDer, bDer, biasDer];
      } else {
        return [aDer, bDer];
      }
    };

    const inputs: _FusedMatMulInputs = {
      a: a3D,
      b: b3D,
      bias: $bias,
      preluActivationWeights: $preluActivationWeights
    };
    const attrs: _FusedMatMulAttrs =
        {transposeA, transposeB, activation, leakyreluAlpha};

    // Depending on the the params passed in we will have different number of
    // inputs and thus a a different number of elements in the gradient.
    if (bias == null) {
      const customOp =
          customGrad((a3D: Tensor3D, b3D: Tensor3D, save: GradSaveFunc) => {
            const res =
                // tslint:disable-next-line: no-unnecessary-type-assertion
                ENGINE.runKernel(
                    _FusedMatMul, inputs as {} as NamedTensorMap,
                    attrs as {} as NamedAttrMap) as Tensor;

            save([a3D, b3D, res]);

            return {value: reshape(res, outShape), gradFunc: grad};
          });
      return customOp(a3D, b3D);
    } else {
      const customOpWithBias = customGrad(
          (a3D: Tensor3D, b3D: Tensor3D, $bias: Tensor, save: GradSaveFunc) => {
            const res =
                // tslint:disable-next-line: no-unnecessary-type-assertion
                ENGINE.runKernel(
                    _FusedMatMul, inputs as {} as NamedTensorMap,
                    attrs as {} as NamedAttrMap) as Tensor;

            save([a3D, b3D, res, $bias]);

            return {value: reshape(res, outShape), gradFunc: grad};
          });

      return customOpWithBias(a3D, b3D, $bias);
    }
  }