in tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py [0:0]
def _sparse_reshape_gpu(sp_input, shape, name=None):
if not hasattr(tfra_math_ops, 'tfra_sparse_reshape'):
tf_logging.warn('`tfra.dynamic_embedding.sparse_reshape` is not'
' found. Use tf.sparse.reshape instead.')
return tf.sparse.reshape(sp_input, shape, name=name)
sp_input = _convert_to_sparse_tensor(sp_input)
shape = math_ops.cast(shape, dtype=dtypes.int64)
with ops.name_scope(name, "SparseReshape", [sp_input]):
# shape = ops.convert_to_tensor(shape, dtype=sp_input.values.dtype)
reshaped_ind, reshaped_shape = tfra_math_ops.tfra_sparse_reshape(
sp_input.indices, sp_input.dense_shape, shape, name=name)
reshaped_shape_const = tensor_util.constant_value_as_shape(shape)
reshaped_shape_const = (reshaped_shape_const.as_list()
if reshaped_shape_const.ndims is not None else None)
if (reshaped_shape_const is not None and sp_input.shape.is_fully_defined()):
# constant_value_as_shape tends to get more information about the partial
# shape values, but here we specifically need to know if the *user* passed
# a shape with 2+ unknown dimensions; and for that constant_value
# provides either the user's direct value or None if only partial elements
# are known via the python shape inference code.
shape_const_by_user = tensor_util.constant_value(shape)
if shape_const_by_user is not None:
num_implied_by_user = sum(d == -1 for d in shape_const_by_user)
if num_implied_by_user > 1:
raise ValueError(
"At most one dimension can be inferred (-1). Found: %s" %
shape_const_by_user)
original_reshaped_shape = list(reshaped_shape_const) # A copy
in_shape_size = np.prod(sp_input.shape.as_list())
num_implied = sum(dim is None for dim in reshaped_shape_const)
if num_implied == 1:
implied_idx = original_reshaped_shape.index(None)
non_implied_idx = (original_reshaped_shape[:implied_idx] +
original_reshaped_shape[implied_idx + 1:])
reshaped_shape_const[implied_idx] = int(in_shape_size //
np.prod(non_implied_idx))
if num_implied <= 1:
reshaped_size = np.prod(reshaped_shape_const)
if reshaped_size != in_shape_size:
raise ValueError(
"Cannot reshape a tensor with %d elements to shape %s "
"(%d elements)." %
(in_shape_size, original_reshaped_shape, reshaped_size))
reshaped_shape = constant_op.constant(reshaped_shape_const,
dtype=dtypes.int64)
return sparse_tensor.SparseTensor(indices=reshaped_ind,
values=array_ops.identity(
sp_input.values),
dense_shape=reshaped_shape)