in tensorflow_recommenders_addons/embedding_variable/core/kernels/ev_ops.cc [410:518]
void Compute(OpKernelContext* ctx) {
auto locks = MaybeLockEmbeddingVariableInputMutexesInOrder<TKey, TValue>(
ctx, use_exclusive_lock_, {0, 1, 2});
EmbeddingVar<TKey, TValue>* var = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
EmbeddingVar<TKey, TValue>* m = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &m));
EmbeddingVar<TKey, TValue>* v = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &v));
const Tensor& beta1_power = ctx->input(3);
const Tensor& beta2_power = ctx->input(4);
const Tensor& lr = ctx->input(5);
const Tensor& beta1 = ctx->input(6);
const Tensor& beta2 = ctx->input(7);
const Tensor& epsilon = ctx->input(8);
const Tensor& grad = ctx->input(9);
const Tensor& indices = ctx->input(10);
const Tensor& global_step = ctx->input(11);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
errors::InvalidArgument("beta1_power is not a scalar: ",
beta1_power.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
errors::InvalidArgument("beta2_power is not a scalar: ",
beta2_power.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
errors::InvalidArgument("beta1 is not a scalar: ",
beta1.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
errors::InvalidArgument("beta2 is not a scalar: ",
beta2.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices must be one-dimensional"));
int64 inner_dim = 1;
TensorShape var_shape({var->Size(), var->ValueLen()});
for (int d = 1; d < var_shape.dims(); d++) {
OP_REQUIRES(ctx, var_shape.dim_size(d) == grad.dim_size(d),
errors::InvalidArgument(strings::StrCat(
"var and grad must match in dimension ", d)));
inner_dim *= grad.dim_size(d);
}
OP_REQUIRES(ctx, inner_dim > 0,
errors::InvalidArgument(
"Inner dimension should be greater than zero."));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
errors::InvalidArgument("global_step is not a scalar: ",
global_step.shape().DebugString()));
const int64 N = indices.dim_size(0);
OP_REQUIRES(
ctx, grad.dim_size(0) == N,
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
if (N > 0) {
TValue beta1_power_scalar = beta1_power.scalar<TValue>()();
TValue beta2_power_scalar = beta2_power.scalar<TValue>()();
TValue lr_scalar = lr.scalar<TValue>()();
TValue beta1_scalar = beta1.scalar<TValue>()();
TValue beta2_scalar = beta2.scalar<TValue>()();
TValue epsilon_scalar = epsilon.scalar<TValue>()();
const TValue alpha =
lr_scalar *
Eigen::numext::sqrt(static_cast<TValue>(1) - beta2_power_scalar) /
(static_cast<TValue>(1) - beta1_power_scalar);
auto DoWork = [this, ctx, inner_dim, &var, &m, &v, &grad, &indices,
&beta1_power_scalar, &beta2_power_scalar, &lr_scalar,
&beta1_scalar, &beta2_scalar, &epsilon_scalar, &alpha,
&global_step](int64 start_i, int64 limit_i) {
if (inner_dim > 0) {
auto grad_flat = grad.flat_outer_dims<TValue>();
auto indices_vec = indices.vec<TKey>();
TStep gs = global_step.scalar<TStep>()();
for (int64 i = static_cast<int64>(start_i);
i < static_cast<int64>(limit_i); i++) {
const TKey index = indices_vec(i);
auto var_i = var->flat(index, gs);
auto m_a = m->flat(index, gs);
auto v_a = v->flat(index, gs);
auto g = grad_flat.template chip<0>(i);
m_a += (g - m_a) * (static_cast<TValue>(1) - beta1_scalar);
v_a += (g.square() - v_a) * (static_cast<TValue>(1) - beta2_scalar);
var_i -= (m_a * alpha) / (v_a.sqrt() + epsilon_scalar);
}
}
};
const int64 cost = 1000;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, N, cost,
DoWork);
}
}