int run_benchmark()

in bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc [58:319]


int run_benchmark(
    int bit_rate,
    int batch_size,
    int num_rows,
    int embedding_dim,
    int average_len,
    bool normalize_by_lengths,
    bool use_32_bit_indices = false,
    bool prefetch = false) {
  // Generate mapping table
  default_random_engine generator;
  constexpr float sparsity = 0.7;
  vector<int32_t> mapping_table(num_rows);
  bernoulli_distribution row_prune_dist(sparsity);
  int num_compressed_rows = 0;
  for (int i = 0; i < num_rows; ++i) {
    if (row_prune_dist(generator)) {
      // pruned
      mapping_table[i] = -1;
    } else {
      mapping_table[i] = num_compressed_rows;
      ++num_compressed_rows;
    }
  }

  // Create embedding table
  int num_elem_per_byte = 8 / bit_rate;
  int fused_embedding_dim =
      (embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte +
      2 * sizeof(float16);
  normal_distribution<float> embedding_distribution;

  vector<uint8_t> fused_embedding_table(
      num_compressed_rows * fused_embedding_dim);
  for (int i = 0; i < num_compressed_rows; i++) {
    for (int ii = 0;
         ii < (embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte;
         ii++) {
      fused_embedding_table[i * fused_embedding_dim + ii] = 2;
    }
    float16* scale_bias = reinterpret_cast<float16*>(
        &fused_embedding_table[i * fused_embedding_dim] +
        (embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte);
    float scale = 2.0f;
    float bias = 1.0f;
    FloatToFloat16_ref(&scale, scale_bias, 1, true /* clip */);
    FloatToFloat16_ref(&bias, scale_bias + 1, 1, true /* clip */);
  }

  // print_fused_table(num_rows, embedding_dim, fused_embedding_table);

  // Generate lengths
  uniform_int_distribution<int> length_distribution(
      1, std::min(2 * average_len + 1, num_rows));
  vector<int> offsets(batch_size + 1);
  offsets[0] = 0;
  for (int i = 0; i < batch_size; ++i) {
    offsets[i + 1] = offsets[i] + length_distribution(generator);
  }

  // Compute the number of indices
  int lengths_sum = offsets[batch_size];

  // Generate indices
  vector<int64_t> indices;
  vector<int32_t> indices_32;

  vector<int> container(num_rows);
  map<int64_t, set<int>> dedup_map; // index -> set(output index)

  // please note we generate unique indices
  for (int i = 0; i < batch_size; ++i) {
    iota(container.begin(), container.end(), 0);
    random_shuffle(container.begin(), container.end());
    copy(
        container.begin(),
        container.begin() + (offsets[i + 1] - offsets[i]),
        back_inserter(indices));
  }
  copy(begin(indices), end(indices), back_inserter(indices_32));

  // Compute the number of valid indices
  int num_valid_indices = 0;
  for (int index : indices) {
    if (mapping_table[index] != -1) {
      ++num_valid_indices;
    }
  }
  cout << "lengths_sum " << lengths_sum << " num_valid_indices "
       << num_valid_indices << endl;

  // Generate weights
  vector<float> weights(lengths_sum);
  for (int i = 0; i < lengths_sum; ++i) {
    weights[i] = embedding_distribution(generator);
  }

  vector<float> output_sls_ref(batch_size * embedding_dim);
  vector<float> output_slws_ref(output_sls_ref.size()),
      output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());

  constexpr int NUM_WARMUP = 4;
  constexpr int NUM_ITER = 10;
  // Only counts the number of bytes for reading embedding table and ignore
  // others. Should be good enough as long as embdding_dim is big enough.
  constexpr int CACHE_LINE_SIZE = 64;
  double bytes = lengths_sum * 2 *
          (use_32_bit_indices ? sizeof(int32_t) : sizeof(int64_t)) +
      num_valid_indices * fused_embedding_dim;
  double bytes_padded = lengths_sum *
          ((use_32_bit_indices ? sizeof(int32_t) : sizeof(int64_t)) +
           CACHE_LINE_SIZE) +
      num_valid_indices * CACHE_LINE_SIZE *
          static_cast<int>(
              (fused_embedding_dim + CACHE_LINE_SIZE - 1) / CACHE_LINE_SIZE);

  for (bool has_weight : {false, true}) {
    vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;

    bool success = false, success_ref = false;

    for (int i = 0; i < NUM_WARMUP + NUM_ITER; ++i) {
      if (use_32_bit_indices) {
        success_ref = EmbeddingSpMDMNBitRowWiseSparse_ref(
            bit_rate,
            embedding_dim,
            batch_size,
            lengths_sum,
            num_rows,
            fused_embedding_table.data(),
            indices_32.data(),
            mapping_table.data(),
            offsets.data(),
            has_weight ? weights.data() : nullptr,
            normalize_by_lengths,
            output_ref.data());

      } else {
        success_ref = EmbeddingSpMDMNBitRowWiseSparse_ref(
            bit_rate,
            embedding_dim,
            batch_size,
            lengths_sum,
            num_rows,
            fused_embedding_table.data(),
            indices.data(),
            mapping_table.data(),
            offsets.data(),
            has_weight ? weights.data() : nullptr,
            normalize_by_lengths,
            output_ref.data());
      }
    }

    vector<float>& output = has_weight ? output_slws : output_sls;

    auto kernel_32 = GenerateEmbeddingSpMDMNBitRowWiseSparse<int32_t>(
        bit_rate,
        embedding_dim,
        has_weight,
        normalize_by_lengths,
        prefetch ? 16 : 0);
    auto kernel_64 = GenerateEmbeddingSpMDMNBitRowWiseSparse<int64_t>(
        bit_rate,
        embedding_dim,
        has_weight,
        normalize_by_lengths,
        prefetch ? 16 : 0);

    for (bool flush_cache : {false, true}) {
      double t = measureWithWarmup(
          [&]() {
            if (use_32_bit_indices) {
              success = kernel_32(
                  batch_size,
                  lengths_sum,
                  num_rows,
                  fused_embedding_table.data(),
                  indices_32.data(),
                  offsets.data(),
                  has_weight ? weights.data() : nullptr,
                  output.data(),
                  mapping_table.data());
            } else {
              success = kernel_64(
                  batch_size,
                  lengths_sum,
                  num_rows,
                  fused_embedding_table.data(),
                  indices.data(),
                  offsets.data(),
                  has_weight ? weights.data() : nullptr,
                  output.data(),
                  mapping_table.data());
            }
          },
          NUM_WARMUP,
          NUM_ITER,
          [&]() {
            if (flush_cache) {
              cache_evict(fused_embedding_table);
              cache_evict(indices);
              cache_evict(indices_32);
              cache_evict(offsets);
              cache_evict(weights);
              cache_evict(output);
            }
          });

      // printMatrix(
      //     matrix_op_t::NoTranspose,
      //     output.data(),
      //     batch_size,
      //     embedding_dim,
      //     embedding_dim,
      //     "");
      // printMatrix(
      //     matrix_op_t::NoTranspose,
      //     output_ref.data(),
      //     batch_size,
      //     embedding_dim,
      //     embedding_dim,
      //     "");
      // Check correctness
      if (!flush_cache) {
        if (success != success_ref) {
          assert(
              false && "ERROR: refernce impl and JIT imp did not both succeed");
        } else if (success) {
          for (size_t i = 0; i < output.size(); ++i) {
            assert(fabs(output[i] - output_ref[i]) < 1e-3);
            if (fabs(output[i] - output_ref[i]) >= 1e-3) {
              cout << i << " " << output[i] << " " << output_ref[i] << endl;
            }
          }
        }
      }

      if (has_weight) {
        cout << setw(16) << "SLW(WEIGHTED) ";
      } else {
        cout << setw(16) << "SLS ";
      }
      if (flush_cache) {
        cout << setw(20) << "cache flushed";
      } else {
        cout << setw(20) << "cache not flushed";
      }
      if (prefetch) {
        cout << setw(16) << "prefetch on";
      } else {
        cout << setw(16) << "prefetch off";
      }

      cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s"
           << setw(20) << "effective b/w: " << setw(16)
           << bytes_padded / 1e9 / t << "GB/s" << setw(8) << " time "
           << setw(16) << t << endl;
    } // flush_cache
  } // has_weight
  return 0;
}