void GeneralCachingStrategy::HandleCachingImpl()

in libraries/value/src/CachingStrategies.cpp [981:1940]


    void GeneralCachingStrategy::HandleCachingImpl(LoopNest& nest)
    {
        // General caching strategy:
        // Given:
        //     - input value
        //     - top level indices that the input uses
        //     - name for the cache
        //     - size of the cache to use in # of elements
        //     - # of elements to cache at a time ( < size of cache for progressive caching, > size of cache is an error)
        //     - Input / InputOutput / Output designation
        //     - Reduce function operating on individual Scalars
        //
        // Set up 3-4 kernels:
        //     - Cache flushing kernel
        //     - Cache filling kernel if Input/InputOutput
        //     - Cache viewing kernel (based on the shape of the input value)
        //     - Cache reduce kernel if InputOutput/Output

        auto extraParams = std::any_cast<std::tuple<value::ArgumentType,
                                                    std::string,
                                                    size_t,
                                                    size_t,
                                                    std::function<ReduceFunctionType>,
                                                    bool>>(_extra);
        value::ArgumentType argType;
        std::string baseName;
        size_t maxCacheElts;
        size_t fillThreshold; // fillThreshold <= maxCacheElts
        std::function<ReduceFunctionType> reduceFunction;
        bool accumulateReduce;
        std::tie(argType,
                 baseName,
                 maxCacheElts,
                 fillThreshold,
                 reduceFunction,
                 accumulateReduce) = extraParams;

        // Read target machine characteristics for number of SIMD registers and the size of the registers
        RegisterCharacteristics registerCharacteristics = GetRegisterCharacteristics(_value.GetBaseType());

        // Determine kernels needed
        bool useFillKernel = (argType == value::ArgumentType::Input || argType == value::ArgumentType::InputOutput);
        bool useViewKernel = true; // always include view kernel for simplicity for now, even if the re-viewing winds up being redundant
        bool useReduceKernel = (argType == value::ArgumentType::Output || argType == value::ArgumentType::InputOutput);

        size_t bufferAlignment = 16 * sizeof(float);

        InvokeForContext<CppEmitterContext>([&] {
            // TODO : Support buffer alignment in CppEmitterContext
            bufferAlignment = 0;
        });

        auto inputArray = Array(_value);
        int logicalDimensionCount = _value.GetLayout().NumDimensions();
        int compositeIndexCount = _kernelIndices.size();
        auto& underlyingNest = nest.GetUnderlyingLoopNest();

        const auto& loopSequence = underlyingNest.GetLoopSequence();
        std::vector<Index> orderedIndices;
        for (const auto& index : loopSequence)
        {
            const auto& dimensionIndex = underlyingNest.GetDimensionRange(index).GetDimensionIndex();
            auto indexIter = std::find(_kernelIndices.begin(), _kernelIndices.end(), dimensionIndex);
            if (indexIter != _kernelIndices.end())
            {
                orderedIndices.push_back(index);
            }
        }

        // Ensure we have some indices
        if (orderedIndices.empty())
        {
            throw InputException(InputExceptionErrors::invalidSize, "Don't have any indices relevant to this input for this loop nest");
        }

        // If there are no _atIndices specified, default to the outermost orderedIndices index
        if (_atIndices.empty())
        {
            _atIndices.push_back(orderedIndices.front());
        }

        // Compute the mapping between the orderedIndices list and the logical input dimensions
        std::vector<int> logicalDimensionMapping;
        logicalDimensionMapping.reserve(orderedIndices.size());

        // Determine the size for each split for each logical dimension
        // We only care about the split indices that are passed in as part of
        // orderedIndices, so instead of recording the sizes of those indices,
        // instead record the size of the full index range followed by the increments
        // of the each of the orderedIndices
        std::map<int, std::vector<int>> logicalDimensionSplitSizes;
        for (int logicalDimension = 0; logicalDimension < logicalDimensionCount; ++logicalDimension)
        {
            logicalDimensionSplitSizes[logicalDimension].push_back(_value.GetLayout().GetActiveSize(logicalDimension));
        }

        // Determine the increments for each split index in the orderedIndices
        // The cache dimensions all operate with logical increments of 1, so when we are mapping between input space and cache space
        // we need to scale appropriately by the split index increments for each split index
        std::vector<int> orderedIndexIncrements;
        orderedIndexIncrements.reserve(orderedIndices.size());

        for (const auto& index : orderedIndices)
        {
            // Compute the logical dimension mapping
            const auto& dimensionIndex = underlyingNest.GetDimensionRange(index).GetDimensionIndex();
            auto indexIter = std::find(_kernelIndices.begin(), _kernelIndices.end(), dimensionIndex);
            // Here we assume:
            //  - _kernelIndices is a vector or similar, so (iteror - begin) == idx of iterator
            //  - _kernelIndices is arranged in logical dimension ordering for this input
            int logicalDimension = indexIter - _kernelIndices.begin();
            logicalDimensionMapping.push_back(logicalDimension);

            // Find the index increment for this index to use for scaling index values to
            // convert between cache dimensions and input indices
            // Also use this for the logical dimension split sizes
            auto increment = underlyingNest.GetIndexRange(index).Increment();
            orderedIndexIncrements.push_back(increment);
            logicalDimensionSplitSizes[logicalDimension].push_back(increment);
        }

        // Compute the memory shape for the cache based on the index sizes in each logical
        // dimension. Each MemoryShape dimension counts the number of shards of the cache
        // that dimension indexes over, so the size of each MemoryShape dimension ought to be
        // the size of the index divided by the size of the next split index in the same
        // logical input dimension.
        // e.g. if Index i ranges over [0,64), and is split by 32, then by 16, then by 4
        //      we will have split indices [0,64,32), [0,32,16), [0,16,4), and [0,4,1),
        //      but suppose a cache doesn't use the second index, i.e. it only uses
        //      [0,64,32), [0,16,4), and [0,4,1), then the memory shape (for split dimensions
        //      in the i logical dimension) should be { 4, 4, 4 } since the outer index
        //      ranging from 0 to 64 accounts for 4 shards of 16
        //      and the next index ranging from 0 to 16 accounts for 4 shards of 4
        //      and the next index ranging from 0 to 4 accounts for 4 shards of 1
        //
        // Now that we have the base dimension size and all the increments for the indices we're using
        // we can compute the shard sizes for each logical dimension by dividing each dimension split
        // size we accumulated above with the size that comes after it, indicating how many instnaces of
        // the next shard occur within the current shard
        std::map<int, std::queue<int>> logicalIndexToShardSizes;
        std::map<int, std::queue<int>> logicalIndexToSizes; // Full element counts, not shard counts
        for (int logicalDimension = 0; logicalDimension < logicalDimensionCount; ++logicalDimension)
        {
            const auto& splitSizes = logicalDimensionSplitSizes[logicalDimension];
            for (unsigned splitIdx = 0; splitIdx < splitSizes.size() - 1; ++splitIdx)
            {
                int currentSize = splitSizes[splitIdx];
                int nextSize = splitSizes[splitIdx + 1];
                int shardSize = currentSize / nextSize;
                if (currentSize % nextSize != 0)
                {
                    // Round up to account for partial shards
                    shardSize++;
                }
                logicalIndexToShardSizes[logicalDimension].push(shardSize);
                logicalIndexToSizes[logicalDimension].push(currentSize);
            }
        }

        // Now that we have the shard sizes grouped by logical dimension, arrange them to match
        // the orderedIndices
        std::vector<int> orderedIndexShardSizes;
        std::vector<int> orderedIndexSizes; // Full element counts, not shard counts
        orderedIndexShardSizes.reserve(orderedIndices.size());
        orderedIndexSizes.reserve(orderedIndices.size());
        for (unsigned idx = 0; idx < logicalDimensionMapping.size(); ++idx)
        {
            int logicalDimension = logicalDimensionMapping[idx];

            orderedIndexShardSizes.push_back(logicalIndexToShardSizes[logicalDimension].front());
            logicalIndexToShardSizes[logicalDimension].pop();

            orderedIndexSizes.push_back(logicalIndexToSizes[logicalDimension].front());
            logicalIndexToSizes[logicalDimension].pop();
        }

        // Create a MemoryShape for the cache based on the shard counts
        // This isn't the final cache shape and layout yet - we may need to shrink it to fit the number
        // of elements requested in the cache
        MemoryShape fullInputShape = { orderedIndexShardSizes };
        MemoryLayout fullInputLayout = { fullInputShape };

        // Physical Cache
        // Determine how large the physical cache ought to be by trying to cover complete view
        // dimensions without exceeding maxCacheElts elements in size.
        // e.g. if the full view has 5 dimensions, and our maxCacheElts only covers the inner most two dimensions,
        //      then the cache size is set to that size and we create our "fill" and "reduce" kernels accordingly
        // To achieve this, start from the base full cache layout and slice off physical dimensions going from the
        // outermost to the innermost until the full extent has no more than maxCacheElts elements
        MemoryLayout cacheLayout = fullInputLayout;
        unsigned cacheThresholdIdx = 0;
        while (cacheLayout.GetMemorySize() > maxCacheElts)
        {
            cacheLayout = cacheLayout.GetSliceLayout(0);
            cacheThresholdIdx++;
        }
        if (cacheLayout.NumElements() == 0)
        {
            throw InputException(InputExceptionErrors::invalidSize, "Specified cache size isn't large enough to cover the smallest dimension of the cache layout");
        }
        std::vector<int> cacheOrderedIndexSizes(orderedIndexSizes.begin() + cacheThresholdIdx, orderedIndexSizes.end());
        std::vector<int> cacheLogicalDimensionMapping(logicalDimensionMapping.begin() + cacheThresholdIdx, logicalDimensionMapping.end());
        std::vector<int> cacheOrderedIndexIncrements(orderedIndexIncrements.begin() + cacheThresholdIdx, orderedIndexIncrements.end());
        auto cacheName = UniqueName(baseName);
        _rawCache = StaticAllocate(cacheName, _value.GetBaseType(), cacheLayout);

        // Progresive Caching
        // To enable progressive caching, where a subset of the full physical cache is
        // filled and used, then later the next chunk of the physical cache is filled
        // and used, we need to find the dimension split at which fillThreshold elements
        // is surpassed and set up a fill kernel at that point
        // If fillThreshold == maxCacheElts or they are both exceeded in the same
        // split, then ensure that the fill kernel occurs after the cache emptying kernel
        if (fillThreshold > maxCacheElts)
        {
            throw InputException(InputExceptionErrors::invalidArgument, "Fill threshold can't be larger than the max cache size");
        }
        unsigned cacheFillThresholdIdx = cacheThresholdIdx;
        MemoryLayout cacheFillLayout = cacheLayout;
        while (cacheFillLayout.GetMemorySize() > fillThreshold)
        {
            cacheFillLayout = cacheFillLayout.GetSliceLayout(0);
            cacheFillThresholdIdx++;
        }
        if (cacheFillLayout.NumElements() == 0)
        {
            throw InputException(InputExceptionErrors::invalidSize, "Specified cache fill threshold size isn't large enough to cover the smallest dimension of the cache layout");
        }
        std::vector<int> cacheFillOrderedIndexSizes(orderedIndexSizes.begin() + cacheFillThresholdIdx, orderedIndexSizes.end());
        std::vector<int> cacheFillLogicalDimensionMapping(logicalDimensionMapping.begin() + cacheFillThresholdIdx, logicalDimensionMapping.end());
        std::vector<int> cacheFillOrderedIndexIncrements(orderedIndexIncrements.begin() + cacheFillThresholdIdx, orderedIndexIncrements.end());

        // Cache View
        // The cache view needs to have the same number of dimensions as the input value
        // but cover an area that is a subset of the full cache and represents one cache
        // dimension per logical input dimension.
        // This may mean that for some of the logical input dimensions, the cache view
        // size is 1, e.g. suppose a 3-D input is cached where the inner 3 dimensions of
        // the cache only operate over two of the logical dimensions of the input while the
        // two innermost dimensions of those operate over the two distinct input logical
        // dimensions. In that case the cache view would cover the inner two cache dimensions
        // and have a 1 for the third dimension size.
        // In general, the cache view needs to cover an area of the cache that can be
        // contiguously represented like the logical input value.

        // To build up the cache view layout, start from the innermost dimension of the
        // cache layout and accumulate dimensions going outward until either all of the
        // logical input dimensions are accounted for or one of the logical input dimensions
        // repeats. However, when a single dimension is repeated multiple times in a row,
        // those repeats can be collapsed into a single visiting of that dimension. These
        // can be collapsed because the logical behavior is the same regardless of whether
        // the split that produced the repeated dimension was made or not.
        // E.g. suppose your dimensions are { 0, 0, 1, 1, 1, 0, 0 }, then the first pair of
        //      0's can be collapsed and treated like a single visiting of that dimension,
        //      the set of 3 1's can be collapsed, and the final pair of 0's can be collapsed,
        //      producing a collapsed dimension ordering of { 0, 1, 0 }. With a collapsed
        //      dimension ordering of { 0, 1, 0 }, the cache view needs to break at the inner
        //      { 1, 0 }, because after that a dimension (the 0 dimension) will repeat.
        MemoryLayout baseCacheViewLayout;
        unsigned cacheViewThresholdIdxOffset;
        std::tie(baseCacheViewLayout, cacheViewThresholdIdxOffset) = ComputeCacheView(cacheFillLayout,
                                                                                      cacheFillLogicalDimensionMapping,
                                                                                      logicalDimensionCount);
        unsigned cacheViewThresholdIdx = cacheFillThresholdIdx + cacheViewThresholdIdxOffset;

        auto cacheRef = _rawCache.Reference();
        cacheRef.SetLayout(baseCacheViewLayout);

        // Boundary Conditions
        // Boundary conditions occur when the region of the input value that we want
        // to cache does not fill the physical cache,
        // e.g. for a matrix cache there are 4 cases, 3 of which are considered boundary condition cases:
        //      Suppose the matrix is M x N and the physical cache is sized to hold M' x N' elements,
        //      where M / 2 < M' < M, N / 2 < N' < N
        //     |-------N-------|
        //     |----N'---|----N'---|
        // _ _ *---------------*
        // | | |         |     |
        // | M'|    1    |  2  |
        // | | |         |     |
        // M _ |_________|_____|
        // | | |    3    |  4  |
        // | M'|         |     |
        // _ | *---------------*
        //   _
        // 1 : The cache has exactly as many rows and columns as the input matrix chunk
        // 2 : The cache has more columns than the matrix chunk but just as many rows
        // 3 : The cache has more rows than the matrix chunk but just as many columns
        // 4 : The cache has more rows and columns than the matrix chunk
        //
        // One possible solution is to zero-pad the cache and keep the layout as-is. This would certainly work
        //
        // However, in order to maximize data locality in the cache (which is the purpose of the cache),
        // we would prefer it if the cache were reshaped such that the input value chunk
        // fills the cache from the beginning until the end of the chunk without any gaps.
        // This reshape amounts to shrinking the cache sizes in some dimensions, however to preserve
        // vectorization behavior we avoid shrinking the innermost dimension and instead zero-pad
        // that dimension
        unsigned cacheFillThresholdIdxOffset = cacheFillThresholdIdx - cacheThresholdIdx;
        unsigned cacheViewThresholdIdxCacheOffset = cacheViewThresholdIdxOffset + cacheFillThresholdIdxOffset;
        BoundaryConditionMemoryLayoutHelper boundaryConditionCacheHelper(_value.GetLayout().GetActiveSize(), cacheOrderedIndexSizes, cacheLogicalDimensionMapping, cacheOrderedIndexIncrements, cacheFillThresholdIdxOffset, cacheViewThresholdIdxCacheOffset);

        std::vector<loopnests::Kernel> cachingKernels;

        {
            // Flush the cache to implicitly zero-pad any regions of the cache we don't fill later
            std::vector<Index> cacheFlushPosition(orderedIndices.begin(), orderedIndices.begin() + cacheThresholdIdx);
            auto cacheEmptyKernel = loopnests::Kernel(cacheName + "_Empty_Cache_Kernel")
                                        .Inputs(_rawCache)
                                        .Indices()
                                        .Define([](Value cache) {
                                            // TODO : determine if a vectorized approach is worthwhile here
                                            ZeroMemory(cache);
                                        });

            underlyingNest.AddKernel(cacheEmptyKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, cacheFlushPosition, {} });
            cachingKernels.push_back(cacheEmptyKernel);
        }
        if (useFillKernel)
        {
            std::vector<Index> cacheFillPosition(orderedIndices.begin(), orderedIndices.begin() + cacheFillThresholdIdx);
            std::vector<Index> cacheFillIndices(_kernelIndices.begin(), _kernelIndices.end());
            cacheFillIndices.insert(cacheFillIndices.end(), cacheFillPosition.begin(), cacheFillPosition.end());

            auto cacheFillKernel = loopnests::Kernel(cacheName + "_Fill_Cache_Kernel")
                                       .Inputs(_value, _rawCache)
                                       .Indices(cacheFillIndices)
                                       .DefineEx([=](std::vector<Value> values, std::vector<Scalar> indices) {
                                           auto& input = values[0];
                                           auto& cache = values[1];
                                           std::vector<Scalar> compositeIndexValues(indices.begin(), indices.begin() + compositeIndexCount);
                                           std::vector<Scalar> splitIndexValues(indices.begin() + compositeIndexCount, indices.end());

                                           auto offsetInput = input.Offset(compositeIndexValues);
                                           offsetInput.SetLayout(input.GetLayout());
                                           auto offsetInputArrayView = Array(offsetInput);

                                           boundaryConditionCacheHelper.EmitBoundarySwitches(compositeIndexValues, [=](MemoryLayout inputRegionShape, MemoryLayout inputRegionFillShape, MemoryLayout boundaryCacheLayout, MemoryLayout boundaryCacheFillLayout) {
                                               // Offset the cache write head based on the where we're at in the progressive caching
                                               // Since fillThreshold <= maxCacheElts, we may run this kernel multiple times filling
                                               // different portions of the cache, so we look at the indices between the
                                               // cacheThresholdIdx and the cacheFillThresholdIdx to find what position we need to
                                               // offset to
                                               // these indices all map in order to the dimensions that are in the cache and outside
                                               // the fill region since the cache memory ordering is based on these indices in this order

                                               auto cacheView = cache;
                                               cacheView.SetLayout(boundaryCacheLayout);
                                               std::vector<Scalar> cacheOffsetIndices;
                                               cacheOffsetIndices.reserve(boundaryCacheLayout.NumDimensions());

                                               // Note: if cacheThresholdIdx == cacheFillThresholdIdx (i.e. if there is no progressive caching)
                                               // Then the first loop is skipped and no offsetting occurs, and therefore filling the cache from
                                               // the beginning every time this kernel is run
                                               for (unsigned idx = cacheThresholdIdx; idx < cacheFillThresholdIdx; ++idx)
                                               {
                                                   // Mapping loopnest indices (input space) -> cache offsets (cache space) so divide by split index increment
                                                   cacheOffsetIndices.push_back(splitIndexValues[idx] / orderedIndexIncrements[idx]);
                                               }
                                               for (unsigned idx = cacheFillThresholdIdx; idx < static_cast<unsigned>(fullInputLayout.NumDimensions()); ++idx)
                                               {
                                                   cacheOffsetIndices.push_back(Scalar{ 0 });
                                               }
                                               auto offsetCache = cacheView.Offset(cacheOffsetIndices);
                                               offsetCache.SetLayout(boundaryCacheFillLayout);
                                               auto cacheFillArrayView = Array(offsetCache);

                                               // Prefer input-oriented loops to maximize locality as the input
                                               // is likely to be larger than the cache in most cases
                                               // Based on the element size and counts in different dimensions,
                                               // we will split and unroll some of the inner loops in order to maximize
                                               // vectorization.
                                               // In order to get appropriate utilization of all the SIMD
                                               // registers, we will need to use a temporary buffer (which we expect
                                               // the compiler to optimize away) with a size equal to the total number
                                               // of elements that can be held in all of the SIMD registers.
                                               // The filling of this temporary buffer from the input needs to be an
                                               // unrolled operation and the filling of the cache from the temporary
                                               // buffer also needs to be an unrolled operation that happens after
                                               // the full temporary buffer has been filled.
                                               // Therefore, we need multiple levels of loopnests so that the area
                                               // outside of the temporary buffer's addressable region can be looped
                                               // over, and the area inside the temporary buffer region can have two
                                               // sequential fully unrolled loopnests.
                                               // new loopnest (outer):
                                               // For ... {
                                               //   For ... {
                                               //       // start of outer loopnest prologue kernel
                                               //       // Fill temp buf
                                               //       new loopnest (inner #1):
                                               //       For ... (unroll) {
                                               //           For ... (unroll) {
                                               //               ... {
                                               //                   // start of inner loopnest #1 kernel
                                               //                   tempBuf(tempBufIndices) = input(inputIndices)
                                               //                   // end of inner loopnest #1 kernel
                                               //               }
                                               //               ...
                                               //           }
                                               //       }
                                               //       // Fill cache
                                               //       new loopnest (inner #2):
                                               //       For ... (unroll) {
                                               //           For ... (unroll) {
                                               //               ... {
                                               //                   // start of inner loopnest #2 kernel
                                               //                   cache(cacheIndices) = tempBuf(tempBufIndices)
                                               //                   // end of inner loopnest #2 kernel
                                               //               }
                                               //               ...
                                               //           }
                                               //       }
                                               //       // end of outer loopnest kernel
                                               //   }
                                               // }

                                               std::vector<loopnests::Index> fillIndices;
                                               fillIndices.reserve(inputRegionFillShape.NumDimensions());
                                               for (int idx = 0; idx < inputRegionFillShape.NumDimensions(); ++idx)
                                               {
                                                   fillIndices.push_back(loopnests::Index("fillIdx_" + std::to_string(idx)));
                                               }

                                               // Define LoopNest
                                               auto fillNest = Using({ offsetInputArrayView }, ArgumentType::Input)
                                                                   .Using({ cacheFillArrayView }, ArgumentType::Output);
                                               for (int idx = 0; idx < inputRegionFillShape.NumDimensions(); ++idx)
                                               {
                                                   fillNest.ForAll(fillIndices[idx], 0, inputRegionFillShape.GetActiveSize(idx));
                                               }

                                               const int VectorizationSize = registerCharacteristics.NumberOfElementsPerSIMDRegister;
                                               int maximumElementsInTempBuf = registerCharacteristics.NumberOfSIMDRegisters * VectorizationSize;
                                               std::vector<int> indexSplitSizes(fillIndices.size());
                                               std::vector<int> tmpBufDimensionMapping(indexSplitSizes.size());

                                               // Handle the innermost input dimension differently since we'll be counting elements there instead of shards of a memory layout
                                               int shardSize = VectorizationSize;
                                               int totalElementsPerShard = VectorizationSize;
                                               for (unsigned idx = fillIndices.size() - 1; fillIndices.size() > idx; --idx)
                                               {
                                                   int availableShardsInTmpBuf = maximumElementsInTempBuf / totalElementsPerShard;
                                                   int inputDimAvailableShards = inputRegionFillShape.GetActiveSize(idx) / shardSize;
                                                   int numShards = std::min(availableShardsInTmpBuf, inputDimAvailableShards);
                                                   tmpBufDimensionMapping[idx] = inputRegionFillShape.GetLogicalDimension(idx);
                                                   if (numShards > 1)
                                                   {
                                                       indexSplitSizes[idx] = numShards * shardSize;
                                                       shardSize = 1; // After the initial vectorization size, we target units of entire memory layout shards
                                                       totalElementsPerShard *= numShards; // The number of elements represented by a target scales with the number of inner targets it represents
                                                   }
                                                   else
                                                   {
                                                       indexSplitSizes[idx] = 1;
                                                   }
                                               }
                                               // The index split sizes are measured in input-space, so no scaling is needed
                                               std::vector<int> tmpBufScaleFactors(indexSplitSizes.size(), 1);

                                               BoundaryConditionMemoryLayoutHelper fillKernelBoundaryHelper(inputRegionFillShape.GetActiveSize(),
                                                                                                            indexSplitSizes,
                                                                                                            tmpBufDimensionMapping,
                                                                                                            tmpBufScaleFactors,
                                                                                                            0, // Fill index doesn't matter for this usage
                                                                                                            tmpBufDimensionMapping.size()); // Shrink any index split sizes needed since we don't have a "view" to worry about

                                               auto cacheFillInternalKernel = loopnests::Kernel("Internal_Fill_Cache_Outer_Kernel")
                                                                                  .Inputs(offsetInputArrayView, cacheFillArrayView)
                                                                                  .Indices(fillIndices)
                                                                                  .DefineEx([=](std::vector<Value> values, std::vector<Scalar> innerIndices) {
                                                                                      Array offsetInput = values[0];
                                                                                      Array cacheFillView = values[1];

                                                                                      Value offsetInputInnerVal = offsetInput.GetValue().Offset(innerIndices);
                                                                                      offsetInputInnerVal.SetLayout(offsetInput.GetValue().GetLayout());
                                                                                      Array offsetInputInner = offsetInputInnerVal;

                                                                                      std::vector<Scalar> cacheIndices;
                                                                                      cacheIndices.reserve(boundaryCacheFillLayout.NumDimensions());
                                                                                      for (int cacheDimIdx = 0; cacheDimIdx < boundaryCacheFillLayout.NumDimensions(); ++cacheDimIdx)
                                                                                      {
                                                                                          unsigned baseDimIdx = cacheFillThresholdIdx + cacheDimIdx;
                                                                                          int logicalDimension = logicalDimensionMapping[baseDimIdx];
                                                                                          // Mapping loopnest indices (input space) -> cache indices (cache space) so divide by split index increment
                                                                                          cacheIndices.push_back((innerIndices[logicalDimension] / orderedIndexIncrements[baseDimIdx]) % boundaryCacheFillLayout.GetActiveSize(cacheDimIdx));
                                                                                      }
                                                                                      Value offsetCacheInnerVal = cacheFillView.GetValue().Offset(cacheIndices);
                                                                                      offsetCacheInnerVal.SetLayout(cacheFillView.GetValue().GetLayout());
                                                                                      Array offsetCacheInner = offsetCacheInnerVal;

                                                                                      fillKernelBoundaryHelper.EmitBoundarySwitches(innerIndices, [=](MemoryLayout fillRegionShape, MemoryLayout, MemoryLayout boundaryTempBufLayout, MemoryLayout) {
                                                                                          Array tmpBuf = Allocate(offsetInput.Type(), boundaryTempBufLayout, bufferAlignment);

                                                                                          std::vector<loopnests::Index> tmpBufInputIndices;

                                                                                          tmpBufInputIndices.reserve(fillRegionShape.NumDimensions());
                                                                                          for (int idx = 0; idx < fillRegionShape.NumDimensions(); ++idx)
                                                                                          {
                                                                                              tmpBufInputIndices.push_back(loopnests::Index("tmpBuf_FillIdx_" + std::to_string(idx)));
                                                                                          }

                                                                                          auto tmpBufFillNest = Using({ offsetInputInner }, ArgumentType::Input)
                                                                                                                    .Using({ tmpBuf }, ArgumentType::Output);
                                                                                          for (int idx = 0; idx < fillRegionShape.NumDimensions(); ++idx)
                                                                                          {
                                                                                              tmpBufFillNest.ForAll(tmpBufInputIndices[idx], 0, fillRegionShape.GetActiveSize(idx));
                                                                                          }

                                                                                          auto tmpBufFill = loopnests::Kernel("Internal_TmpBuf_FillTmpBuf_Kernel")
                                                                                                                .Inputs(offsetInputInner, tmpBuf)
                                                                                                                .Indices(tmpBufInputIndices)
                                                                                                                .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufInputIndices) {
                                                                                                                    Array offsetInputInner = tmpBufValues[0];
                                                                                                                    Array tmpBuf = tmpBufValues[1];

                                                                                                                    tmpBuf(tmpBufInputIndices) = offsetInputInner(tmpBufInputIndices);
                                                                                                                });
                                                                                          tmpBufFillNest.Do(tmpBufFill);
                                                                                          auto& tmpBufFillSchedule = tmpBufFillNest.GetSchedule();
                                                                                          // unroll everything
                                                                                          for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                          {
                                                                                              tmpBufFillSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                          }
                                                                                          tmpBufFillNest.Run();

                                                                                          // Cache fill from tmp buf
                                                                                          auto cacheFillNest = Using({ tmpBuf }, ArgumentType::Input)
                                                                                                                   .Using({ offsetCacheInner }, ArgumentType::Output);
                                                                                          for (int idx = 0; idx < tmpBuf.GetValue().GetLayout().NumDimensions(); ++idx)
                                                                                          {
                                                                                              cacheFillNest.ForAll(tmpBufInputIndices[idx], 0, tmpBuf.GetValue().GetLayout().GetActiveSize(idx));
                                                                                          }

                                                                                          auto cacheFill = loopnests::Kernel("Internal_TmpBuf_FillCache_Kernel")
                                                                                                               .Inputs(tmpBuf, offsetCacheInner)
                                                                                                               .Indices(tmpBufInputIndices)
                                                                                                               .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufIndices) {
                                                                                                                   Array tmpBuf = tmpBufValues[0];
                                                                                                                   Array offsetCacheInner = tmpBufValues[1];

                                                                                                                   int cacheDimensions = offsetCacheInner.GetValue().GetLayout().NumDimensions();
                                                                                                                   std::vector<Scalar> cacheIndices;
                                                                                                                   cacheIndices.reserve(cacheDimensions);
                                                                                                                   for (int cacheDimIdx = 0; cacheDimIdx < cacheDimensions; ++cacheDimIdx)
                                                                                                                   {
                                                                                                                       unsigned baseDimIdx = cacheFillThresholdIdx + cacheDimIdx;
                                                                                                                       int logicalDimension = logicalDimensionMapping[baseDimIdx];
                                                                                                                       // Mapping loopnest indices (input space) -> cache indices (cache space) so divide by split index increment
                                                                                                                       cacheIndices.push_back((tmpBufIndices[logicalDimension] / orderedIndexIncrements[baseDimIdx]) % boundaryCacheFillLayout.GetActiveSize(cacheDimIdx));
                                                                                                                   }
                                                                                                                   offsetCacheInner(cacheIndices) = tmpBuf(tmpBufIndices);
                                                                                                               });
                                                                                          cacheFillNest.Do(cacheFill);
                                                                                          auto& cacheFillSchedule = cacheFillNest.GetSchedule();
                                                                                          for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                          {
                                                                                              cacheFillSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                          }
                                                                                          cacheFillNest.Run();
                                                                                      });
                                                                                  });

                                               auto& schedule = fillNest.GetSchedule();
                                               std::vector<loopnests::Index> splitOuterIndices;
                                               for (unsigned idx = 0; idx < fillIndices.size(); ++idx)
                                               {
                                                   if (indexSplitSizes[idx] > 1)
                                                   {
                                                       splitOuterIndices.push_back(schedule.Split(fillIndices[idx], indexSplitSizes[idx]));
                                                   }
                                                   else
                                                   {
                                                       splitOuterIndices.push_back(fillIndices[idx]);
                                                   }
                                               }

                                               fillNest.Do(cacheFillInternalKernel, splitOuterIndices);

                                               fillNest.Run();
                                           });
                                       });

            underlyingNest.AddKernel(cacheFillKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, cacheFillPosition, {} });
            cachingKernels.push_back(cacheFillKernel);
        }

        if (useViewKernel)
        {
            // The cache view indices are all of the indices that occur before the cacheViewThresholdIdx
            std::vector<Index> cacheViewPosition(orderedIndices.begin(), orderedIndices.begin() + cacheViewThresholdIdx);
            std::vector<Index> cacheViewIndices(_kernelIndices.begin(), _kernelIndices.end());
            cacheViewIndices.insert(cacheViewIndices.end(), cacheViewPosition.begin(), cacheViewPosition.end());

            auto cacheViewKernel = loopnests::Kernel(cacheName + "_View_Cache_Kernel")
                                       .Inputs(_rawCache, cacheRef)
                                       .Indices(cacheViewIndices)
                                       .DefineEx([boundaryConditionCacheHelper, compositeIndexCount, fullInputLayout, cacheLayout, baseCacheViewLayout, cacheLogicalDimensionMapping, logicalDimensionMapping, orderedIndices, orderedIndexIncrements, cacheThresholdIdx, cacheViewThresholdIdx, logicalDimensionCount](std::vector<Value> values, std::vector<Scalar> indices) {
                                           auto& cache = values[0];
                                           auto& cacheRef = values[1];
                                           std::vector<Scalar> compositeIndexValues(indices.begin(), indices.begin() + compositeIndexCount);
                                           std::vector<Scalar> splitIndexValues(indices.begin() + compositeIndexCount, indices.end());

                                           boundaryConditionCacheHelper.EmitBoundarySwitches(compositeIndexValues, [&](MemoryLayout inputRegionShape, MemoryLayout inputRegionFillShape, MemoryLayout boundaryCacheLayout, MemoryLayout boundaryCacheFillLayout) {
                                               // Find the view slice in the cache for this offset
                                               // The indices in [cacheThresoldIdx, cacheViewThresholdIdx) in the indices determine which slice to use
                                               std::vector<Scalar> cacheOffsetIndices;
                                               cacheOffsetIndices.reserve(cacheLayout.NumDimensions());

                                               // Note: if cacheThresholdIdx == cacheViewThresholdIdx (i.e. if there is no repeated re-viewing of the cache)
                                               // Then the first loop is skipped and no offsetting occurs
                                               auto cacheView = cache;
                                               for (unsigned idx = cacheThresholdIdx; idx < cacheViewThresholdIdx; ++idx)
                                               {
                                                   // Mapping loopnest indices (input space) -> cache offsets (cache space) so divide by split index increment
                                                   cacheOffsetIndices.push_back(splitIndexValues[idx] / orderedIndexIncrements[idx]);
                                               }
                                               for (unsigned idx = cacheViewThresholdIdx; idx < static_cast<unsigned>(fullInputLayout.NumDimensions()); ++idx)
                                               {
                                                   cacheOffsetIndices.push_back(Scalar{ 0 });
                                               }

                                               cacheView.SetLayout(boundaryCacheLayout);
                                               auto offsetCache = cacheView.Offset(cacheOffsetIndices);
                                               offsetCache.SetLayout(baseCacheViewLayout);

                                               // Offset the cache ref from the base cache such that indexing with the current loop values
                                               // would offset a pointer to the beginning of this view of the cache
                                               std::vector<Scalar> offsetIndices(logicalDimensionCount);
                                               for (int idx = 0; idx < logicalDimensionCount; ++idx)
                                               {
                                                   offsetIndices[idx] -= compositeIndexValues[idx];
                                               }

                                               auto offsetCacheView = offsetCache.Offset(offsetIndices);
                                               offsetCacheView.SetLayout(baseCacheViewLayout);
                                               cacheRef.SetLayout(baseCacheViewLayout);
                                               cacheRef = offsetCacheView.Reference();
                                           });
                                       });

            underlyingNest.AddKernel(cacheViewKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::prologue, cacheViewPosition, {} });
            cachingKernels.push_back(cacheViewKernel);
        }

        if (useReduceKernel)
        {
            // The cache reduce indices are all of the indices that occur before the cacheThresholdIdx
            // Because the reduce is symmetric with the cache non-progressive fill / flush level of a loop nest
            std::vector<Index> cacheReducePosition(orderedIndices.begin(), orderedIndices.begin() + cacheThresholdIdx);
            std::vector<Index> cacheReduceIndices(_kernelIndices.begin(), _kernelIndices.end());
            cacheReduceIndices.insert(cacheReduceIndices.end(), cacheReducePosition.begin(), cacheReducePosition.end());

            auto cacheReduceKernel = loopnests::Kernel(cacheName + "_Reduce_Kernel")
                                         .Inputs(_value, _rawCache)
                                         .Indices(cacheReduceIndices)
                                         .DefineEx([=](std::vector<Value> values, std::vector<Scalar> indices) {
                                             auto& input = values[0];
                                             auto& cache = values[1];
                                             std::vector<Scalar> compositeIndexValues(indices.begin(), indices.begin() + compositeIndexCount);
                                             std::vector<Scalar> splitIndexValues(indices.begin() + compositeIndexCount, indices.end());

                                             auto offsetInput = input.Offset(compositeIndexValues);
                                             offsetInput.SetLayout(input.GetLayout());
                                             auto offsetInputArrayView = Array(offsetInput);

                                             boundaryConditionCacheHelper.EmitBoundarySwitches(compositeIndexValues, [=](MemoryLayout inputRegionShape, MemoryLayout, MemoryLayout boundaryCacheLayout, MemoryLayout) {
                                                 auto cacheArrayView = Array(cache);

                                                 // Prefer input-oriented loops to maximize locality as the input
                                                 // is likely to be larger than the cache in most cases
                                                 // Based on the element size and counts in different dimensions,
                                                 // we will split and unroll some of the inner loops in order to maximize
                                                 // vectorization.
                                                 // In order to get appropriate utilization of all the SIMD
                                                 // registers, we will need to use a temporary buffer (which we expect
                                                 // the compiler to optimize away) with a size equal to the total number
                                                 // of elements that can be held in all of the SIMD registers.
                                                 // The filling of this temporary buffer from the cache needs to be an
                                                 // unrolled operation and the reducing of the output from the temporary
                                                 // buffer also needs to be an unrolled operation that happens after
                                                 // the full temporary buffer has been filled.
                                                 // If the reduce operation is a SumReduce operation, then we need
                                                 // a third loop in the middle which accumulates the current value
                                                 // from the output into the temporary buffer, then have the
                                                 // third loop copy the temporary buffer to the output
                                                 // Therefore, we need multiple levels of loopnests so that the area
                                                 // outside of the temporary buffer's addressable region can be looped
                                                 // over, and the area inside the temporary buffer region can have two
                                                 // or three sequential fully unrolled loopnests.
                                                 // new loopnest (outer):
                                                 // For ... {
                                                 //   For ... {
                                                 //       // start of outer loopnest prologue kernel
                                                 //       // Fill temp buf with cache data
                                                 //       new loopnest (inner #1):
                                                 //       For ... (unroll) {
                                                 //           For ... (unroll) {
                                                 //               ... {
                                                 //                   // start of inner loopnest #1 kernel
                                                 //                   tempBuf(tempBufIndices) = cache(cacheIndices)
                                                 //                   // end of inner loopnest #1 kernel
                                                 //               }
                                                 //               ...
                                                 //           }
                                                 //       }
                                                 //       // if reduceFunction == SumReduce
                                                 //       // Apply the reduce function to reduce elements of the output into the temp buf
                                                 //       new loopnest (inner #2):
                                                 //       For ... (unroll) {
                                                 //           For ... (unroll) {
                                                 //               ... {
                                                 //                   // start of inner loopnest #2 kernel
                                                 //                   tempBuf(tempBufIndices) += input(inputIndices)
                                                 //                   // end of inner loopnest #2 kernel
                                                 //               }
                                                 //               ...
                                                 //           }
                                                 //       }
                                                 //       // Copy temp buf to output
                                                 //       new loopnest (inner #3):
                                                 //       For ... (unroll) {
                                                 //           For ... (unroll) {
                                                 //               ... {
                                                 //                   // start of inner loopnest #3 kernel
                                                 //                   input(inputIndices) = tempBuf(tempBufIndices)
                                                 //                   // end of inner loopnest #3 kernel
                                                 //               }
                                                 //               ...
                                                 //           }
                                                 //       }
                                                 //       // end of outer loopnest kernel
                                                 //   }
                                                 // }

                                                 std::vector<loopnests::Index> reduceIndices;
                                                 reduceIndices.reserve(inputRegionShape.NumDimensions());
                                                 for (int idx = 0; idx < inputRegionShape.NumDimensions(); ++idx)
                                                 {
                                                     reduceIndices.push_back(loopnests::Index("reduceIdx_" + std::to_string(idx)));
                                                 }

                                                 // Define LoopNest
                                                 auto reduceNest = Using({ offsetInputArrayView }, ArgumentType::Input)
                                                                       .Using({ cacheArrayView }, ArgumentType::Output);
                                                 for (int idx = 0; idx < inputRegionShape.NumDimensions(); ++idx)
                                                 {
                                                     reduceNest.ForAll(reduceIndices[idx], 0, inputRegionShape.GetActiveSize(idx));
                                                 }

                                                 const int VectorizationSize = registerCharacteristics.NumberOfElementsPerSIMDRegister;
                                                 int maximumElementsInTempBuf = registerCharacteristics.NumberOfSIMDRegisters * VectorizationSize;
                                                 std::vector<int> indexSplitSizes(reduceIndices.size());
                                                 std::vector<int> tmpBufDimensionMapping(indexSplitSizes.size());

                                                 // Handle the innermost input dimension differently since we'll be counting elements there instead of shards of a memory layout
                                                 int shardSize = VectorizationSize;
                                                 int totalElementsPerShard = VectorizationSize;
                                                 for (unsigned idx = reduceIndices.size() - 1; reduceIndices.size() > idx; --idx)
                                                 {
                                                     int availableShardsInTmpBuf = maximumElementsInTempBuf / totalElementsPerShard;
                                                     int inputDimAvailableShards = inputRegionShape.GetActiveSize(idx) / shardSize;
                                                     int numShards = std::min(availableShardsInTmpBuf, inputDimAvailableShards);
                                                     tmpBufDimensionMapping[idx] = inputRegionShape.GetLogicalDimension(idx);
                                                     if (numShards > 1)
                                                     {
                                                         indexSplitSizes[idx] = numShards * shardSize;
                                                         shardSize = 1; // After the initial vectorization size, we target units of entire memory layout shards
                                                         totalElementsPerShard *= numShards; // The number of elements represented by a target scales with the number of inner targets it represents
                                                     }
                                                     else
                                                     {
                                                         indexSplitSizes[idx] = 1;
                                                     }
                                                 }
                                                 // The index split sizes are measured in input-space, so no scaling is needed
                                                 std::vector<int> tmpBufScaleFactors(indexSplitSizes.size(), 1);

                                                 BoundaryConditionMemoryLayoutHelper reduceKernelBoundaryHelper(inputRegionShape.GetActiveSize(),
                                                                                                                indexSplitSizes,
                                                                                                                tmpBufDimensionMapping,
                                                                                                                tmpBufScaleFactors,
                                                                                                                0, // Fill index doesn't matter for this usage
                                                                                                                tmpBufDimensionMapping.size()); // Shrink any index split sizes needed since we don't have a "view" to worry about

                                                 auto cacheReduceInternalKernel = loopnests::Kernel("Internal_Reduce_Cache_Outer_Kernel")
                                                                                      .Inputs(offsetInputArrayView, cacheArrayView)
                                                                                      .Indices(reduceIndices)
                                                                                      .DefineEx([=](std::vector<Value> values, std::vector<Scalar> innerIndices) {
                                                                                          Array offsetInput = values[0];
                                                                                          Array cacheView = values[1];

                                                                                          Value offsetInputInnerVal = offsetInput.GetValue().Offset(innerIndices);
                                                                                          offsetInputInnerVal.SetLayout(offsetInput.GetValue().GetLayout());
                                                                                          Array offsetInputInner = offsetInputInnerVal;

                                                                                          std::vector<Scalar> cacheIndices;
                                                                                          cacheIndices.reserve(boundaryCacheLayout.NumDimensions());
                                                                                          for (int cacheDimIdx = 0; cacheDimIdx < boundaryCacheLayout.NumDimensions(); ++cacheDimIdx)
                                                                                          {
                                                                                              unsigned baseDimIdx = cacheThresholdIdx + cacheDimIdx;
                                                                                              int logicalDimension = logicalDimensionMapping[baseDimIdx];
                                                                                              // Mapping loopnest indices (input space) -> cache indices (cache space) so divide by split index increment
                                                                                              cacheIndices.push_back((innerIndices[logicalDimension] / orderedIndexIncrements[baseDimIdx]) % boundaryCacheLayout.GetActiveSize(cacheDimIdx));
                                                                                          }
                                                                                          Value offsetCacheInnerVal = cacheView.GetValue().Offset(cacheIndices);
                                                                                          offsetCacheInnerVal.SetLayout(cacheView.GetValue().GetLayout());
                                                                                          Array offsetCacheInner = offsetCacheInnerVal;

                                                                                          reduceKernelBoundaryHelper.EmitBoundarySwitches(innerIndices, [=](MemoryLayout reduceRegionShape, MemoryLayout, MemoryLayout boundaryTempBufLayout, MemoryLayout) {
                                                                                              Array tmpBuf = Allocate(offsetInput.Type(), boundaryTempBufLayout, bufferAlignment);

                                                                                              std::vector<loopnests::Index> tmpBufInputIndices;

                                                                                              tmpBufInputIndices.reserve(reduceRegionShape.NumDimensions());
                                                                                              for (int idx = 0; idx < reduceRegionShape.NumDimensions(); ++idx)
                                                                                              {
                                                                                                  tmpBufInputIndices.push_back(loopnests::Index("tmpBuf_ReduceIdx_" + std::to_string(idx)));
                                                                                              }

                                                                                              auto tmpBufFillFromCacheNest = Using({ offsetCacheInner }, ArgumentType::Input)
                                                                                                                                 .Using({ tmpBuf }, ArgumentType::Output);
                                                                                              for (int idx = 0; idx < reduceRegionShape.NumDimensions(); ++idx)
                                                                                              {
                                                                                                  tmpBufFillFromCacheNest.ForAll(tmpBufInputIndices[idx], 0, reduceRegionShape.GetActiveSize(idx));
                                                                                              }

                                                                                              // Fill tmp buf from cache
                                                                                              auto tmpBufFillFromCache = loopnests::Kernel("Internal_TmpBuf_FillTmpBuf_Kernel")
                                                                                                                             .Inputs(offsetCacheInner, tmpBuf)
                                                                                                                             .Indices(tmpBufInputIndices)
                                                                                                                             .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufInputIndices) {
                                                                                                                                 Array offsetCacheInner = tmpBufValues[0];
                                                                                                                                 Array tmpBuf = tmpBufValues[1];

                                                                                                                                 int cacheDimensions = offsetCacheInner.GetValue().GetLayout().NumDimensions();
                                                                                                                                 std::vector<Scalar> cacheIndices;
                                                                                                                                 cacheIndices.reserve(cacheDimensions);
                                                                                                                                 for (int cacheDimIdx = 0; cacheDimIdx < cacheDimensions; ++cacheDimIdx)
                                                                                                                                 {
                                                                                                                                     unsigned baseDimIdx = cacheFillThresholdIdx + cacheDimIdx;
                                                                                                                                     int logicalDimension = logicalDimensionMapping[baseDimIdx];
                                                                                                                                     // Mapping loopnest indices (input space) -> cache indices (cache space) so divide by split index increment
                                                                                                                                     cacheIndices.push_back((tmpBufInputIndices[logicalDimension] / orderedIndexIncrements[baseDimIdx]) % boundaryCacheLayout.GetActiveSize(cacheDimIdx));
                                                                                                                                 }
                                                                                                                                 tmpBuf(tmpBufInputIndices) = offsetCacheInner(cacheIndices);
                                                                                                                             });
                                                                                              tmpBufFillFromCacheNest.Do(tmpBufFillFromCache);
                                                                                              auto& tmpBufFillSchedule = tmpBufFillFromCacheNest.GetSchedule();
                                                                                              // unroll everything
                                                                                              for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                              {
                                                                                                  tmpBufFillSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                              }
                                                                                              tmpBufFillFromCacheNest.Run();

                                                                                              if (accumulateReduce)
                                                                                              {
                                                                                                  // Reduce the current input/output contents into the temp buffer
                                                                                                  auto tmpBufReduceNest = Using({ offsetInputInner }, ArgumentType::Input)
                                                                                                                              .Using({ tmpBuf }, ArgumentType::Output);
                                                                                                  for (int idx = 0; idx < tmpBuf.GetValue().GetLayout().NumDimensions(); ++idx)
                                                                                                  {
                                                                                                      tmpBufReduceNest.ForAll(tmpBufInputIndices[idx], 0, tmpBuf.GetValue().GetLayout().GetActiveSize(idx));
                                                                                                  }

                                                                                                  auto tmpBufReduce = loopnests::Kernel("Internal_TmpBuf_ReduceOutput_Kernel")
                                                                                                                          .Inputs(tmpBuf, offsetInputInner)
                                                                                                                          .Indices(tmpBufInputIndices)
                                                                                                                          .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufInputIndices) {
                                                                                                                              Array tmpBuf = tmpBufValues[0];
                                                                                                                              Array offsetInputInner = tmpBufValues[1];

                                                                                                                              reduceFunction(tmpBuf(tmpBufInputIndices), offsetInputInner(tmpBufInputIndices));
                                                                                                                          });
                                                                                                  tmpBufReduceNest.Do(tmpBufReduce);
                                                                                                  auto& tmpBufReduceSchedule = tmpBufReduceNest.GetSchedule();
                                                                                                  for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                                  {
                                                                                                      tmpBufReduceSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                                  }
                                                                                                  tmpBufReduceNest.Run();

                                                                                                  // Copy temp buffer contents to input/output
                                                                                                  auto storeOutNest = Using({ tmpBuf }, ArgumentType::Input)
                                                                                                                          .Using({ offsetInputInner }, ArgumentType::Output);
                                                                                                  for (int idx = 0; idx < tmpBuf.GetValue().GetLayout().NumDimensions(); ++idx)
                                                                                                  {
                                                                                                      storeOutNest.ForAll(tmpBufInputIndices[idx], 0, tmpBuf.GetValue().GetLayout().GetActiveSize(idx));
                                                                                                  }

                                                                                                  auto storeOut = loopnests::Kernel("Internal_TmpBuf_CopyOutput_Kernel")
                                                                                                                      .Inputs(tmpBuf, offsetInputInner)
                                                                                                                      .Indices(tmpBufInputIndices)
                                                                                                                      .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufInputIndices) {
                                                                                                                          Array tmpBuf = tmpBufValues[0];
                                                                                                                          Array offsetInputInner = tmpBufValues[1];

                                                                                                                          offsetInputInner(tmpBufInputIndices) = tmpBuf(tmpBufInputIndices);
                                                                                                                      });
                                                                                                  storeOutNest.Do(storeOut);
                                                                                                  auto& storeOutSchedule = storeOutNest.GetSchedule();
                                                                                                  for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                                  {
                                                                                                      storeOutSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                                  }
                                                                                                  storeOutNest.Run();
                                                                                              }
                                                                                              else
                                                                                              {
                                                                                                  // Reduce the temp buffer into input/output
                                                                                                  auto outputReduceNest = Using({ tmpBuf }, ArgumentType::Input)
                                                                                                                              .Using({ offsetInputInner }, ArgumentType::Output);
                                                                                                  for (int idx = 0; idx < tmpBuf.GetValue().GetLayout().NumDimensions(); ++idx)
                                                                                                  {
                                                                                                      outputReduceNest.ForAll(tmpBufInputIndices[idx], 0, tmpBuf.GetValue().GetLayout().GetActiveSize(idx));
                                                                                                  }

                                                                                                  auto outputReduce = loopnests::Kernel("Internal_TmpBuf_ReduceOutput_Kernel")
                                                                                                                          .Inputs(tmpBuf, offsetInputInner)
                                                                                                                          .Indices(tmpBufInputIndices)
                                                                                                                          .DefineEx([=](std::vector<Value> tmpBufValues, std::vector<Scalar> tmpBufInputIndices) {
                                                                                                                              Array tmpBuf = tmpBufValues[0];
                                                                                                                              Array offsetInputInner = tmpBufValues[1];

                                                                                                                              reduceFunction(offsetInputInner(tmpBufInputIndices), tmpBuf(tmpBufInputIndices));
                                                                                                                          });
                                                                                                  outputReduceNest.Do(outputReduce);
                                                                                                  auto& outputReduceSchedule = outputReduceNest.GetSchedule();
                                                                                                  for (unsigned idx = 0; idx < tmpBufInputIndices.size(); ++idx)
                                                                                                  {
                                                                                                      outputReduceSchedule.Unroll(tmpBufInputIndices[idx]);
                                                                                                  }
                                                                                                  outputReduceNest.Run();
                                                                                              }
                                                                                          });
                                                                                      });

                                                 auto& schedule = reduceNest.GetSchedule();
                                                 std::vector<loopnests::Index> splitOuterIndices;
                                                 for (unsigned idx = 0; idx < reduceIndices.size(); ++idx)
                                                 {
                                                     if (indexSplitSizes[idx] > 1)
                                                     {
                                                         splitOuterIndices.push_back(schedule.Split(reduceIndices[idx], indexSplitSizes[idx]));
                                                     }
                                                 }

                                                 reduceNest.Do(cacheReduceInternalKernel, splitOuterIndices);

                                                 reduceNest.Run();
                                             });
                                         });

            underlyingNest.AddKernel(cacheReduceKernel, loopnests::CodePositionConstraints{ loopnests::LoopFragmentType::epilogue, cacheReducePosition, {} });
            cachingKernels.push_back(cacheReduceKernel);
        }

        underlyingNest.RenameVariable(_value, cacheRef, _atIndices, cachingKernels);
    }