easy_rec/python/compat/dynamic_variable.py (340 lines of code) (raw):
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import tensorflow as tf
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
from sparse_operation_kit.experiment.communication import num_gpus
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
# from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
from tensorflow.python.ops.resource_variable_ops import variable_accessed
# from tensorflow.python.util import object_identity
dynamic_variable_count = 0
_resource_var_from_proto = ResourceVariable.from_proto
class DynamicVariable(ResourceVariable):
"""Abbreviated as ``sok.experiment.DynamicVariable``.
A variable that allocates memory dynamically.
Parameters
----------
dimension: int
The last dimension of this variable(that is, the embedding vector
size of embedding table).
initializer: string
a string to specify how to initialize this variable.
Currently, only support "random" or string of a float
value(meaning const initializer). Default value is "random".
var_type: string
a string to specify to use DET or HKV as the backend.
If use HKV as the backend, only support tf.int64 as key_type
If use HKV as the backend, please set init_capacity and max_capacity value equal to 2 powers.
key_type: dtype
specify the data type of indices. Unlike the static variable of
tensorflow, this variable is dynamically allocated and contains
a hash table inside it. So the data type of indices must be
specified to construct the hash table. Default value is tf.int64.
dtype: dtype
specify the data type of values. Default value is tf.float32.
Example
-------
.. code-block:: python
import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok
v = sok.DynamicVariable(dimension=3, initializer="13")
print("v.shape:", v.shape)
print("v.size:", v.size)
indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
embedding = tf.nn.embedding_lookup(v, indices)
print("embedding:", embedding)
print("v.shape:", v.shape)
print("v.size:", v.size)
"""
def __init__(self,
dimension,
initializer=None,
var_type=None,
name=None,
constraint=None,
trainable=True,
key_type=None,
dtype=None,
mode=None,
variable_def=None,
import_scope=None,
**kwargs):
self._indices = None
if variable_def is not None:
super(DynamicVariable, self)._init_from_proto(
variable_def, import_scope=import_scope, validate_shape=False)
g = ops.get_default_graph()
handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope),
allow_operation=False)
self._dimension = handle.op.get_attr('shape').dim[-1].size
self._key_type = handle.op.get_attr('key_type')
self._handle_type = handle.op.get_attr('dtype')
self._mode = None
self._config = {}
self._name = variable_def.variable_name.split(':')[0]
self._trainable = variable_def.trainable
self._dummy_handle = handle
self._handle = handle
# init op
init_op = g.as_graph_element(variable_def.initializer_name)
self._initializer_op = init_op
init_tf = init_op.control_inputs[0]
# init_dummy = init_op.control_inputs[1]
self._tf_handle = init_tf.inputs[0]
return
self._key_type = key_type if key_type is not None else tf.int64
self._handle_dtype = dtype if dtype is not None else tf.float32
self._dimension = dimension
self._mode = mode
self._config = json.dumps(kwargs)
self._config_dict = kwargs
if var_type == 'hybrid' and self._key_type != tf.int64:
raise NotImplementedError(
'only key_type tf.int64 is supported in HKV backend')
if name is None:
global dynamic_variable_count
name = 'sok_dynamic_Variable_' + str(dynamic_variable_count)
dynamic_variable_count += 1
var_type = 'hbm' if var_type is None else var_type
self._var_type = var_type
self._base = super(DynamicVariable, self)
self._base.__init__(
initial_value=[[0.0] * dimension],
trainable=trainable,
name=name + '/proxy',
dtype=self._handle_dtype,
constraint=constraint,
distribute_strategy=None,
synchronization=None,
aggregation=None,
shape=[None, dimension],
)
with ops.init_scope():
# name = "DynamicVariable" if name is None else name
with ops.name_scope(name) as name_scope:
self._dummy_name = ops.name_from_scope_name(name_scope)
if context.executing_eagerly():
self._dummy_name = '%s_%d' % (name, ops.uid())
with ops.NullContextmanager():
shape = [None, dimension]
initializer = '' if initializer is None else initializer
self._initializer = initializer
handle = dynamic_variable_ops.dummy_var_handle(
container='DummyVariableContainer',
shared_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
shape=shape,
)
if type(initializer) is str:
init_op = dynamic_variable_ops.dummy_var_initialize(
handle,
initializer=initializer,
var_type=var_type,
unique_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
config=self._config,
)
else:
with tf.control_dependencies([initializer._initializer_op]):
initial_val = initializer.read_value()
init_op = dynamic_variable_ops.dummy_var_initialize(
handle,
initializer=initial_val,
var_type=var_type,
unique_name=self._dummy_name,
key_type=self._key_type,
dtype=self._handle_dtype,
config=self._config,
)
# TODO: Add is_initialized_op
# is_initialized_op = ops.convert_to_tensor(True)
self._tf_handle = self._handle
self._dummy_handle = handle
# Note that the default handle will be sok's handle
self._handle = self._dummy_handle
self._initializer_op = tf.group([self._initializer_op, init_op])
# self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op])
handle_data = (
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
.HandleData())
handle_data.is_set = True
handle_data.shape_and_type.append(
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
.HandleShapeAndType(
shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum))
resource_variable_ops._set_handle_shapes_and_types(
self._handle,
handle_data,
graph_mode=False if context.executing_eagerly() else True)
def is_static(self):
return self._handle is self._tf_handle
def to_static(self, indices, lookup_only=False):
if not self.is_static() and self._indices is None:
buffer = self.sparse_read(indices, lookup_only)
self._indices = indices
self._handle = self._tf_handle
return self.assign(buffer)
else:
raise RuntimeError('to_static() must be called in dynamic mode.')
def to_dynamic(self):
if self.is_static():
buffer = self.read_value()
sparse_delta = ops.IndexedSlices(buffer, self._indices, self.shape)
self._indices = None
self._handle = self._dummy_handle
return self.scatter_update(sparse_delta)
else:
raise RuntimeError('to_dynamic() must be called in static mode.')
@property
def name(self):
return self._dummy_handle.name
def __repr__(self):
if self.is_static():
return self._base.__repr__()
return "<sok.DynamicVariable '%s' shape=%s dtype=%s>" % (
self._dummy_name,
self.shape,
self.dtype.name,
)
@property
def size(self):
return dynamic_variable_ops.dummy_var_shape(
self._dummy_handle, key_type=self._key_type, dtype=self._handle_dtype)
@property
def indices(self):
return self._indices
@property
def dimension(self):
return self._dimension
def get_shape(self):
return [self._dimension]
@property
def key_type(self):
return self._key_type
@property
def handle_dtype(self):
return self._handle_dtype
@property
def backend_type(self):
return self._var_type
@property
def config_dict(self):
return self._config_dict
@property
def mode(self):
return self._mode
@property
def num_gpus(self):
return num_gpus()
@property
def initializer_str(self):
return self._initializer
def key_map(self, indices):
return indices
# -------------------------------------------------------------------------
# Methods supported both in static mode and dynamic mode
# -------------------------------------------------------------------------
def sparse_read(self, indices, name=None, lookup_only=False):
if self.is_static():
return self._base.sparse_read(indices, name)
variable_accessed(self)
if indices.dtype == tf.int32:
indices = tf.cast(indices, tf.int64)
return dynamic_variable_ops.dummy_var_sparse_read(
self._dummy_handle,
indices,
dtype=self._handle_dtype,
lookup_only=lookup_only)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
if self.is_static():
return self._base.scatter_sub(sparse_delta, use_locking, name)
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
return dynamic_variable_ops.dummy_var_scatter_add(
self._dummy_handle,
sparse_delta.indices,
ops.convert_to_tensor(-sparse_delta.values, self.dtype),
)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
if self.is_static():
return self._base.scatter_add(sparse_delta, use_locking, name)
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
return dynamic_variable_ops.dummy_var_scatter_add(
self._dummy_handle,
sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype),
)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if self.is_static():
return self._base.scatter_update(sparse_delta, use_locking, name)
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
return dynamic_variable_ops.dummy_var_scatter_update(
self._dummy_handle,
sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype),
)
# -------------------------------------------------------------------------
# Methods not supported both in static mode and dynamic mode
# -------------------------------------------------------------------------
def __deepcopy__(self, *args, **kwargs):
raise NotImplementedError('__deepcopy__() is not supported.')
def __reduce__(self, *args, **kwargs):
raise NotImplementedError('__reduce__() is not supported.')
def to_proto(self, *args, **kwargs):
return super(DynamicVariable, self).to_proto(*args, **kwargs)
# raise NotImplementedError("to_proto() is not supported.")
@staticmethod
def from_proto(variable_def, import_scope=None):
if '/DummyVarHandle' in variable_def.variable_name:
return DynamicVariable(
dimension=0, variable_def=variable_def, import_scope=import_scope)
else:
return _resource_var_from_proto(variable_def, import_scope)
# raise NotImplementedError("from_proto() is not supported.")
def set_shape(self, *args, **kwargs):
raise NotImplementedError('set_shape() is not supported.')
# -------------------------------------------------------------------------
# Methods only supported in static mode
# -------------------------------------------------------------------------
def is_initialized(self, name):
return True
if self.is_static():
return self._base.is_initialized(name)
raise NotImplementedError(
'is_initialized() is not supported in dynamic mode.')
def _read_variable_op(self):
if self.is_static():
return self._base._read_variable_op()
raise NotImplementedError(
'_read_variable_op() is not supported in dynamic mode.')
def value(self):
if self.is_static():
return self._base.value()
raise NotImplementedError('value() is not supported in dynamic mode.')
def _dense_var_to_tensor(self, *args, **kwargs):
if self.is_static():
return self._base._dense_var_to_tensor(*args, **kwargs)
raise NotImplementedError(
'_dense_var_to_tensor() is not supported in dynamic mode.')
def _gather_saveables_for_checkpoint(self):
if self.is_static():
return self._base._gather_saveables_for_checkpoint()
raise NotImplementedError(
'_gather_saveables_for_checkpoint() is not supported in dynamic mode.')
def gather_nd(self, *args, **kwargs):
if self.is_static():
return self._base.gather_nd(*args, **kwargs)
raise NotImplementedError('gather_nd() is not supported in dynamic mode.')
def assign_add(self, *args, **kwargs):
if self.is_static():
return self._base.assign_add(*args, **kwargs)
raise NotImplementedError('assign_add() is not supported in dynamic mode.')
def assign(self, *args, **kwargs):
if self.is_static():
return self._base.assign(*args, **kwargs)
raise NotImplementedError('assign() is not supported in dynamic mode.')
def scatter_max(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_max(*args, **kwargs)
raise NotImplementedError('scatter_max() is not supported in dynamic mode.')
def scatter_min(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_min(*args, **kwargs)
raise NotImplementedError('scatter_min() is not supported in dynamic mode.')
def scatter_mul(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_mul(*args, **kwargs)
raise NotImplementedError('scatter_mul() is not supported in dynamic mode.')
def scatter_dim(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_dim(*args, **kwargs)
raise NotImplementedError('scatter_dim() is not supported in dynamic mode.')
def batch_scatter_update(self, *args, **kwargs):
if self.is_static():
return self._base.batch_scatter_update(*args, **kwargs)
raise NotImplementedError(
'batch_scatter_update() is not supported in dynamic mode.')
def scatter_nd_sub(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_nd_sub(*args, **kwargs)
raise NotImplementedError(
'scatter_nd_sub() is not supported in dynamic mode.')
def scatter_nd_update(self, *args, **kwargs):
if self.is_static():
return self._base.scatter_nd_update(*args, **kwargs)
raise NotImplementedError(
'scatter_nd_update() is not supported in dynamic mode.')
def _strided_slice_assign(self, *args, **kwargs):
if self.is_static():
return self._base._strided_slice_assign(*args, **kwargs)
raise NotImplementedError(
'_strided_slice_assign() is not supported in dynamic mode.')
def __int__(self, *args, **kwargs):
if self.is_static():
return self._base.__int__(*args, **kwargs)
raise NotImplementedError('__int__() is not supported in dynamic mode.')
ResourceVariable.from_proto = DynamicVariable.from_proto
# @tf.RegisterGradient("DummyVarSparseRead")
# def _SparseReadGrad(op, grad):
# """Gradient for sparse_read."""
# handle = op.inputs[0]
# indices = op.inputs[1]
# key_type = op.get_attr("key_type")
# dtype = op.get_attr("dtype")
# variable_shape = dynamic_variable_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype)
# size = array_ops.expand_dims(array_ops.size(indices), 0)
# values_shape = array_ops.concat([size, variable_shape[1:]], 0)
# grad = array_ops.reshape(grad, values_shape)
# indices = array_ops.reshape(indices, size)
# return (ops.IndexedSlices(grad, indices, variable_shape), None)
def export(var):
"""Abbreviated as ``sok.experiment.export``.
Export the indices and value tensor from the given variable.
Parameters
----------
var: sok.DynamicVariable
The variable to extract indices and values.
Returns
-------
indices: tf.Tensor
The indices of the given variable.
values: tf.Tensor
the values of the given variable.
"""
if isinstance(var, DynamicVariable):
indices, values = dynamic_variable_ops.dummy_var_export(
var.handle, key_type=var.key_type, dtype=var.handle_dtype)
with tf.device('CPU'):
indices = tf.identity(indices)
values = tf.identity(values)
return indices, values
def assign(var, indices, values):
"""Abbreviated as ``sok.experiment.assign``.
Assign the indices and value tensor to the target variable.
Parameters
----------
var: sok.DynamicVariable
The target variable of assign.
indices: tf.Tensor
indices to be assigned to the variable.
values: tf.Tensor
values to be assigned to the variable
Returns
-------
variable: sok.DynamicVariable
"""
if isinstance(var, DynamicVariable):
tf.cast(indices, var._key_type)
return dynamic_variable_ops.dummy_var_assign(var.handle, indices, values)