void BLASTCopy::HandleCachingImpl()

in libraries/value/src/CachingStrategies.cpp [355:593]


    void BLASTCopy::HandleCachingImpl(LoopNest& nest)
    {
        /* BLAS T COPY:
        suppose input matrix is M x N, cache size is M' x N', stripeSize = 4
        so cache successive M'x4 row-major submatrices from the input matrix

         0  1  2  3 16 17 18 19      0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 ...
         4  5  6  7 20 21 22 23 ->  16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
         8  9 10 11 24 25 26 27
        12 13 14 15 28 29 30 31

        Need 2 layers of caching:
        at M x N level, build up cache values
        at stripeSize level, set up pointer and memory layout
         */

        ValidateInputDimensionality(_value, _shape, _order);

        // get block size, stripe size, and stripe slitting index from extras
        auto extraParams = std::any_cast<std::tuple<int, Index, BoundaryConditionHandling>>(_extra);
        int stripeSize;
        Index stripeSplitIndex;
        BoundaryConditionHandling boundaryHandling;
        std::tie(stripeSize, stripeSplitIndex, boundaryHandling) = extraParams;

        if (boundaryHandling == BoundaryConditionHandling::ZeroPadding && _shape[1] % stripeSize != 0)
        {
            // To avoid an odd repeated edge case, enforce that the number of cache columns is a multiple of the stripe size
            // So the base 3D cache view can represent the full cache
            throw InputException(InputExceptionErrors::invalidSize, "The number of cache columns must be a multiple of the cache stripe size");
        }

        // Cache structure:
        // Lift the 2D submatrix into a 3D array to set up the cache simply
        // The first dimension identifies which cached column block to use
        // The second two dimensions identify the element inside of that cached submatrix block
        // Index mapping: input ( i, j ) -> cache ( j / stripeSize, i, j % stripeSize )
        //                cache ( i, j, k ) -> input ( j, i * stripeSize + k )

        // Boundary handling
        // There are 4 boundary scenarios (possibly all 4 can happen in a single input matrix + cache size combination
        // while iterating over the matrix):
        //     |-------N-------|
        //     |----N'---|
        // _ _ *---------------*
        // | | |         |     |
        // | M'|    1    |  2  |
        // | | |         |     |
        // M _ |_________|_____|
        // |   |    3    |  4  |
        // |   |         |     |
        // _   *---------------*

        // 1 : The cache has exactly as many rows and columns as the input matrix chunk
        //     - This is the simple case, leave the cache as { M' x N' }
        // 2 : The cache has more columns than the input matrix but fewer rows
        //     - re-view the cache to be { M' x remainingColumns }
        // 3 : The cache has more rows than the input matrix but fewer columns
        //     - re-view the cache to be { remainingRows x N' }
        // 4 : The cache has more rows and columns than the input matrix
        //     - re-view the cache to be { remainingRows x remainingColumns }
        // Note: it is assumed that the input matrix is stepped over in splits based on the
        //     cache size given, so the cache can never be smaller than the input matrix chunk

        // Since the matrix and cache sizes are known ahead of time, we can compute all of the boundary
        // condition layouts that are needed:
        // remainingRows = M % M'
        // remainingColumns = N % N'

        auto inputMatrix = Matrix(_value);
        int inputRows = inputMatrix.Rows();
        int inputCols = inputMatrix.Columns();
        int remainingRows = inputRows % _shape[0];
        int remainingCols = inputCols % _shape[1];
        int roundedRemainingCols = RoundUpToMultiple(remainingCols, stripeSize);
        // we don't need to round up remainingRows since stripe size only applies to columns in BLASTCopy

        auto generateTCOPYCacheLayout = [stripeSize](int rows, int cols) {
            auto cacheDimOrder = DimensionOrder{ 0, 1, 2 };
            auto liftedShape = MemoryShape{ cols / stripeSize, rows, stripeSize };
            auto cacheLayout = MemoryLayout{ liftedShape, cacheDimOrder };
            return cacheLayout;
        };
        auto generateTCOPYCacheViewLayout = [stripeSize](int rows, int cols) {
            auto cacheViewLayout = MemoryLayout{ { rows, stripeSize }, RowMajorMatrixOrder };
            return cacheViewLayout;
        };

        auto baseCacheLayout = generateTCOPYCacheLayout(_shape[0], _shape[1]); // The non-boundary-case 3D lifted shape
        auto baseCacheViewLayout = generateTCOPYCacheViewLayout(_shape[0], _shape[1]);

        // "Boundary" condition 1 is the general case (i.e. non-boundary case)
        auto boundaryConditionCacheLayout1 = baseCacheLayout;
        auto cacheViewLayout1 = baseCacheViewLayout;

        // Boundary condition 2, re-view to M' x remainingColumns
        auto boundaryConditionCacheLayout2 = generateTCOPYCacheLayout(_shape[0], roundedRemainingCols);
        auto cacheViewLayout2 = generateTCOPYCacheViewLayout(_shape[0], roundedRemainingCols);

        // Boundary condition 3, re-view to remainingRows x N'
        auto boundaryConditionCacheLayout3 = generateTCOPYCacheLayout(remainingRows, _shape[1]);
        auto cacheViewLayout3 = generateTCOPYCacheViewLayout(remainingRows, _shape[1]);

        // Boundary condition 4, re-view to remainingRows x remainingColumns
        auto boundaryConditionCacheLayout4 = generateTCOPYCacheLayout(remainingRows, roundedRemainingCols);
        auto cacheViewLayout4 = generateTCOPYCacheViewLayout(remainingRows, roundedRemainingCols);

        auto cacheName = UniqueName("BLASTCopyCache");
        _rawCache = StaticAllocate(cacheName, _value.GetBaseType(), baseCacheLayout);
        Array liftedCache(_rawCache);

        auto cacheRef = _rawCache.Reference();
        cacheRef.SetLayout(baseCacheViewLayout);
        cacheRef.SetName(cacheName + "_Ref");

        auto cacheFillKernel = loopnests::Kernel(cacheName + "_Fill_Cache_Kernel")
                                   .Inputs(_value, liftedCache)
                                   .Indices(_kernelIndices)
                                   .Define([remainingRows, remainingCols, stripeSize, shape = _shape, inputRows, inputCols, boundaryConditionCacheLayout1, boundaryConditionCacheLayout2, boundaryConditionCacheLayout3, boundaryConditionCacheLayout4](value::Matrix input, value::Array cache, value::Scalar i, value::Scalar j) {
                                       // We may need to re-view the cache to a smaller layout if we have less
                                       // data to cache than we have available space in the cache.
                                       // If we re-view the cache then we can keep the smaller cached data
                                       // physically contiguous while still using the same looping APIs
                                       Scalar kernelRemainingRows = inputRows - i;
                                       Scalar kernelRemainingCols = inputCols - j;
                                       Scalar notEnoughRows = shape[0] > kernelRemainingRows;
                                       Scalar notEnoughCols = shape[1] > kernelRemainingCols;
                                       ZeroMemory(cache);

                                       // Generate the cache fill loop in a parameterized lambda so we can emit the different layout versions independently
                                       auto cacheFillLoop = [&](MemoryLayout cacheFillLayout, int rows, int cols) {
                                           auto cacheFillView = cache.GetValue();
                                           cacheFillView.SetLayout(cacheFillLayout);
                                           auto reViewedCache = Array(cacheFillView);

                                           ForRange(Scalar{ 0 }, Scalar{ cols / stripeSize }, [&](Scalar stripeColumnChunk) {
                                               ForRange(Scalar{ 0 }, Scalar{ rows }, [&](Scalar row) {
                                                   ForRange(Scalar{ 0 }, Scalar{ stripeSize }, [&](Scalar stripeColumn) {
                                                       reViewedCache({ stripeColumnChunk, row, stripeColumn }) = input(i + row, j + stripeColumnChunk * stripeSize + stripeColumn);
                                                   });
                                               });
                                           });
                                           auto finalColumnChunk = Scalar{ cols / stripeSize };
                                           ForRange(Scalar{ 0 }, Scalar{ rows }, [&](Scalar row) {
                                               ForRange(Scalar{ 0 }, Scalar{ cols % stripeSize }, [&](Scalar stripeColumn) {
                                                   reViewedCache({ finalColumnChunk, row, stripeColumn }) = input(i + row, j + finalColumnChunk * stripeSize + stripeColumn);
                                               });
                                           });
                                       };

                                       // Emit all of the different loops individually since the cache layouts are set at emit-time
                                       If(notEnoughRows,
                                          [&]() {
                                              If(notEnoughCols,
                                                 [&]() {
                                                     // Boundary condition 4
                                                     cacheFillLoop(boundaryConditionCacheLayout4, remainingRows, remainingCols);
                                                 })
                                                  .Else(
                                                      [&]() {
                                                          // Boundary condition 3
                                                          cacheFillLoop(boundaryConditionCacheLayout3, remainingRows, shape[1]);
                                                      });
                                          })
                                           .ElseIf(notEnoughCols,
                                                   [&]() {
                                                       // Boundary condition 2
                                                       cacheFillLoop(boundaryConditionCacheLayout2, shape[0], remainingCols);
                                                   })
                                           .Else(
                                               [&]() {
                                                   // Boundary condition 1
                                                   cacheFillLoop(boundaryConditionCacheLayout1, shape[0], shape[1]);
                                               });
                                   });

        auto& underlyingNest = nest.GetUnderlyingLoopNest();
        underlyingNest.AddKernel(cacheFillKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, _atIndices, {} });

        std::vector<Index> viewInitKernelIndices;
        viewInitKernelIndices.assign(_kernelIndices.begin(), _kernelIndices.end());
        viewInitKernelIndices.push_back(stripeSplitIndex);
        auto viewInitKernel = loopnests::Kernel(cacheName + "_View_Init_Kernel")
                                  .Inputs(liftedCache, cacheRef)
                                  .Indices(viewInitKernelIndices)
                                  .Define([shape = _shape, stripeSize, inputRows, inputCols, cacheViewLayout1, cacheViewLayout2, cacheViewLayout3, cacheViewLayout4, boundaryConditionCacheLayout1, boundaryConditionCacheLayout2, boundaryConditionCacheLayout3, boundaryConditionCacheLayout4](value::Array cache, value::Value cacheRef, value::Scalar i, value::Scalar j, value::Scalar jStripe) {
                                      // To set up the view for the kernel to use, we need to set up the cacheRef reference
                                      // so that a kernel indexing with (i, j) winds up in the right spot, pointing into the
                                      // cached row-major submatrix that is the (j / stripeSize, ALL, ALL) slice of the cache array

                                      // We may need to re-view the cache view to a smaller layout if we are in one of the boundary conditions
                                      Scalar remainingRows = inputRows - i;
                                      Scalar remainingCols = inputCols - j;
                                      Scalar notEnoughRows = shape[0] > remainingRows;
                                      Scalar notEnoughCols = shape[1] > remainingCols;

                                      auto cacheViewFn = [&](MemoryLayout cacheLayout, MemoryLayout viewLayout) {
                                          // Re-View the cache so we can index into the correct cached stripe
                                          auto cacheView = cache.GetValue();
                                          cacheView.SetLayout(cacheLayout);
                                          auto cacheStripe = jStripe % shape[1]; // If N > N', make sure we index into the re-initialized cache position
                                          auto indexedCacheView = cacheView.Offset({ cacheStripe / stripeSize, 0, 0 });

                                          // Re-View the indexed cache as a 2-D matrix so we can position the offset pointer for use in the inner kernels
                                          indexedCacheView.SetLayout(viewLayout);
                                          auto offsetIndexedCacheView = indexedCacheView.Offset({ -1 * i, -1 * j });
                                          offsetIndexedCacheView.SetLayout(viewLayout);
                                          cacheRef.SetLayout(viewLayout);
                                          cacheRef = offsetIndexedCacheView.Reference();
                                      };

                                      // Emit all of the views and offsets individually since the cache layouts are set at emit-time
                                      If(notEnoughRows,
                                         [&]() {
                                             If(notEnoughCols,
                                                [&]() {
                                                    // Boundary condition 4
                                                    cacheViewFn(boundaryConditionCacheLayout4, cacheViewLayout4);
                                                })
                                                 .Else(
                                                     [&]() {
                                                         // Boundary condition 3
                                                         cacheViewFn(boundaryConditionCacheLayout3, cacheViewLayout3);
                                                     });
                                         })
                                          .ElseIf(notEnoughCols,
                                                  [&]() {
                                                      // Boundary condition 2
                                                      cacheViewFn(boundaryConditionCacheLayout2, cacheViewLayout2);
                                                  })
                                          .Else(
                                              [&]() {
                                                  // Boundary condition 1
                                                  cacheViewFn(boundaryConditionCacheLayout1, cacheViewLayout1);
                                              });
                                  });
        underlyingNest.AddKernel(viewInitKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, { stripeSplitIndex }, {} });
        underlyingNest.RenameVariable(_value, cacheRef, _atIndices, { cacheFillKernel, viewInitKernel });
    }