in tensorflow_recommenders_addons/embedding_variable/python/ops/embedding_variable.py [0:0]
def _get_single_variable(self,
var_store,
name,
shape=None,
dtype=dtypes.float32,
ktype=dtypes.int64,
initializer=None,
regularizer=None,
partition_info=None,
reuse=None,
trainable=None,
collections=None,
caching_device=None,
validate_shape=True,
constraint=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g.
a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
for details.
Args:
name: see get_variable.
shape: see get_variable.
dtype: see get_variable.
initializer: see get_variable.
regularizer: see get_variable.
partition_info: _PartitionInfo object.
reuse: see get_variable.
trainable: see get_variable.
collections: see get_variable.
caching_device: see get_variable.
validate_shape: see get_variable.
constraint: see get_variable.
synchronization: see get_variable.
aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
Raises:
ValueError: See documentation of get_variable above.
"""
# Set to true if initializer is a constant.
initializing_from_value = False
if initializer is not None and not callable(initializer):
initializing_from_value = True
if shape is not None and initializing_from_value:
raise ValueError("If initializer is a constant, do not specify shape.")
dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.as_shape(shape)
if name in var_store._vars:
# Here we handle the case when returning an existing variable.
if reuse is False:
var = var_store._vars[name]
err_msg = ("Variable %s already exists, disallowed."
" Did you mean to set reuse=True or "
"reuse=tf.AUTO_REUSE in VarScope?" % name)
# ResourceVariables don't have an op associated with so no traceback
if isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(err_msg)
tb = var.op.traceback[::-1]
# Throw away internal tf entries and only take a few lines. In some
# cases the traceback can be longer (e.g. if someone uses factory
# functions to create variables) so we take more than needed in the
# default case.
tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
raise ValueError("%s Originally defined at:\n\n%s" %
(err_msg, "".join(traceback.format_list(tb))))
found_var = var_store._vars[name]
if not shape.is_compatible_with(found_var.get_shape()):
raise ValueError("Trying to share variable %s, but specified shape %s"
" and found shape %s." %
(name, shape, found_var.get_shape()))
if not dtype.is_compatible_with(found_var.dtype):
dtype_str = dtype.name
found_type_str = found_var.dtype.name
raise ValueError("Trying to share variable %s, but specified dtype %s"
" and found dtype %s." %
(name, dtype_str, found_type_str))
return found_var
# The code below handles only the case of creating a new variable.
if reuse is True:
raise ValueError("Variable %s does not exist, or was not created with "
"tf.get_variable(). Did you mean to set "
"reuse=tf.AUTO_REUSE in VarScope?" % name)
# Create the tensor to initialize the variable with default value.
if initializer is None:
initializer, initializing_from_value = self._get_default_initializer(
name=name, shape=shape, dtype=dtype)
# Enter an init scope when creating the initializer.
with ops.init_scope():
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
# Instantiate initializer if provided initializer is a type object.
if tf_inspect.isclass(initializer):
initializer = initializer()
if shape.is_fully_defined():
if "partition_info" in tf_inspect.getargspec(initializer).args:
init_val = functools.partial(initializer,
shape.as_list(),
dtype=dtype,
partition_info=partition_info)
else:
init_val = functools.partial(initializer,
shape.as_list(),
dtype=dtype)
variable_dtype = dtype.base_dtype
elif _needs_no_arguments(initializer):
init_val = initializer
variable_dtype = None
else:
raise ValueError("The initializer passed is not valid. It should "
"be a callable with no arguments and the "
"shape should not be provided or an instance of "
"`tf.keras.initializers.*' and `shape` should be "
"fully defined.")
# Create the variable.
v = ev.EmbeddingVariable(
embedding_dim=shape,
initializer=initializer,
trainable=trainable,
collections=collections,
name=name,
ktype=ktype,
vtype=dtype,
#variable_def=None,
#import_scope=None,
#distribute_strategy=None,
caching_device=caching_device,
#validate_shape=validate_shape,
constraint=constraint,
synchronization=synchronization,
aggregation=aggregation)
if context.executing_eagerly() and var_store._store_eager_variables:
if collections:
ops.add_to_collections(collections, v)
else:
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
if trainable:
ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
if not context.executing_eagerly() or var_store._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
var_store._vars[name] = v
logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)
# Run the regularizer if requested and save the resulting loss.
if regularizer:
def make_regularizer_op():
with ops.colocate_with(v):
with ops.name_scope(name + "/Regularizer/"):
return regularizer(v)
if regularizer(v) is not None:
lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
lazy_eval_tensor)
return v