def generate_vocab()

in tensorflow_text/tools/wordpiece_vocab/generate_vocab.py [0:0]


def generate_vocab(data_file, vocab_file, metrics_file, raw_metadata, params,
                   min_token_frequency=2):
  """Returns a pipeline generating a vocab and writing the output.

  Args:
    data_file: recordio file to read
    vocab_file: path in which to write the vocab
    metrics_file: path in which to write the metrics
    raw_metadata: schema for dataset
    params: parameters for wordpiece vocab learning algorithm
    min_token_frequency: the min frequency for a token to be included
  """

  lang_set = set(FLAGS.lang_set.split(','))

  # Schema to format metrics as CSV.
  csv_schema = schema_utils.schema_from_feature_spec({
      'lang': tf.FixedLenFeature([], tf.string),
      'sample_count': tf.FixedLenFeature([], tf.int64),
      'micro_drop_char_percent': tf.FixedLenFeature([], tf.string),
      'macro_drop_char_percent': tf.FixedLenFeature([], tf.string),
      'micro_compress_ratio': tf.FixedLenFeature([], tf.string),
      'macro_compress_ratio': tf.FixedLenFeature([], tf.string),
      'unweighted_en_wp_overlap_percent': tf.FixedLenFeature([], tf.string),
      'weighted_en_wp_overlap_percent': tf.FixedLenFeature([], tf.string),
  })

  columns = ['lang',
             'sample_count',
             'micro_drop_char_percent',
             'macro_drop_char_percent',
             'micro_compress_ratio',
             'macro_compress_ratio',
             'unweighted_en_wp_overlap_percent',
             'weighted_en_wp_overlap_percent']

  example_converter = tft.coders.ExampleProtoCoder(raw_metadata.schema,
                                                   serialized=False)

  def run_vocab():
    """Creates a pipeline to generate wordpiece vocab over a corpus."""

    vocab_pipeline = beam.Pipeline()

    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
      # Read raw data and convert to TF Transform encoded dict.
      raw_data = (
          vocab_pipeline
          | 'ReadInputData' >> beam.io.tfrecordio.ReadFromTFRecord(
              data_file, coder=beam.coders.ProtoCoder(tf.train.Example))
          | 'DecodeInputData' >> beam.Map(example_converter.decode))

      # Apply TF Transform.
      (transformed_data, _), _ = (
          (raw_data, raw_metadata)
          | 'FilterLangAndExtractToken' >> tft_beam.AnalyzeAndTransformDataset(
              utils.count_preprocessing_fn(FLAGS.text_key,
                                           FLAGS.language_code_key)))

      # Filter by languages.
      tokens = (
          transformed_data
          | 'FilterByLang' >> beam.ParDo(utils.FilterTokensByLang(lang_set)))

      # Calculate smoothing coefficients.
      coeffs = (
          tokens
          | 'CalculateSmoothingCoefficients' >> beam.CombineGlobally(
              utils.CalculateCoefficients(FLAGS.smoothing_exponent)))

      # Apply smoothing, aggregate counts, and sort words by count.
      _ = (
          tokens
          | 'ApplyExponentialSmoothing' >> beam.ParDo(
              utils.ExponentialSmoothing(), beam.pvalue.AsSingleton(coeffs))
          | 'SumCounts' >> beam.CombinePerKey(sum)
          | 'FilterLowCounts' >> beam.ParDo(utils.FilterByCount(
              FLAGS.max_word_length, min_token_frequency))
          | 'MergeAndSortCounts' >> beam.CombineGlobally(utils.SortByCount())
          | 'LearnVocab' >> beam.ParDo(utils.LearnVocab(params))
          | 'Flatten' >> beam.FlatMap(lambda x: x + '\n')
          | 'WriteVocab' >> beam.io.WriteToText(vocab_file,
                                                shard_name_template='',
                                                append_trailing_newlines=False))
    return vocab_pipeline

  def run_metrics():
    """Creates a pipeline to measure wordpiece vocab metrics over a corpus."""

    metrics_pipeline = beam.Pipeline()

    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
      # Read raw data and convert to TF Transform encoded dict.
      raw_data = (
          metrics_pipeline
          | 'ReadInputData' >> beam.io.tfrecordio.ReadFromTFRecord(
              data_file, coder=beam.coders.ProtoCoder(tf.train.Example))
          | 'DecodeInputData' >> beam.Map(example_converter.decode))

      # Apply transform to wordpiece-tokenize input.
      (metrics_transformed_data, _), _ = (
          (raw_data, raw_metadata)
          | 'WordpieceTokenizeInput' >> tft_beam.AnalyzeAndTransformDataset(
              utils.metrics_preprocessing_fn(FLAGS.vocab_file,
                                             FLAGS.text_key,
                                             FLAGS.language_code_key)))

      # Initialize CSV coder. Aggregate values for each lang, calculate metrics,
      # and write to output to a CSV file.
      csv_converter = tft.coders.CsvCoder(columns, csv_schema)
      _ = (
          metrics_transformed_data
          | 'CompileTokenInfo' >> beam.ParDo(utils.CompileTokenizationInfo())
          | 'CombineStatsForLang' >> beam.CombineGlobally(utils.AggregateLang())
          | 'CalculateMetrics' >> beam.ParDo(utils.CalculateMetrics())
          | 'EncodeMetrics' >> beam.Map(csv_converter.encode)
          | 'WriteMetrics' >> beam.io.WriteToText(
              metrics_file, shard_name_template='', header=','.join(columns)))
    return metrics_pipeline

  vocab_pipeline = run_vocab()
  vocab_pipeline.run().wait_until_finish()

  metrics_pipeline = run_metrics()
  metrics_pipeline.run().wait_until_finish()