bench/EmbeddingSpMDMBenchmark.cc (293 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <immintrin.h> #include <algorithm> #include <cassert> #include <chrono> #include <cmath> #include <cstdint> #include <iomanip> #include <iostream> #include <map> #include <numeric> #include <random> #include <set> #include <vector> #include "./BenchUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/FbgemmConvert.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; static vector<vector<int>> GetInputs_() { vector<vector<int>> input_dims = { // batch size, number of rows of table, emb dim , avg lengthl // TODO: Add more inputs // Use these -- but they are slow. {10, 4000000, 32, 100}, {10, 4000000, 64, 100}, {10, 4000000, 128, 100}, {10, 4000000, 256, 100}, // Use these for debugging // {2, 16, 128, 10}, // {10, 4000, 128, 100}, // {10, 4000, 128, 100}, // {10, 4000, 128, 100}, }; return input_dims; } void run_benchmark( int batch_size, int num_rows, int embedding_dim, int average_len, bool normalize_by_lengths, bool use_fp16_inputs = false, bool use_32_bit_indices = false, bool prefetch = false) { // Create embedding table default_random_engine generator; vector<float> embedding_table(num_rows * embedding_dim); normal_distribution<float> embedding_distribution; for (size_t i = 0; i < embedding_table.size(); ++i) { embedding_table[i] = embedding_distribution(generator); } vector<float16> embedding_table_fp16; if (use_fp16_inputs) { embedding_table_fp16.resize(embedding_table.size()); FloatToFloat16_simd( embedding_table.data(), embedding_table_fp16.data(), embedding_table.size()); } // Generate lengths uniform_int_distribution<int> length_distribution( 1, std::min(2 * average_len + 1, num_rows)); vector<int> offsets(batch_size + 1); offsets[0] = 0; for (int i = 0; i < batch_size; ++i) { offsets[i + 1] = offsets[i] + length_distribution(generator); } // Compute the number of indices int lengths_sum = offsets[batch_size]; cout << "lengths_sum " << lengths_sum << endl; // Generate indices vector<int64_t> indices; vector<int32_t> indices_32; vector<int> container(num_rows); // please note we generate unique indices for (int i = 0; i < batch_size; ++i) { iota(container.begin(), container.end(), 0); random_shuffle(container.begin(), container.end()); copy( container.begin(), container.begin() + (offsets[i + 1] - offsets[i]), back_inserter(indices)); } copy(begin(indices), end(indices), back_inserter(indices_32)); // Generate weights vector<float> weights(lengths_sum); for (int i = 0; i < lengths_sum; ++i) { weights[i] = embedding_distribution(generator); } vector<float> output_sls_ref(batch_size * embedding_dim); vector<float> output_slws_ref(output_sls_ref.size()), output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size()); constexpr int NUM_WARMUP = 4; constexpr int NUM_ITER = 10; int elem_bytes = use_fp16_inputs ? sizeof(float16) : sizeof(float); double bytes = lengths_sum * (embedding_dim * elem_bytes + (use_32_bit_indices ? 4 : 8)) + batch_size * sizeof(int); double bytes_padded = lengths_sum * ((embedding_dim * elem_bytes + 63) / 64 * 64 + (use_32_bit_indices ? 4 : 8)) + batch_size * sizeof(int); for (bool has_weight : {false, true}) { vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref; bool success = false, success_ref = false; if (use_fp16_inputs) { if (use_32_bit_indices) { success_ref = EmbeddingSpMDM_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), indices_32.data(), offsets.data(), has_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data()); } else { success_ref = EmbeddingSpMDM_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), indices.data(), offsets.data(), has_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data()); } } else { if (use_32_bit_indices) { success_ref = EmbeddingSpMDM_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), indices_32.data(), offsets.data(), has_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data()); } else { success_ref = EmbeddingSpMDM_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), indices.data(), offsets.data(), has_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data()); } } auto kernel_fp32_i32 = GenerateEmbeddingSpMDM<float, int32_t>( embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); auto kernel_fp32_i64 = GenerateEmbeddingSpMDM<float, int64_t>( embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); auto kernel_fp16_i32 = GenerateEmbeddingSpMDM<float16, int32_t>( embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); auto kernel_fp16_i64 = GenerateEmbeddingSpMDM<float16, int64_t>( embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); vector<float>& output = has_weight ? output_slws : output_sls; for (bool flush_cache : {false, true}) { double t = measureWithWarmup( [&]() { if (use_fp16_inputs) { if (use_32_bit_indices) { success = kernel_fp16_i32( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), indices_32.data(), offsets.data(), has_weight ? weights.data() : nullptr, output.data()); } else { success = kernel_fp16_i64( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), indices.data(), offsets.data(), has_weight ? weights.data() : nullptr, output.data()); } } else { if (use_32_bit_indices) { success = kernel_fp32_i32( batch_size, lengths_sum, num_rows, embedding_table.data(), indices_32.data(), offsets.data(), has_weight ? weights.data() : nullptr, output.data()); } else { success = kernel_fp32_i64( batch_size, lengths_sum, num_rows, embedding_table.data(), indices.data(), offsets.data(), has_weight ? weights.data() : nullptr, output.data()); } } }, NUM_WARMUP, NUM_ITER, [&]() { if (flush_cache) { cache_evict(embedding_table); cache_evict(indices); cache_evict(indices_32); cache_evict(offsets); cache_evict(weights); cache_evict(output); } }); // printMatrix( // matrix_op_t::NoTranspose, // output.data(), // batch_size, // embedding_dim, // embedding_dim, // ""); // cout << "reference data\n"; // printMatrix( // matrix_op_t::NoTranspose, // output_ref.data(), // batch_size, // embedding_dim, // embedding_dim, // ""); // Check correctness if (!flush_cache) { if (success != success_ref) { assert( false && "ERROR: refernce impl and JIT imp did not both succeed"); } else if (success) { for (size_t i = 0; i < output.size(); ++i) { assert(output[i] == output_ref[i]); if (output[i] != output_ref[i]) { cout << i << " " << output[i] << " " << output_ref[i] << endl; } } } } if (has_weight) { cout << setw(16) << "SLW(WEIGHTED) "; } else { cout << setw(16) << "SLS "; } if (flush_cache) { cout << setw(20) << "cache flushed"; } else { cout << setw(20) << "cache not flushed"; } if (prefetch) { cout << setw(16) << "prefetch on"; } else { cout << setw(16) << "prefetch off"; } cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s" << setw(20) << "effective b/w: " << setw(16) << bytes_padded / 1e9 / t << "GB/s" << setw(8) << " time " << setw(16) << t << endl; } // flush_cache } // has_weight } int main() { vector<vector<int>> inputs(GetInputs_()); for (auto& input : inputs) { assert(input.size() > 3); int batch_size = input[0]; int num_rows = input[1]; int embedding_dim = input[2]; int average_len = input[3]; cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows" << setw(16) << num_rows << setw(10) << "emb dim" << setw(6) << embedding_dim << setw(16) << "avg length" << setw(6) << average_len << endl; for (bool normalize_by_lengths : {false, true}) { for (bool use_fp16_inputs : {false, true}) { for (bool use_32_bit_indices : {false, true}) { for (bool prefetch : {false, true}) { // args: batch sz, num rows, emb dim, avg len, normalize, use 32b, // prefetch if (normalize_by_lengths) { cout << "Mean"; } if (use_fp16_inputs) { cout << "fp16 inputs"; } cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices"; if (prefetch) { cout << " with prefetching"; } cout << ", "; run_benchmark( batch_size, num_rows, embedding_dim, average_len, normalize_by_lengths, use_fp16_inputs, use_32_bit_indices, prefetch); } // prefetch } // use_32_bit_indices } // use_fp16_inputs } // normalize_by_length } // for each input return 0; }