in easy_rec/python/compat/dynamic_variable.py [0:0]
def __init__(self,
dimension,
initializer=None,
var_type=None,
name=None,
constraint=None,
trainable=True,
key_type=None,
dtype=None,
mode=None,
variable_def=None,
import_scope=None,
**kwargs):
self._indices = None
if variable_def is not None:
super(DynamicVariable, self)._init_from_proto(
variable_def, import_scope=import_scope, validate_shape=False)
g = ops.get_default_graph()
handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope),
allow_operation=False)
self._dimension = handle.op.get_attr('shape').dim[-1].size
self._key_type = handle.op.get_attr('key_type')
self._handle_type = handle.op.get_attr('dtype')
self._mode = None
self._config = {}
self._name = variable_def.variable_name.split(':')[0]
self._trainable = variable_def.trainable
self._dummy_handle = handle
self._handle = handle
# init op
init_op = g.as_graph_element(variable_def.initializer_name)
self._initializer_op = init_op
init_tf = init_op.control_inputs[0]
# init_dummy = init_op.control_inputs[1]
self._tf_handle = init_tf.inputs[0]
return
self._key_type = key_type if key_type is not None else tf.int64
self._handle_dtype = dtype if dtype is not None else tf.float32
self._dimension = dimension
self._mode = mode
self._config = json.dumps(kwargs)
self._config_dict = kwargs
if var_type == 'hybrid' and self._key_type != tf.int64:
raise NotImplementedError(
'only key_type tf.int64 is supported in HKV backend')
if name is None:
global dynamic_variable_count
name = 'sok_dynamic_Variable_' + str(dynamic_variable_count)
dynamic_variable_count += 1
var_type = 'hbm' if var_type is None else var_type
self._var_type = var_type
self._base = super(DynamicVariable, self)
self._base.__init__(
initial_value=[[0.0] * dimension],
trainable=trainable,
name=name + '/proxy',
dtype=self._handle_dtype,
constraint=constraint,
distribute_strategy=None,
synchronization=None,
aggregation=None,
shape=[None, dimension],
)
with ops.init_scope():
# name = "DynamicVariable" if name is None else name
with ops.name_scope(name) as name_scope:
self._dummy_name = ops.name_from_scope_name(name_scope)
if context.executing_eagerly():
self._dummy_name = '%s_%d' % (name, ops.uid())
with ops.NullContextmanager():
shape = [None, dimension]
initializer = '' if initializer is None else initializer
self._initializer = initializer
handle = dynamic_variable_ops.dummy_var_handle(
container='DummyVariableContainer',
shared_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
shape=shape,
)
if type(initializer) is str:
init_op = dynamic_variable_ops.dummy_var_initialize(
handle,
initializer=initializer,
var_type=var_type,
unique_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
config=self._config,
)
else:
with tf.control_dependencies([initializer._initializer_op]):
initial_val = initializer.read_value()
init_op = dynamic_variable_ops.dummy_var_initialize(
handle,
initializer=initial_val,
var_type=var_type,
unique_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
config=self._config,
)
# TODO: Add is_initialized_op
# is_initialized_op = ops.convert_to_tensor(True)
self._tf_handle = self._handle
self._dummy_handle = handle
# Note that the default handle will be sok's handle
self._handle = self._dummy_handle
self._initializer_op = tf.group([self._initializer_op, init_op])
# self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op])
handle_data = (
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
.HandleData())
handle_data.is_set = True
handle_data.shape_and_type.append(
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
.HandleShapeAndType(
shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum))
resource_variable_ops._set_handle_shapes_and_types(
self._handle,
handle_data,
graph_mode=False if context.executing_eagerly() else True)