bench/RowwiseAdagradBenchmark.cc (189 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <map>
#include <random>
#include <set>
#include <vector>
#include "./BenchUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"
using namespace std;
using namespace fbgemm;
static vector<vector<int>> GetInputs_() {
vector<vector<int>> input_dims = {
// num_rows, emb dim
{1500000, 16},
{1500000, 24},
{1500000, 32},
{1500000, 72},
{1500000, 128},
};
return input_dims;
}
vector<int> prefetch_distances{16};
void run_benchmark(
const int num_rows, // number of rows reading
const int block_size, // number of parameters per row
const uint64_t param_size, // total number of parameters
const bool isIndex64b,
const int prefetch,
const bool adjust_weight_decay) {
vector<char> llc(64L * 1024L * 1024L, 1.0);
vector<float> g(num_rows * block_size); // gradients
vector<float> h(param_size); // input momentums
vector<float> w(param_size); // input params
vector<float> h_ref(param_size);
vector<float> w_ref(param_size);
default_random_engine generator;
// normal_distribution<float> h_w_distribution;
// TODO: check appropriate vals for g,h,w
for (size_t i = 0; i < g.size(); ++i) {
g[i] = 4 + i; // h_w_distribution(generator);
}
for (size_t i = 0; i < h.size(); ++i) {
h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator);
}
for (size_t i = 0; i < w.size(); ++i) {
w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator);
}
vector<int64_t> indices(num_rows);
vector<int32_t> indices_32(num_rows);
vector<double> counter(num_rows);
float epsilon = 1e-5;
float lr = 0.5;
float weight_decay = adjust_weight_decay ? 1e-7 : 0.0f;
constexpr int64_t counter_halflife = 1e6;
uniform_int_distribution<int64_t> length_distribution(0, num_rows - 1);
uniform_int_distribution<int64_t> counter_distribution(
0, 2 * counter_halflife);
for (int i = 0; i < num_rows; ++i) {
indices_32[i] = indices[i] = length_distribution(generator);
counter[i] = counter_distribution(generator);
}
double t = 0.0;
constexpr int NUM_WARMUP = 4;
constexpr int NUM_ITER = 10;
double data_moved = num_rows * (3 * sizeof(float) * block_size + 2 * 64);
if (isIndex64b) {
auto fn_indices_64 = GenerateSparseAdaGrad<int64_t>(
block_size, /*rowwise=*/true, prefetch, adjust_weight_decay);
t = measureWithWarmup(
[&]() {
fn_indices_64(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay, // weight_decay
adjust_weight_decay ? counter.data() : nullptr, // counters
counter_halflife); // counter_halflife
},
NUM_WARMUP,
NUM_ITER,
[&]() { llc_flush(llc); });
for (int i = 0; i < NUM_WARMUP + NUM_ITER; ++i) {
rowwise_sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per row
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay, // weight decay (lambda)
adjust_weight_decay ? counter.data()
: nullptr, // feature ID frequency counter data
counter_halflife); // counter halflife value for adjustments
}
} else {
auto fn_indices_32 = GenerateSparseAdaGrad<int32_t>(
block_size, /*rowwise=*/true, prefetch, adjust_weight_decay);
t = measureWithWarmup(
[&]() {
fn_indices_32(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay, // weight_decay
adjust_weight_decay ? counter.data() : nullptr, // counters
counter_halflife); // counter_halflife
},
NUM_WARMUP,
NUM_ITER,
[&]() { llc_flush(llc); });
for (int i = 0; i < NUM_WARMUP + NUM_ITER; ++i) {
rowwise_sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per row
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay, // weight decay (lambda)
adjust_weight_decay ? counter.data()
: nullptr, // feature ID frequency counter data
counter_halflife); // counter halflife value for adjustments
}
}
for (size_t i = 0; i < w.size(); ++i) {
assert(fabs(w[i] - w_ref[i]) < 1e-5);
if (fabs(w[i] - w_ref[i]) >= 1e-5) {
fprintf(stderr, "%ld %f %f\n", i, w[i], w_ref[i]);
}
}
for (size_t i = 0; i < h.size(); ++i) {
assert(fabs(h[i] - h_ref[i]) < 1e-5);
if (fabs(h[i] - h_ref[i]) >= 1e-5) {
fprintf(stderr, "%ld %f %f\n", i, h[i], h_ref[i]);
}
}
cout << "indices: " << (isIndex64b ? " 64bits " : " 32bits ") << " | ";
cout << "weight_decay: " << setw(1) << adjust_weight_decay << " | ";
cout << "num_rows: " << setw(8) << num_rows << " block_size: " << setw(4)
<< block_size << " | ";
cout << "time taken by jit code(secs): " << setw(10) << fixed
<< setprecision(6) << t << " | ";
cout << "bandwidth fbgemm (GB/s) " << setw(10) << fixed << setprecision(6)
<< data_moved / t / 1e9 << endl;
}
int main() {
int num_rows;
int block_size;
uint64_t param_size;
vector<vector<int>> inputs(GetInputs_());
for (auto isIndex64b : vector<bool>{true, false}) {
for (auto adjust_weight_decay : vector<bool>{true, false}) {
for (auto prefetch : prefetch_distances) {
for (auto& input : inputs) {
assert(input.size() >= 2);
num_rows = input[0];
block_size = input[1];
param_size = num_rows * block_size;
run_benchmark(
num_rows,
block_size,
param_size,
isIndex64b,
prefetch,
adjust_weight_decay);
}
}
}
}
return 0;
}