in tfjs-core/src/ops/fused/depthwise_conv2d.ts [90:259]
function fusedDepthwiseConv2d_<T extends Tensor3D|Tensor4D>({
x,
filter,
strides,
pad,
dataFormat = 'NHWC',
dilations = [1, 1],
dimRoundingMode,
bias,
activation = 'linear',
preluActivationWeights,
leakyreluAlpha
}: {
x: T|TensorLike,
filter: Tensor4D|TensorLike,
strides: [number, number]|number,
pad: 'valid'|'same'|number,
dataFormat?: 'NHWC'|'NCHW',
dilations?: [number, number]|number,
dimRoundingMode?: 'floor'|'round'|'ceil',
bias?: Tensor|TensorLike,
activation?: Activation,
preluActivationWeights?: Tensor,
leakyreluAlpha?: number
}): T {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = unfusedDepthwiseConv2d(
x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add(result, bias);
}
return applyActivation(
result, activation, preluActivationWeights, leakyreluAlpha) as T;
}
const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
const $filter =
convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
let x4D = $x as Tensor4D;
let reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
util.assert(
x4D.rank === 4,
() => `Error in fused depthwiseConv2d: input must be rank 4, but got ` +
`rank ${x4D.rank}.`);
util.assert(
$filter.rank === 4,
() => `Error in fused depthwiseConv2d: filter must be rank 4, ` +
`but got rank ${$filter.rank}.`);
util.assert(
x4D.shape[3] === $filter.shape[2],
() => `Error in fused depthwiseConv2d: number of input channels ` +
`(${x4D.shape[3]}) must match the inChannels dimension in ` +
`filter ${$filter.shape[2]}.`);
if (dilations == null) {
dilations = [1, 1];
}
util.assert(
conv_util.eitherStridesOrDilationsAreOne(strides, dilations),
() =>
'Error in fused depthwiseConv2d: Either strides or dilations must ' +
`be 1. Got strides ${strides} and dilations '${dilations}'`);
conv_util.checkPadOnDimRoundingMode(
'fused depthwiseConv2d', pad, dimRoundingMode);
const convInfo = conv_util.computeConv2DInfo(
x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode,
true /* depthwise */);
let $bias: Tensor;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused conv2d');
[$bias] = makeTypesMatch($bias, $x);
broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
}
let $preluActivationWeights: Tensor;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(
preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
}
const grad = (dy: Tensor4D, saved: Tensor[]) => {
util.assert(
conv_util.tupleValuesAreOne(dilations),
() => 'Error in gradient of fused depthwiseConv2d: dilation rates ' +
`greater than 1 are not yet supported. Got dilations ` +
`'${dilations}'`);
const [$filter, x4D, y, bias] = saved;
const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;
const xDer = depthwiseConv2dNativeBackpropInput(
(x4D as Tensor4D).shape, dyActivation, $filter as Tensor4D, strides,
pad, dilations, dimRoundingMode);
const filterDer = depthwiseConv2dNativeBackpropFilter(
x4D as Tensor4D, dyActivation, ($filter as Tensor4D).shape, strides,
pad, dilations, dimRoundingMode);
if (bias != null) {
const biasDer = getFusedBiasGradient($bias, dyActivation);
return [xDer, filterDer, biasDer];
}
return [xDer, filterDer];
};
const inputs: FusedDepthwiseConv2DInputs = {
x: x4D,
filter: $filter,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
const attrs: FusedDepthwiseConv2DAttrs = {
strides,
pad,
dataFormat,
dilations,
dimRoundingMode,
activation,
leakyreluAlpha
};
// Depending on the the params passed in we will have different number of
// inputs and thus a a different number of elements in the gradient.
if (bias == null) {
const customOp =
customGrad((x4D: Tensor4D, filter: Tensor4D, save: GradSaveFunc) => {
// tslint:disable-next-line: no-unnecessary-type-assertion
let res: Tensor4D|Tensor3D = ENGINE.runKernel(
FusedDepthwiseConv2D, inputs as {} as NamedTensorMap,
attrs as {} as NamedAttrMap);
save([filter, x4D, res]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as
Tensor3D;
}
return {value: res, gradFunc: grad};
});
return customOp(x4D, $filter) as T;
} else {
const customOpWithBias = customGrad(
(x4D: Tensor4D, filter: Tensor4D, bias: Tensor, save: GradSaveFunc) => {
// tslint:disable-next-line: no-unnecessary-type-assertion
let res: Tensor4D|Tensor3D = ENGINE.runKernel(
FusedDepthwiseConv2D, inputs as {} as NamedTensorMap,
attrs as {} as NamedAttrMap);
save([filter, x4D, res, bias]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as
Tensor3D;
}
return {value: res, gradFunc: grad};
});
return customOpWithBias(x4D, $filter, $bias) as T;
}
}