void Compute()

in tensorflow_recommenders_addons/embedding_variable/core/kernels/ev_ops.cc [410:518]


  void Compute(OpKernelContext* ctx) {
    auto locks = MaybeLockEmbeddingVariableInputMutexesInOrder<TKey, TValue>(
        ctx, use_exclusive_lock_, {0, 1, 2});
    EmbeddingVar<TKey, TValue>* var = nullptr;
    OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));

    EmbeddingVar<TKey, TValue>* m = nullptr;
    OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &m));

    EmbeddingVar<TKey, TValue>* v = nullptr;
    OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &v));

    const Tensor& beta1_power = ctx->input(3);
    const Tensor& beta2_power = ctx->input(4);
    const Tensor& lr = ctx->input(5);
    const Tensor& beta1 = ctx->input(6);
    const Tensor& beta2 = ctx->input(7);
    const Tensor& epsilon = ctx->input(8);
    const Tensor& grad = ctx->input(9);
    const Tensor& indices = ctx->input(10);
    const Tensor& global_step = ctx->input(11);

    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
                errors::InvalidArgument("beta1_power is not a scalar: ",
                                        beta1_power.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
                errors::InvalidArgument("beta2_power is not a scalar: ",
                                        beta2_power.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
                errors::InvalidArgument("lr is not a scalar: ",
                                        lr.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
                errors::InvalidArgument("beta1 is not a scalar: ",
                                        beta1.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
                errors::InvalidArgument("beta2 is not a scalar: ",
                                        beta2.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
                errors::InvalidArgument("epsilon is not a scalar: ",
                                        epsilon.shape().DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
                errors::InvalidArgument("indices must be one-dimensional"));

    int64 inner_dim = 1;
    TensorShape var_shape({var->Size(), var->ValueLen()});
    for (int d = 1; d < var_shape.dims(); d++) {
      OP_REQUIRES(ctx, var_shape.dim_size(d) == grad.dim_size(d),
                  errors::InvalidArgument(strings::StrCat(
                      "var and grad must match in dimension ", d)));
      inner_dim *= grad.dim_size(d);
    }
    OP_REQUIRES(ctx, inner_dim > 0,
                errors::InvalidArgument(
                    "Inner dimension should be greater than zero."));

    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
                errors::InvalidArgument("global_step is not a scalar: ",
                                        global_step.shape().DebugString()));

    const int64 N = indices.dim_size(0);
    OP_REQUIRES(
        ctx, grad.dim_size(0) == N,
        errors::InvalidArgument(
            "grad must be the same size as indices in the first dimension."));

    if (N > 0) {
      TValue beta1_power_scalar = beta1_power.scalar<TValue>()();
      TValue beta2_power_scalar = beta2_power.scalar<TValue>()();
      TValue lr_scalar = lr.scalar<TValue>()();
      TValue beta1_scalar = beta1.scalar<TValue>()();
      TValue beta2_scalar = beta2.scalar<TValue>()();
      TValue epsilon_scalar = epsilon.scalar<TValue>()();
      const TValue alpha =
          lr_scalar *
          Eigen::numext::sqrt(static_cast<TValue>(1) - beta2_power_scalar) /
          (static_cast<TValue>(1) - beta1_power_scalar);

      auto DoWork = [this, ctx, inner_dim, &var, &m, &v, &grad, &indices,
                     &beta1_power_scalar, &beta2_power_scalar, &lr_scalar,
                     &beta1_scalar, &beta2_scalar, &epsilon_scalar, &alpha,
                     &global_step](int64 start_i, int64 limit_i) {
        if (inner_dim > 0) {
          auto grad_flat = grad.flat_outer_dims<TValue>();
          auto indices_vec = indices.vec<TKey>();

          TStep gs = global_step.scalar<TStep>()();

          for (int64 i = static_cast<int64>(start_i);
               i < static_cast<int64>(limit_i); i++) {
            const TKey index = indices_vec(i);

            auto var_i = var->flat(index, gs);
            auto m_a = m->flat(index, gs);
            auto v_a = v->flat(index, gs);

            auto g = grad_flat.template chip<0>(i);
            m_a += (g - m_a) * (static_cast<TValue>(1) - beta1_scalar);
            v_a += (g.square() - v_a) * (static_cast<TValue>(1) - beta2_scalar);
            var_i -= (m_a * alpha) / (v_a.sqrt() + epsilon_scalar);
          }
        }
      };

      const int64 cost = 1000;
      auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
      Shard(worker_threads.num_threads, worker_threads.workers, N, cost,
            DoWork);
    }
  }