in bench/EmbeddingSpMDM8BitBenchmark.cc [58:300]
int run_benchmark(
int batch_size,
int num_rows,
int embedding_dim,
int average_len,
bool normalize_by_lengths,
bool use_32_bit_indices = false,
bool prefetch = false,
bool stress_multi_threading = false) {
// Create embedding table
default_random_engine generator;
normal_distribution<float> embedding_distribution;
vector<uint8_t> fused_embedding_table(
num_rows * (embedding_dim + 2 * sizeof(float)));
for (int i = 0; i < num_rows; i++) {
for (int ii = 0; ii < embedding_dim; ii++) {
fused_embedding_table[i * (embedding_dim + 2 * sizeof(float)) + ii] = 2;
}
float* scale_bias = reinterpret_cast<float*>(
&fused_embedding_table[i * (embedding_dim + 2 * sizeof(float))] +
embedding_dim);
scale_bias[0] = 2.0;
scale_bias[1] = 1.0;
}
// print_fused_table(num_rows, embedding_dim, fused_embedding_table);
// Generate lengths
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
vector<int> offsets(batch_size + 1);
offsets[0] = 0;
for (int i = 0; i < batch_size; ++i) {
offsets[i + 1] = offsets[i] + length_distribution(generator);
}
// Compute the number of indices
int lengths_sum = offsets[batch_size];
if (fbgemm_get_thread_num() == 0) {
cout << "lengths_sum " << lengths_sum << endl;
}
// Generate indices
vector<int64_t> indices;
vector<int32_t> indices_32;
vector<int> container(num_rows);
map<int64_t, set<int>> dedup_map; // index -> set(output index)
// please note we generate unique indices
for (int i = 0; i < batch_size; ++i) {
iota(container.begin(), container.end(), 0);
random_shuffle(container.begin(), container.end());
copy(
container.begin(),
container.begin() + (offsets[i + 1] - offsets[i]),
back_inserter(indices));
}
copy(begin(indices), end(indices), back_inserter(indices_32));
// Generate weights
vector<float> weights(lengths_sum);
for (int i = 0; i < lengths_sum; ++i) {
weights[i] = embedding_distribution(generator);
}
vector<float> output_sls_ref(batch_size * embedding_dim);
vector<float> output_slws_ref(output_sls_ref.size()),
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());
constexpr int NUM_WARMUP = 4;
int NUM_ITER = stress_multi_threading ? 1 << 20 : 10;
double bytes = lengths_sum *
(embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) +
(use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);
double bytes_padded = lengths_sum *
((embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) + 63) / 64 *
64 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);
vector<bool> has_weight_options;
has_weight_options.push_back(false);
if (!stress_multi_threading) {
has_weight_options.push_back(true);
}
for (bool has_weight : has_weight_options) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
bool success = false, success_ref = false;
if (use_32_bit_indices) {
success_ref = EmbeddingSpMDM_ref(
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices_32.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
} else {
success_ref = EmbeddingSpMDM_ref(
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
}
vector<float>& output = has_weight ? output_slws : output_sls;
vector<bool> flush_cache_options;
flush_cache_options.push_back(false);
if (!stress_multi_threading) {
flush_cache_options.push_back(true);
}
auto kernel_32 = GenerateEmbeddingSpMDM<uint8_t, int32_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_64 = GenerateEmbeddingSpMDM<uint8_t, int64_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
#ifdef _OPENMP
#pragma omp barrier
#endif
for (bool flush_cache : flush_cache_options) {
benchmarkTimes[fbgemm_get_thread_num()] = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
success = kernel_32(
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices_32.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
output.data());
} else {
success = kernel_64(
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
output.data());
}
},
NUM_WARMUP,
NUM_ITER,
[&]() {
if (flush_cache) {
cache_evict(fused_embedding_table);
cache_evict(indices);
cache_evict(indices_32);
cache_evict(offsets);
cache_evict(weights);
cache_evict(output);
}
});
// printMatrix(
// matrix_op_t::NoTranspose,
// output.data(),
// batch_size,
// embedding_dim,
// embedding_dim,
// "");
// printMatrix(
// matrix_op_t::NoTranspose,
// output_ref.data(),
// batch_size,
// embedding_dim,
// embedding_dim,
// "");
// Check correctness
if (!flush_cache) {
// vector<float>& output_ref =
// has_weight ? output_slws_ref : output_sls_ref;
if (success != success_ref) {
assert(
false && "ERROR: refernce impl and JIT imp did not both succeed");
} else if (success) {
for (size_t i = 0; i < output.size(); ++i) {
assert(fabs(output[i] - output_ref[i]) < 1e-3);
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
cout << i << " " << output[i] << " " << output_ref[i] << endl;
}
}
}
}
#ifdef _OPENMP
#pragma omp barrier
#endif
if (fbgemm_get_thread_num() == 0) {
if (has_weight) {
cout << setw(16) << "SLW(WEIGHTED) ";
} else {
cout << setw(16) << "SLS ";
}
if (flush_cache) {
cout << setw(20) << "cache flushed";
} else {
cout << setw(20) << "cache not flushed";
}
if (prefetch) {
cout << setw(16) << "prefetch on";
} else {
cout << setw(16) << "prefetch off";
}
double max_time = *std::max_element(
benchmarkTimes.begin(),
benchmarkTimes.begin() + fbgemm_get_num_threads());
double avg_time = std::accumulate(
benchmarkTimes.begin(),
benchmarkTimes.begin() + fbgemm_get_num_threads(),
0.0) /
fbgemm_get_num_threads();
double load_imbalance = (max_time - avg_time) / avg_time;
cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / max_time
<< " GB/s" << setw(20) << "effective b/w: " << setw(16)
<< bytes_padded / 1e9 / max_time << "GB/s" << setw(8) << " time "
<< setw(16) << max_time << " load_imbalance " << load_imbalance
<< endl;
}
} // flush_cache
} // has_weight
return 0;
}