in src/batchsizefinder.h [44:236]
int find(
torch::Device device, Prepare&& prepare, Forward&& forward, int minBatchSize, int maxBatchsize, float maxTimeMs,
Score&& scoreFunction) {
torch::NoGradGuard ng;
bool isCuda = device.is_cuda();
std::optional<c10::cuda::CUDAStreamGuard> g;
if (isCuda) {
g.emplace(c10::cuda::getStreamFromPool(false, device.index()));
} else {
// throw std::runtime_error("findBatchSize on non-cuda device is not meaningful");
}
auto input = prepare(1);
auto call = [&]() {
forward(input);
if (isCuda) {
g->current_stream().synchronize();
}
};
fmt::printf("Finding batch size\n");
// warm up
for (int i = 0; i != 10; ++i) {
call();
}
Timer t;
for (int i = 0; i != 10; ++i) {
call();
}
float call1 = t.elapsed() / 10.0f * 1000.0f;
fmt::printf("Base latency: %gms\n", call1);
float maxms = maxTimeMs;
int maxbs = maxBatchsize;
struct I {
float latency = 0.0f;
int size = 0;
int n = 0;
bool isBad = false;
};
auto scorex = [&](auto& x) { return scoreFunction(x.latency / x.n, x.size); };
std::map<int, I> li;
int best = 0;
float bestScore = std::numeric_limits<float>::infinity();
auto eval = [&](int i) {
input = prepare(i);
int badcount = 0;
float latency = 0.0f;
int n = 2;
for (int j = 0; j != n; ++j) {
call();
}
for (int j = 0; j != n; ++j) {
t.reset();
call();
float ms = t.elapsed() * 1000;
latency += ms;
if (ms > maxms || i > maxbs || i < minBatchSize) {
++badcount;
}
}
auto& x = li[i];
x.size = i;
x.latency += latency;
x.n += n;
x.isBad = badcount >= n;
float score = scorex(x);
if (!x.isBad && score < bestScore) {
bestScore = score;
best = i;
}
return badcount < n;
};
for (int i = std::max(minBatchSize, 1);; i += (i + 3) / 4) {
if (!eval(i)) {
break;
}
}
std::minstd_rand rng(std::random_device{}());
auto expandNear = [&](int k) {
int r = 0;
auto i = li.find(k);
if (i != li.end()) {
auto search = [&](auto begin, auto end) {
int b = begin->first;
int e;
if (end == li.end()) {
e = std::prev(end)->first;
} else {
e = end->first;
}
b = std::max(b, i->first - 3);
e = std::max(b, i->first + 6);
for (int i = b; i != e; ++i) {
if (li.find(i) != li.end()) {
continue;
}
++r;
if (!eval(i)) {
break;
}
}
};
search(i, std::next(i));
if (i != li.begin()) {
search(std::prev(i), i);
}
}
return r;
};
for (int j = 0; j != 4; ++j) {
int expands = 12;
for (int k = 0; k != 12; ++k) {
float sum = 0.0f;
std::vector<std::tuple<float, int, int>> list;
float minweight = std::numeric_limits<float>::infinity();
for (auto& [k, v] : li) {
if (!v.isBad) {
minweight = std::min(minweight, scorex(v));
}
}
for (auto i = li.begin();;) {
auto next = std::next(i);
if (next == li.end()) {
break;
}
if (i->second.isBad && next->second.isBad) {
i = next;
continue;
}
int from = i->first + 1;
int to = next->first;
if (to - from > 0) {
float weight = std::min(scorex(i->second), scorex(next->second)) - minweight;
weight = 1.0f / std::min(std::exp(weight * 4), 1e9f);
weight *= to - from;
list.emplace_back(weight, from, to);
sum += weight;
}
i = next;
}
if (list.size() > 0 && sum > 0.0f) {
float val = std::uniform_real_distribution<float>(0.0f, sum)(rng);
for (auto& [weight, from, to] : list) {
val -= weight;
if (val <= 0) {
int k = std::uniform_int_distribution<int>(from, to - 1)(rng);
eval(k);
if (expands > 0) {
expands -= expandNear(k);
}
break;
}
}
}
}
if (best) {
expandNear(best);
}
std::vector<std::tuple<float, int>> sorted;
for (auto& [k, v] : li) {
if (!v.isBad) {
sorted.emplace_back(scorex(v), k);
}
}
std::sort(sorted.begin(), sorted.end());
for (size_t i = 0; i != sorted.size() && i < 10; ++i) {
int k = std::get<1>(sorted[i]);
if (li[k].n < 8) {
eval(k);
}
}
}
for (auto& [k, v] : li) {
fmt::printf(
"Batch size %d, evals %d latency %fms throughput %g score %g\n", k, v.n, v.latency / v.n,
v.size / (v.latency / v.n), scorex(v));
}
fmt::printf(
"Found best batch size of %d with evals %d latency %fms "
"throughput %g score %g\n",
best, li[best].n, li[best].latency / li[best].n, li[best].size / (li[best].latency / li[best].n),
scorex(li[best]));
return best;
}