in tensorflow_addons/custom_ops/text/cc/kernels/skip_gram_kernels.cc [40:125]
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
const auto input = input_tensor->flat<T>();
const Tensor* min_skips_tensor;
OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
const int min_skips = *(min_skips_tensor->scalar<int>().data());
const Tensor* max_skips_tensor;
OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
const int max_skips = *(max_skips_tensor->scalar<int>().data());
const Tensor& input_check = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_check.shape()),
errors::InvalidArgument("input_tensor must be of rank 1"));
OP_REQUIRES(
context, min_skips >= 0 && max_skips >= 0,
errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
OP_REQUIRES(context, min_skips <= max_skips,
errors::InvalidArgument("min_skips must be <= max_skips."));
const Tensor* start_tensor;
OP_REQUIRES_OK(context, context->input("start", &start_tensor));
const int start = *(start_tensor->scalar<int>().data());
const Tensor* limit_tensor;
OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
const int limit = *(limit_tensor->scalar<int>().data());
const int end =
limit < 0 ? input.size()
: std::min(start + limit, static_cast<int>(input.size()));
const Tensor* emit_self_tensor;
OP_REQUIRES_OK(context,
context->input("emit_self_as_target", &emit_self_tensor));
const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());
std::vector<T> tokens;
std::vector<T> labels;
// Reserve the number of random numbers we will use - we use one for each
// token between start and end.
random::PhiloxRandom local_gen =
generator_.ReserveSamples32(end - start + 1);
random::SimplePhilox rng(&local_gen);
// For each token in the sentence, pick a random skip, then generates
// (token, label) pairs for all labels whose distances from the token are
// within the range [-skip, skip].
for (int i = start; i < end; ++i) {
const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
for (int j = -skips; j <= skips; ++j) {
if ((i + j < start) || (i + j >= end) ||
(j == 0 && !emit_self_as_target)) {
continue;
}
tokens.push_back(input(i));
labels.push_back(input(i + j));
}
}
Tensor* tokens_output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
"tokens", TensorShape({static_cast<int>(tokens.size())}),
&tokens_output));
Tensor* labels_output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
"labels", TensorShape({static_cast<int>(labels.size())}),
&labels_output));
OP_REQUIRES(
context, tokens_output->IsSameSize(*labels_output),
errors::Internal(strings::StrCat(
"Mismatch between tokens_output shape of ",
tokens_output->shape().DebugString(),
" and labels_output shape of ",
labels_output->shape().DebugString(),
". This should never happen - contact ami-team@ if it does.")));
// Copies results to output tensors.
for (typename std::vector<T>::size_type i = 0; i < tokens.size(); ++i) {
tokens_output->vec<T>()(i) = tokens[i];
labels_output->vec<T>()(i) = labels[i];
}
}