fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp (98 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
using Tensor = at::Tensor;
namespace {
void bounds_check_indices_cpu(
Tensor rows_per_table,
Tensor indices,
Tensor offsets,
int64_t bounds_check_mode_,
Tensor warning) {
auto bounds_check_mode = static_cast<BoundsCheckMode>(bounds_check_mode_);
if (bounds_check_mode == BoundsCheckMode::WARNING) {
warning.zero_();
}
int32_t T = rows_per_table.size(0);
int32_t B = (offsets.size(0) - 1) / T;
const auto rows_per_table_acc = rows_per_table.accessor<int64_t, 1>();
auto warning_acc = warning.data_ptr<int64_t>();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
auto offsets_acc = offsets.accessor<index_t, 1>();
auto indices_acc = indices.accessor<index_t, 1>();
auto num_indices = indices.numel();
for (auto t = 0; t < T; ++t) {
auto num_rows = rows_per_table_acc[t];
for (auto b = 0; b < B; ++b) {
auto indices_start = offsets_acc[t * B + b];
auto indices_end = offsets_acc[t * B + b + 1];
if (bounds_check_mode == BoundsCheckMode::FATAL) {
TORCH_CHECK(indices_start >= 0);
TORCH_CHECK(indices_start <= indices_end);
TORCH_CHECK(indices_end <= num_indices);
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (indices_start < 0 || indices_start > indices_end ||
indices_end > num_indices) {
if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) {
LOG(ERROR)
<< "(at least one) Out of bounds access for batch: " << b
<< ", table: " << t << ", indices_start: " << indices_start
<< ", indices_end: " << indices_end
<< ", num_indices: " << num_indices
<< ". Setting indices_start and indices_end within the range";
}
indices_start = std::max(
0L, std::min(static_cast<int64_t>(indices_start), num_indices));
indices_end = std::max(
static_cast<int64_t>(indices_start),
std::min(static_cast<int64_t>(indices_end), num_indices));
offsets_acc[t * B + b] = indices_start;
offsets_acc[t * B + b + 1] = indices_end;
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
indices_start = std::max(
0L, std::min(static_cast<int64_t>(indices_start), num_indices));
indices_end = std::max(
static_cast<int64_t>(indices_start),
std::min(static_cast<int64_t>(indices_end), num_indices));
offsets_acc[t * B + b] = indices_start;
offsets_acc[t * B + b + 1] = indices_end;
}
auto L = indices_end - indices_start;
for (auto l = 0; l < L; ++l) {
auto idx = indices_acc[indices_start + l];
if (idx == -1) {
// -1 indicates pruned rows.
continue;
}
if (bounds_check_mode == BoundsCheckMode::FATAL) {
TORCH_CHECK(idx >= 0);
TORCH_CHECK(idx < num_rows);
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (idx < 0 || idx >= num_rows) {
if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) {
LOG(ERROR) << "(at least one) Out of bounds access for batch: "
<< b << ", table: " << t << ", bag element: " << l
<< ", idx: " << idx << ", num_rows: " << num_rows
<< ". Setting idx to zero.";
}
indices_acc[indices_start + l] = 0;
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
if (idx < 0 || idx >= num_rows) {
indices_acc[indices_start + l] = 0;
}
}
}
}
}
});
}
} // namespace
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
// or DCE'd, etc.
m.def(
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(a!) offsets, int bounds_check_mode, Tensor(a!) warning) -> ()");
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
}