poplar::Tensor decompressPacked4BitTensor()

in optimum/graphcore/custom_ops/group_quantize_decompress/group_quantize_decompressx.cpp [63:130]


poplar::Tensor decompressPacked4BitTensor(poplar::Graph &graph,
                                          poplar::Tensor &x,
                                          poplar::Tensor &groupScale,
                                          poplar::Tensor &groupBias,
                                          poplar::program::Sequence &prog) {

  std::vector<size_t> unp_shape{x.shape()[0], x.shape()[1], x.shape()[2] * 4};
  std::vector<size_t> out_shape{x.shape()[0], x.shape()[1] * x.shape()[2] * 4};

  poplar::DebugContext debugContext;

  // Quantized/compressed tensor must have a tile mapping with minimum grain size.
  // 64 bits decoding vertex.
  const unsigned grain_size = 4;
  poputil::mapTensorLinearly(graph, x, 0, grain_size);

  auto x_unpacked =
      graph.addVariable(poplar::HALF, unp_shape, {debugContext, "x_unpacked"});
  auto computeSet = graph.addComputeSet({debugContext, "unpack4bit"});
  auto mapping = graph.getTileMapping(x);
  auto numWorkers = 6;

  for (auto tile = 0u; tile < mapping.size(); ++tile) {
    for (auto i : mapping[tile]) {
      // Colocate unpacked tensor to input
      graph.setTileMapping(
          x_unpacked.flatten().slice(i.begin() * 4, i.end() * 4), tile);
      // Get constants for slicing input across 6 threads.
      // Similarly, need to satisfy the minimal grain size per thread.
      const auto interval = (i.end() - i.begin());
      const auto interval_blocks = interval / grain_size;
      const auto numElmsPerWorkerNoRemainder = (interval_blocks / numWorkers) * grain_size;
      const auto numElmsRemainder = interval - numElmsPerWorkerNoRemainder * (numWorkers - 1);
      int slice_bounds[7] = {0};
      slice_bounds[0] = i.begin();

      for (auto wid = 0; wid < numWorkers; ++wid) {
        // Determine slice bounds for thread worker
        if (wid < numWorkers - 1) {
          slice_bounds[wid + 1] = slice_bounds[wid] + numElmsPerWorkerNoRemainder;
        }
        else {
          slice_bounds[wid + 1] = slice_bounds[wid] + numElmsRemainder;
        }
        // add vertex to thread
        auto vertex = graph.addVertex(
            computeSet, "DecompressPacked4BitTensorV1",
            {{"input",
              x.flatten().slice(slice_bounds[wid], slice_bounds[wid + 1])},
             {"output",
              x_unpacked.flatten().slice(slice_bounds[wid] * 4,
                                         slice_bounds[wid + 1] * 4)}});
        graph.setTileMapping(vertex, tile);
        graph.setPerfEstimate(
            vertex,
            100 + 30 * i.size()); // guess from godbolt per int32 + overhead
      }
    }
  }
  prog.add(poplar::program::Execute(computeSet));

  // Scale float16 tensor (TODO: optimize with scaledAddTo or fuse into vertex)
  auto x_scaled = popops::map(graph, pe::_1 * pe::_2 + pe::_3,
                              {x_unpacked, groupScale, groupBias}, prog);
  auto x_out = x_scaled.reshape(out_shape);

  return x_out;
}