in tensorflow_recommenders_addons/embedding_variable/python/ops/embedding_variable.py [0:0]
def get_variable(
name, # unique,
embedding_dim,
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
initializer=None,
regularizer=None,
reuse=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
constraint=None):
if key_dtype == dtypes.int64 or key_dtype == dtypes.int32:
invalid_key = -1
else:
raise ValueError("Not support key_dtype: %s, only support int64/int32" %
key_dtype)
if initializer is None:
initializer = init_ops.truncated_normal_initializer()
scope = variable_scope.get_variable_scope()
scope_store = variable_scope._get_default_variable_store()
if regularizer is None:
regularizer = scope._regularizer
if caching_device is None:
caching_device = scope._caching_device
if partitioner is None:
partitioner = scope._partitioner
if not context.executing_eagerly():
if reuse is None:
reuse = scope._reuse
else:
reuse = AUTO_REUSE
full_name = scope.name + "/" + name if scope.name else name
# Variable names only depend on variable_scope (full_name here),
# not name_scope, so we reset it below for the time of variable creation.
with ops.name_scope(None):
dtype = value_dtype
# Check that `initializer` dtype and `dtype` are consistent before
# replacing them with defaults.
if (dtype is not None and initializer is not None
and not callable(initializer)):
init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
if init_dtype != dtype:
raise ValueError("Initializer type '%s' and explicit dtype '%s' "
"don't match." % (init_dtype, dtype))
if initializer is None:
initializer = scope._initializer
if constraint is None:
constraint = scope._constraint
if dtype is None:
dtype = scope._dtype
if invalid_key is None:
invalid_key = -1
ev_store = _EmbeddingVariableStore(scope_store)
return ev_store.get_variable(full_name,
shape=embedding_dim,
dtype=value_dtype,
ktype=key_dtype,
initializer=initializer,
regularizer=regularizer,
reuse=reuse,
trainable=trainable,
collections=collections,
caching_device=caching_device,
partitioner=partitioner,
validate_shape=True,
constraint=None)