def newsgroups_dataset()

in example_zoo/tensorflow/probability/latent_dirichlet_allocation_edward2/trainer/latent_dirichlet_allocation_edward2.py [0:0]


def newsgroups_dataset(directory, split_name, num_words, shuffle_and_repeat):
  """20 newsgroups as a tf.data.Dataset."""
  data = np.load(download(directory, FILE_TEMPLATE.format(split=split_name)))
  # The last row is empty in both train and test.
  data = data[:-1]

  # Each row is a list of word ids in the document. We first convert this to
  # sparse COO matrix (which automatically sums the repeating words). Then,
  # we convert this COO matrix to CSR format which allows for fast querying of
  # documents.
  num_documents = data.shape[0]
  indices = np.array([(row_idx, column_idx)
                      for row_idx, row in enumerate(data)
                      for column_idx in row])
  sparse_matrix = scipy.sparse.coo_matrix(
      (np.ones(indices.shape[0]), (indices[:, 0], indices[:, 1])),
      shape=(num_documents, num_words),
      dtype=np.float32)
  sparse_matrix = sparse_matrix.tocsr()

  dataset = tf.data.Dataset.range(num_documents)

  # For training, we shuffle each epoch and repeat the epochs.
  if shuffle_and_repeat:
    dataset = dataset.shuffle(num_documents).repeat()

  # Returns a single document as a dense TensorFlow tensor. The dataset is
  # stored as a sparse matrix outside of the graph.
  def get_row_py_func(idx):
    def get_row_python(idx_py):
      return np.squeeze(np.array(sparse_matrix[idx_py].todense()), axis=0)

    py_func = tf.compat.v1.py_func(
        get_row_python, [idx], tf.float32, stateful=False)
    py_func.set_shape((num_words,))
    return py_func

  dataset = dataset.map(get_row_py_func)
  return dataset