in tfjs-core/src/ops/fused/mat_mul.ts [58:211]
function fusedMatMul_({
a,
b,
transposeA = false,
transposeB = false,
bias,
activation = 'linear',
preluActivationWeights,
leakyreluAlpha,
}: {
a: Tensor|TensorLike,
b: Tensor|TensorLike,
transposeA?: boolean,
transposeB?: boolean,
bias?: Tensor|TensorLike,
activation?: Activation,
preluActivationWeights?: Tensor
leakyreluAlpha?: number
}): Tensor {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = unfusedMatMul(a, b, transposeA, transposeB);
if (bias != null) {
result = add(result, bias);
}
return applyActivation(
result, activation, preluActivationWeights, leakyreluAlpha);
}
let $a = convertToTensor(a, 'a', 'fused matMul');
let $b = convertToTensor(b, 'b', 'fused matMul');
[$a, $b] = makeTypesMatch($a, $b);
const innerShapeA =
transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
const innerShapeB =
transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
const outerShapeA =
transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
const outerShapeB =
transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
const outerDimsA = $a.shape.slice(0, -2);
const outerDimsB = $b.shape.slice(0, -2);
const batchDimA = util.sizeFromShape(outerDimsA);
const batchDimB = util.sizeFromShape(outerDimsB);
util.assert(
innerShapeA === innerShapeB,
() => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
`${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
`${$b.shape} and transposeA=${transposeA}` +
` and transposeB=${transposeB} must match.`);
const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(
$a.shape.slice(0, -2), $b.shape.slice(0, -2));
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
const a3D: Tensor3D = transposeA ?
reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
reshape($a, [batchDimA, outerShapeA, innerShapeA]);
const b3D: Tensor3D = transposeB ?
reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
reshape($b, [batchDimB, innerShapeB, outerShapeB]);
let $bias: Tensor;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused matMul');
[$bias] = makeTypesMatch($bias, $a);
broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape);
}
let $preluActivationWeights: Tensor;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(
preluActivationWeights, 'prelu weights', 'fused matMul');
}
const grad = (dy: Tensor3D, saved: Tensor[]) => {
const [a3D, b3D, y, $bias] = saved;
// we reshape dy because the result of the forward is not
// necessarily going to be a 3d tensor due to a reshape done at the end of
// the customOp.
const dyActivation =
getFusedDyActivation(reshape(dy, y.shape), y, activation);
let aDer: Tensor;
let bDer: Tensor;
if (!transposeA && !transposeB) {
aDer = unfusedMatMul(dyActivation, b3D, false, true);
bDer = unfusedMatMul(a3D, dyActivation, true, false);
} else if (!transposeA && transposeB) {
aDer = unfusedMatMul(dyActivation, b3D, false, false);
bDer = unfusedMatMul(dyActivation, a3D, true, false);
} else if (transposeA && !transposeB) {
aDer = unfusedMatMul(b3D, dyActivation, false, true);
bDer = unfusedMatMul(a3D, dyActivation, false, false);
} else {
aDer = unfusedMatMul(b3D, dyActivation, true, true);
bDer = unfusedMatMul(dyActivation, a3D, true, true);
}
if (bias != null) {
const biasDer = getFusedBiasGradient($bias, dyActivation);
return [aDer, bDer, biasDer];
} else {
return [aDer, bDer];
}
};
const inputs: _FusedMatMulInputs = {
a: a3D,
b: b3D,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
const attrs: _FusedMatMulAttrs =
{transposeA, transposeB, activation, leakyreluAlpha};
// Depending on the the params passed in we will have different number of
// inputs and thus a a different number of elements in the gradient.
if (bias == null) {
const customOp =
customGrad((a3D: Tensor3D, b3D: Tensor3D, save: GradSaveFunc) => {
const res =
// tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(
_FusedMatMul, inputs as {} as NamedTensorMap,
attrs as {} as NamedAttrMap) as Tensor;
save([a3D, b3D, res]);
return {value: reshape(res, outShape), gradFunc: grad};
});
return customOp(a3D, b3D);
} else {
const customOpWithBias = customGrad(
(a3D: Tensor3D, b3D: Tensor3D, $bias: Tensor, save: GradSaveFunc) => {
const res =
// tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(
_FusedMatMul, inputs as {} as NamedTensorMap,
attrs as {} as NamedAttrMap) as Tensor;
save([a3D, b3D, res, $bias]);
return {value: reshape(res, outShape), gradFunc: grad};
});
return customOpWithBias(a3D, b3D, $bias);
}
}