def _get_single_variable()

in tensorflow_recommenders_addons/embedding_variable/python/ops/embedding_variable.py [0:0]


  def _get_single_variable(self,
                           var_store,
                           name,
                           shape=None,
                           dtype=dtypes.float32,
                           ktype=dtypes.int64,
                           initializer=None,
                           regularizer=None,
                           partition_info=None,
                           reuse=None,
                           trainable=None,
                           collections=None,
                           caching_device=None,
                           validate_shape=True,
                           constraint=None,
                           synchronization=VariableSynchronization.AUTO,
                           aggregation=VariableAggregation.NONE):
    """Get or create a single Variable (e.g.

    a shard or entire variable).

    See the documentation of get_variable above (ignore partitioning components)
    for details.

    Args:
      name: see get_variable.
      shape: see get_variable.
      dtype: see get_variable.
      initializer: see get_variable.
      regularizer: see get_variable.
      partition_info: _PartitionInfo object.
      reuse: see get_variable.
      trainable: see get_variable.
      collections: see get_variable.
      caching_device: see get_variable.
      validate_shape: see get_variable.
      constraint: see get_variable.
      synchronization: see get_variable.
      aggregation: see get_variable.

    Returns:
      A Variable.  See documentation of get_variable above.

    Raises:
      ValueError: See documentation of get_variable above.
    """
    # Set to true if initializer is a constant.
    initializing_from_value = False
    if initializer is not None and not callable(initializer):
      initializing_from_value = True
    if shape is not None and initializing_from_value:
      raise ValueError("If initializer is a constant, do not specify shape.")

    dtype = dtypes.as_dtype(dtype)
    shape = tensor_shape.as_shape(shape)

    if name in var_store._vars:
      # Here we handle the case when returning an existing variable.
      if reuse is False:
        var = var_store._vars[name]
        err_msg = ("Variable %s already exists, disallowed."
                   " Did you mean to set reuse=True or "
                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
        # ResourceVariables don't have an op associated with so no traceback
        if isinstance(var, resource_variable_ops.ResourceVariable):
          raise ValueError(err_msg)
        tb = var.op.traceback[::-1]
        # Throw away internal tf entries and only take a few lines. In some
        # cases the traceback can be longer (e.g. if someone uses factory
        # functions to create variables) so we take more than needed in the
        # default case.
        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
        raise ValueError("%s Originally defined at:\n\n%s" %
                         (err_msg, "".join(traceback.format_list(tb))))
      found_var = var_store._vars[name]
      if not shape.is_compatible_with(found_var.get_shape()):
        raise ValueError("Trying to share variable %s, but specified shape %s"
                         " and found shape %s." %
                         (name, shape, found_var.get_shape()))
      if not dtype.is_compatible_with(found_var.dtype):
        dtype_str = dtype.name
        found_type_str = found_var.dtype.name
        raise ValueError("Trying to share variable %s, but specified dtype %s"
                         " and found dtype %s." %
                         (name, dtype_str, found_type_str))
      return found_var

    # The code below handles only the case of creating a new variable.
    if reuse is True:
      raise ValueError("Variable %s does not exist, or was not created with "
                       "tf.get_variable(). Did you mean to set "
                       "reuse=tf.AUTO_REUSE in VarScope?" % name)

    # Create the tensor to initialize the variable with default value.
    if initializer is None:
      initializer, initializing_from_value = self._get_default_initializer(
          name=name, shape=shape, dtype=dtype)
    # Enter an init scope when creating the initializer.
    with ops.init_scope():
      if initializing_from_value:
        init_val = initializer
        variable_dtype = None
      else:
        # Instantiate initializer if provided initializer is a type object.
        if tf_inspect.isclass(initializer):
          initializer = initializer()
        if shape.is_fully_defined():
          if "partition_info" in tf_inspect.getargspec(initializer).args:
            init_val = functools.partial(initializer,
                                         shape.as_list(),
                                         dtype=dtype,
                                         partition_info=partition_info)
          else:
            init_val = functools.partial(initializer,
                                         shape.as_list(),
                                         dtype=dtype)
          variable_dtype = dtype.base_dtype
        elif _needs_no_arguments(initializer):
          init_val = initializer
          variable_dtype = None
        else:
          raise ValueError("The initializer passed is not valid. It should "
                           "be a callable with no arguments and the "
                           "shape should not be provided or an instance of "
                           "`tf.keras.initializers.*' and `shape` should be "
                           "fully defined.")

    # Create the variable.
    v = ev.EmbeddingVariable(
        embedding_dim=shape,
        initializer=initializer,
        trainable=trainable,
        collections=collections,
        name=name,
        ktype=ktype,
        vtype=dtype,
        #variable_def=None,
        #import_scope=None,
        #distribute_strategy=None,
        caching_device=caching_device,
        #validate_shape=validate_shape,
        constraint=constraint,
        synchronization=synchronization,
        aggregation=aggregation)
    if context.executing_eagerly() and var_store._store_eager_variables:
      if collections:
        ops.add_to_collections(collections, v)
      else:
        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
      if trainable:
        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)

    if not context.executing_eagerly() or var_store._store_eager_variables:
      # In eager mode we do not want to keep default references to Variable
      # objects as this will prevent their memory from being released.
      var_store._vars[name] = v
    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
                 format(shape), initializer)

    # Run the regularizer if requested and save the resulting loss.
    if regularizer:

      def make_regularizer_op():
        with ops.colocate_with(v):
          with ops.name_scope(name + "/Regularizer/"):
            return regularizer(v)

      if regularizer(v) is not None:
        lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
        ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
                              lazy_eval_tensor)

    return v