def _create_slot_var()

in tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py [0:0]


def _create_slot_var(primary,
                     val,
                     scope,
                     validate_shape,
                     shape,
                     dtype,
                     *,
                     copy_xla_sharding=False):
  """Helper function for creating a slot variable."""

  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
  # scope.
  current_partitioner = variable_scope.get_variable_scope().partitioner
  variable_scope.get_variable_scope().set_partitioner(None)
  # When init from val instead of callable initializer, the shape is expected to
  # be None, not <unknown> or any fully defined shape.
  shape = shape if callable(val) else None
  if rvo.is_resource_variable(primary):
    use_resource = True
  elif isinstance(primary, variables.RefVariable):
    use_resource = False
  else:
    use_resource = None
  if isinstance(primary, ev.EmbeddingVariable):
    slot = ev.get_variable(scope,
                           embedding_dim=shape[1:],
                           initializer=val,
                           trainable=False,
                           key_dtype=primary._ktype,
                           value_dtype=primary.dtype)
  else:
    slot = variable_scope.get_variable(scope,
                                       initializer=val,
                                       trainable=False,
                                       use_resource=use_resource,
                                       shape=shape,
                                       dtype=dtype,
                                       validate_shape=validate_shape)
  variable_scope.get_variable_scope().set_partitioner(current_partitioner)

  # pylint: disable=protected-access
  if isinstance(primary, variables.Variable) and primary._save_slice_info:
    # Primary is a partitioned variable, so we need to also indicate that
    # the slot is a partitioned variable.  Slots have the same partitioning
    # as their primaries.
    # For examples when using AdamOptimizer in linear model, slot.name
    # here can be "linear//weights/Adam:0", while primary.op.name is
    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
    # remove "'linear//weight' + '/'" and ':0'.
    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
    slice_info = primary._save_slice_info
    # support slot's shape not same as primary's shape
    # example: primary's shape = [10, 20, 30], slot's shape =
    # None, [], [10], [10, 20] or [10, 20, 30] is allowed
    # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
    # slot's shape = [], don't set slot's slice_info
    # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
    n = slot.shape.ndims
    if n is None or n > 0:
      slot._set_save_slice_info(
          variables.Variable.SaveSliceInfo(
              slice_info.full_name + "/" + real_slot_name,
              slice_info.full_shape[:n], slice_info.var_offset[:n],
              slice_info.var_shape[:n]))
  # pylint: enable=protected-access

  # Copy XLA sharding attributes from primary.
  if copy_xla_sharding:
    try:
      from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
      slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
    except ImportError:
      tf_logging.warn("xla_sharding not found, maybe in tf version < 2.5")
      pass
  return slot