in libraries/value/src/CachingStrategies.cpp [355:593]
void BLASTCopy::HandleCachingImpl(LoopNest& nest)
{
/* BLAS T COPY:
suppose input matrix is M x N, cache size is M' x N', stripeSize = 4
so cache successive M'x4 row-major submatrices from the input matrix
0 1 2 3 16 17 18 19 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ...
4 5 6 7 20 21 22 23 -> 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
8 9 10 11 24 25 26 27
12 13 14 15 28 29 30 31
Need 2 layers of caching:
at M x N level, build up cache values
at stripeSize level, set up pointer and memory layout
*/
ValidateInputDimensionality(_value, _shape, _order);
// get block size, stripe size, and stripe slitting index from extras
auto extraParams = std::any_cast<std::tuple<int, Index, BoundaryConditionHandling>>(_extra);
int stripeSize;
Index stripeSplitIndex;
BoundaryConditionHandling boundaryHandling;
std::tie(stripeSize, stripeSplitIndex, boundaryHandling) = extraParams;
if (boundaryHandling == BoundaryConditionHandling::ZeroPadding && _shape[1] % stripeSize != 0)
{
// To avoid an odd repeated edge case, enforce that the number of cache columns is a multiple of the stripe size
// So the base 3D cache view can represent the full cache
throw InputException(InputExceptionErrors::invalidSize, "The number of cache columns must be a multiple of the cache stripe size");
}
// Cache structure:
// Lift the 2D submatrix into a 3D array to set up the cache simply
// The first dimension identifies which cached column block to use
// The second two dimensions identify the element inside of that cached submatrix block
// Index mapping: input ( i, j ) -> cache ( j / stripeSize, i, j % stripeSize )
// cache ( i, j, k ) -> input ( j, i * stripeSize + k )
// Boundary handling
// There are 4 boundary scenarios (possibly all 4 can happen in a single input matrix + cache size combination
// while iterating over the matrix):
// |-------N-------|
// |----N'---|
// _ _ *---------------*
// | | | | |
// | M'| 1 | 2 |
// | | | | |
// M _ |_________|_____|
// | | 3 | 4 |
// | | | |
// _ *---------------*
// 1 : The cache has exactly as many rows and columns as the input matrix chunk
// - This is the simple case, leave the cache as { M' x N' }
// 2 : The cache has more columns than the input matrix but fewer rows
// - re-view the cache to be { M' x remainingColumns }
// 3 : The cache has more rows than the input matrix but fewer columns
// - re-view the cache to be { remainingRows x N' }
// 4 : The cache has more rows and columns than the input matrix
// - re-view the cache to be { remainingRows x remainingColumns }
// Note: it is assumed that the input matrix is stepped over in splits based on the
// cache size given, so the cache can never be smaller than the input matrix chunk
// Since the matrix and cache sizes are known ahead of time, we can compute all of the boundary
// condition layouts that are needed:
// remainingRows = M % M'
// remainingColumns = N % N'
auto inputMatrix = Matrix(_value);
int inputRows = inputMatrix.Rows();
int inputCols = inputMatrix.Columns();
int remainingRows = inputRows % _shape[0];
int remainingCols = inputCols % _shape[1];
int roundedRemainingCols = RoundUpToMultiple(remainingCols, stripeSize);
// we don't need to round up remainingRows since stripe size only applies to columns in BLASTCopy
auto generateTCOPYCacheLayout = [stripeSize](int rows, int cols) {
auto cacheDimOrder = DimensionOrder{ 0, 1, 2 };
auto liftedShape = MemoryShape{ cols / stripeSize, rows, stripeSize };
auto cacheLayout = MemoryLayout{ liftedShape, cacheDimOrder };
return cacheLayout;
};
auto generateTCOPYCacheViewLayout = [stripeSize](int rows, int cols) {
auto cacheViewLayout = MemoryLayout{ { rows, stripeSize }, RowMajorMatrixOrder };
return cacheViewLayout;
};
auto baseCacheLayout = generateTCOPYCacheLayout(_shape[0], _shape[1]); // The non-boundary-case 3D lifted shape
auto baseCacheViewLayout = generateTCOPYCacheViewLayout(_shape[0], _shape[1]);
// "Boundary" condition 1 is the general case (i.e. non-boundary case)
auto boundaryConditionCacheLayout1 = baseCacheLayout;
auto cacheViewLayout1 = baseCacheViewLayout;
// Boundary condition 2, re-view to M' x remainingColumns
auto boundaryConditionCacheLayout2 = generateTCOPYCacheLayout(_shape[0], roundedRemainingCols);
auto cacheViewLayout2 = generateTCOPYCacheViewLayout(_shape[0], roundedRemainingCols);
// Boundary condition 3, re-view to remainingRows x N'
auto boundaryConditionCacheLayout3 = generateTCOPYCacheLayout(remainingRows, _shape[1]);
auto cacheViewLayout3 = generateTCOPYCacheViewLayout(remainingRows, _shape[1]);
// Boundary condition 4, re-view to remainingRows x remainingColumns
auto boundaryConditionCacheLayout4 = generateTCOPYCacheLayout(remainingRows, roundedRemainingCols);
auto cacheViewLayout4 = generateTCOPYCacheViewLayout(remainingRows, roundedRemainingCols);
auto cacheName = UniqueName("BLASTCopyCache");
_rawCache = StaticAllocate(cacheName, _value.GetBaseType(), baseCacheLayout);
Array liftedCache(_rawCache);
auto cacheRef = _rawCache.Reference();
cacheRef.SetLayout(baseCacheViewLayout);
cacheRef.SetName(cacheName + "_Ref");
auto cacheFillKernel = loopnests::Kernel(cacheName + "_Fill_Cache_Kernel")
.Inputs(_value, liftedCache)
.Indices(_kernelIndices)
.Define([remainingRows, remainingCols, stripeSize, shape = _shape, inputRows, inputCols, boundaryConditionCacheLayout1, boundaryConditionCacheLayout2, boundaryConditionCacheLayout3, boundaryConditionCacheLayout4](value::Matrix input, value::Array cache, value::Scalar i, value::Scalar j) {
// We may need to re-view the cache to a smaller layout if we have less
// data to cache than we have available space in the cache.
// If we re-view the cache then we can keep the smaller cached data
// physically contiguous while still using the same looping APIs
Scalar kernelRemainingRows = inputRows - i;
Scalar kernelRemainingCols = inputCols - j;
Scalar notEnoughRows = shape[0] > kernelRemainingRows;
Scalar notEnoughCols = shape[1] > kernelRemainingCols;
ZeroMemory(cache);
// Generate the cache fill loop in a parameterized lambda so we can emit the different layout versions independently
auto cacheFillLoop = [&](MemoryLayout cacheFillLayout, int rows, int cols) {
auto cacheFillView = cache.GetValue();
cacheFillView.SetLayout(cacheFillLayout);
auto reViewedCache = Array(cacheFillView);
ForRange(Scalar{ 0 }, Scalar{ cols / stripeSize }, [&](Scalar stripeColumnChunk) {
ForRange(Scalar{ 0 }, Scalar{ rows }, [&](Scalar row) {
ForRange(Scalar{ 0 }, Scalar{ stripeSize }, [&](Scalar stripeColumn) {
reViewedCache({ stripeColumnChunk, row, stripeColumn }) = input(i + row, j + stripeColumnChunk * stripeSize + stripeColumn);
});
});
});
auto finalColumnChunk = Scalar{ cols / stripeSize };
ForRange(Scalar{ 0 }, Scalar{ rows }, [&](Scalar row) {
ForRange(Scalar{ 0 }, Scalar{ cols % stripeSize }, [&](Scalar stripeColumn) {
reViewedCache({ finalColumnChunk, row, stripeColumn }) = input(i + row, j + finalColumnChunk * stripeSize + stripeColumn);
});
});
};
// Emit all of the different loops individually since the cache layouts are set at emit-time
If(notEnoughRows,
[&]() {
If(notEnoughCols,
[&]() {
// Boundary condition 4
cacheFillLoop(boundaryConditionCacheLayout4, remainingRows, remainingCols);
})
.Else(
[&]() {
// Boundary condition 3
cacheFillLoop(boundaryConditionCacheLayout3, remainingRows, shape[1]);
});
})
.ElseIf(notEnoughCols,
[&]() {
// Boundary condition 2
cacheFillLoop(boundaryConditionCacheLayout2, shape[0], remainingCols);
})
.Else(
[&]() {
// Boundary condition 1
cacheFillLoop(boundaryConditionCacheLayout1, shape[0], shape[1]);
});
});
auto& underlyingNest = nest.GetUnderlyingLoopNest();
underlyingNest.AddKernel(cacheFillKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, _atIndices, {} });
std::vector<Index> viewInitKernelIndices;
viewInitKernelIndices.assign(_kernelIndices.begin(), _kernelIndices.end());
viewInitKernelIndices.push_back(stripeSplitIndex);
auto viewInitKernel = loopnests::Kernel(cacheName + "_View_Init_Kernel")
.Inputs(liftedCache, cacheRef)
.Indices(viewInitKernelIndices)
.Define([shape = _shape, stripeSize, inputRows, inputCols, cacheViewLayout1, cacheViewLayout2, cacheViewLayout3, cacheViewLayout4, boundaryConditionCacheLayout1, boundaryConditionCacheLayout2, boundaryConditionCacheLayout3, boundaryConditionCacheLayout4](value::Array cache, value::Value cacheRef, value::Scalar i, value::Scalar j, value::Scalar jStripe) {
// To set up the view for the kernel to use, we need to set up the cacheRef reference
// so that a kernel indexing with (i, j) winds up in the right spot, pointing into the
// cached row-major submatrix that is the (j / stripeSize, ALL, ALL) slice of the cache array
// We may need to re-view the cache view to a smaller layout if we are in one of the boundary conditions
Scalar remainingRows = inputRows - i;
Scalar remainingCols = inputCols - j;
Scalar notEnoughRows = shape[0] > remainingRows;
Scalar notEnoughCols = shape[1] > remainingCols;
auto cacheViewFn = [&](MemoryLayout cacheLayout, MemoryLayout viewLayout) {
// Re-View the cache so we can index into the correct cached stripe
auto cacheView = cache.GetValue();
cacheView.SetLayout(cacheLayout);
auto cacheStripe = jStripe % shape[1]; // If N > N', make sure we index into the re-initialized cache position
auto indexedCacheView = cacheView.Offset({ cacheStripe / stripeSize, 0, 0 });
// Re-View the indexed cache as a 2-D matrix so we can position the offset pointer for use in the inner kernels
indexedCacheView.SetLayout(viewLayout);
auto offsetIndexedCacheView = indexedCacheView.Offset({ -1 * i, -1 * j });
offsetIndexedCacheView.SetLayout(viewLayout);
cacheRef.SetLayout(viewLayout);
cacheRef = offsetIndexedCacheView.Reference();
};
// Emit all of the views and offsets individually since the cache layouts are set at emit-time
If(notEnoughRows,
[&]() {
If(notEnoughCols,
[&]() {
// Boundary condition 4
cacheViewFn(boundaryConditionCacheLayout4, cacheViewLayout4);
})
.Else(
[&]() {
// Boundary condition 3
cacheViewFn(boundaryConditionCacheLayout3, cacheViewLayout3);
});
})
.ElseIf(notEnoughCols,
[&]() {
// Boundary condition 2
cacheViewFn(boundaryConditionCacheLayout2, cacheViewLayout2);
})
.Else(
[&]() {
// Boundary condition 1
cacheViewFn(boundaryConditionCacheLayout1, cacheViewLayout1);
});
});
underlyingNest.AddKernel(viewInitKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, { stripeSplitIndex }, {} });
underlyingNest.RenameVariable(_value, cacheRef, _atIndices, { cacheFillKernel, viewInitKernel });
}