def _sparse_reshape_gpu()

in tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py [0:0]


def _sparse_reshape_gpu(sp_input, shape, name=None):
  if not hasattr(tfra_math_ops, 'tfra_sparse_reshape'):
    tf_logging.warn('`tfra.dynamic_embedding.sparse_reshape` is not'
                    ' found. Use tf.sparse.reshape instead.')
    return tf.sparse.reshape(sp_input, shape, name=name)

  sp_input = _convert_to_sparse_tensor(sp_input)
  shape = math_ops.cast(shape, dtype=dtypes.int64)
  with ops.name_scope(name, "SparseReshape", [sp_input]):
    # shape = ops.convert_to_tensor(shape, dtype=sp_input.values.dtype)
    reshaped_ind, reshaped_shape = tfra_math_ops.tfra_sparse_reshape(
        sp_input.indices, sp_input.dense_shape, shape, name=name)

    reshaped_shape_const = tensor_util.constant_value_as_shape(shape)
    reshaped_shape_const = (reshaped_shape_const.as_list()
                            if reshaped_shape_const.ndims is not None else None)

    if (reshaped_shape_const is not None and sp_input.shape.is_fully_defined()):
      # constant_value_as_shape tends to get more information about the partial
      # shape values, but here we specifically need to know if the *user* passed
      # a shape with 2+ unknown dimensions; and for that constant_value
      # provides either the user's direct value or None if only partial elements
      # are known via the python shape inference code.
      shape_const_by_user = tensor_util.constant_value(shape)
      if shape_const_by_user is not None:
        num_implied_by_user = sum(d == -1 for d in shape_const_by_user)
        if num_implied_by_user > 1:
          raise ValueError(
              "At most one dimension can be inferred (-1). Found: %s" %
              shape_const_by_user)
      original_reshaped_shape = list(reshaped_shape_const)  # A copy
      in_shape_size = np.prod(sp_input.shape.as_list())
      num_implied = sum(dim is None for dim in reshaped_shape_const)
      if num_implied == 1:
        implied_idx = original_reshaped_shape.index(None)
        non_implied_idx = (original_reshaped_shape[:implied_idx] +
                           original_reshaped_shape[implied_idx + 1:])
        reshaped_shape_const[implied_idx] = int(in_shape_size //
                                                np.prod(non_implied_idx))
      if num_implied <= 1:
        reshaped_size = np.prod(reshaped_shape_const)
        if reshaped_size != in_shape_size:
          raise ValueError(
              "Cannot reshape a tensor with %d elements to shape %s "
              "(%d elements)." %
              (in_shape_size, original_reshaped_shape, reshaped_size))
        reshaped_shape = constant_op.constant(reshaped_shape_const,
                                              dtype=dtypes.int64)

    return sparse_tensor.SparseTensor(indices=reshaped_ind,
                                      values=array_ops.identity(
                                          sp_input.values),
                                      dense_shape=reshaped_shape)