in tfjs-core/src/ops/slice_util.ts [421:629]
export function sliceInfo(
xShape: number[], begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number, ellipsisMask: number,
newAxisMask: number, shrinkAxisMask: number): SliceInfo {
let stridesNonNull;
if (strides == null) {
stridesNonNull = new Array(begin.length);
stridesNonNull.fill(1);
} else {
stridesNonNull = strides;
}
// Only one non-zero bit is allowed in ellipsisMask, which means ellipsisMask
// is a power of 2. Use bit compares to ensure ellipsisMask is 0 or a power
// of 2. When i is a power of 2, i & (i - 1) is always 0.
// Also ref:
// https://stackoverflow.com/questions/600293/how-to-check-if-a-number-is-a-power-of-2
if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
throw new Error('Multiple ellipses in slice is not allowed.');
}
// Step 1: Account for ellipsis and new axis.
// Check for ellipsis and count how many non-newaxis there are after.
let ellipsisSeen = false;
const sparseSpec: StridedSliceSparseSpec = {
dims: stridesNonNull.length,
numAddAxisAfterEllipsis: 0,
begin: begin.slice(),
end: end.slice(),
strides: stridesNonNull.slice(),
beginMask,
endMask,
ellipsisMask,
newAxisMask,
shrinkAxisMask
};
for (let i = 0; i < sparseSpec.dims; i++) {
if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
sparseSpec.numAddAxisAfterEllipsis++;
}
if ((1 << i) & ellipsisMask) {
ellipsisSeen = true;
}
}
// If no ellipsis insert one at the end.
if (!ellipsisSeen) {
sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
sparseSpec.dims++; // this effects loop iteration below
}
// Step 2: Make a sparse spec into a full index spec.
//
// The sparse spec deos not correspond to the number of dimensions.
// Make a dense spec that cooresponds to the number of dimensions.
//
// For example suppose foo[...,3:] on foo.shape = [2, 2, 3] then we need to
// produce the missing beginMask for the first two dimensions i.e. from
// beginMaskSpec = 0, endMaskSpec = 2, we achieve beginMask = 6 (110),
// endMask = 7 (111).
const denseSpec: StridedSliceDenseSpec = {
dims: xShape.length,
beginMask: 0,
endMask: 0,
beginValid: false,
endValid: false
};
buildDenseSpec(sparseSpec, denseSpec);
// Step 3: Make implicit ranges (non-zero beginMasks and endMasks) explicit
// and bounds check.
let isIdentity = true;
let sliceDim0 = true;
let isSimpleSlice = true;
const processingShape = [];
const finalShape = [];
for (let i = 0; i < xShape.length; ++i) {
if (denseSpec.strides[i] === 0) {
throw Error(`strides[${i}] must be non-zero`);
}
const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
const dimI = xShape[i];
if (dimI === -1) {
processingShape.push(shrinkI ? 1 : -1);
continue;
}
const masks =
[denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
const validRange = [
denseSpec.strides[i] > 0 ? 0 : -1,
denseSpec.strides[i] > 0 ? dimI : dimI - 1
];
if (shrinkI && denseSpec.strides[i] <= 0) {
throw Error('only stride 1 allowed on non-range indexing.');
}
isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
const beginAndEndMasked =
!!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
if (denseSpec.beginValid && denseSpec.endValid) {
if (shrinkI) {
// If we are shrinking, the end index is now possibly incorrect. In
// particular foo[-1] produces sparseBegin = -1, sparseEnd = 0.
// and canonical puts these to n-1 and 0, which implies a degenerate
// interval. Fortunately, it is now safe to re-create end as begin + 1.
const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
denseSpec.begin[i];
denseSpec.begin[i] = xFwd;
denseSpec.end[i] = denseSpec.begin[i] + 1;
if (xFwd < 0 || xFwd >= dimI) {
throw Error(`slice index ${denseSpec.begin[i]} of dimension ${
i} out of bounds.`);
}
} else {
denseSpec.begin[i] = canonical(
denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks,
validRange);
denseSpec.end[i] = canonical(
denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
}
// Update optimization values
const takeAllInDimension = denseSpec.strides[i] === 1 &&
denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
isIdentity = isIdentity && takeAllInDimension;
sliceDim0 = sliceDim0 &&
((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
} else {
isIdentity =
isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
sliceDim0 = sliceDim0 &&
((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
}
// Compute the processing shape (the intermediate Eigen will produce)
let intervalLength;
let knownInterval = false;
if (denseSpec.beginValid && denseSpec.endValid) {
intervalLength = denseSpec.end[i] - denseSpec.begin[i];
knownInterval = true;
} else if (shrinkI) {
// The dimension is still known as 1 for the processingShape, but will be
// discarded for the final shape.
intervalLength = 1;
knownInterval = true;
} else if (beginAndEndMasked) {
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
if (dimI >= 0) {
if (denseSpec.strides[i] < 0) {
intervalLength = -dimI;
} else {
intervalLength = dimI;
}
knownInterval = true;
}
}
if (knownInterval) {
let sizeI;
// Hold zero if the interval is degenerate, otherwise account for
// remainder
if (intervalLength === 0 ||
((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
sizeI = 0;
} else {
sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
(intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
}
processingShape.push(sizeI);
} else {
processingShape.push(-1);
}
}
// Step 4: Compute the final shape
//
// newAxis will increase dimension by 1 (with a one-size dimension)
// slices like foo[3, ...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length;
++denseDim) {
const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
if (gatherIndex >= 0) {
finalShape.push(processingShape[gatherIndex]);
} else if (gatherIndex === NEW_AXIS) {
finalShape.push(1);
}
}
const finalShapeSparse = finalShape.filter(
(dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS);
return {
finalShapeSparse,
finalShape,
isIdentity,
sliceDim0,
isSimpleSlice,
begin: denseSpec.begin,
end: denseSpec.end,
strides: denseSpec.strides
};
}