def sample_examples()

in tfx_addons/sampling/executor.py [0:0]


def sample_examples(data, null_classes, sampling_strategy):
  """Function that performs the sampling given a label-mapped dataset."""

  # Finds the minimum frequency of all classes in the input label.
  # Output is a singleton PCollection with the minimum # of examples.

  def find_minimum(elements):
    return min(elements or [0])

  def find_maximum(elements):
    return max(elements or [0])

  if sampling_strategy == spec.SamplingStrategy.UNDERSAMPLE:
    sample_fn = find_minimum
  elif sampling_strategy == spec.SamplingStrategy.OVERSAMPLE:
    sample_fn = find_maximum
  else:
    raise ValueError("Invalid value for sampling_strategy variable!")

  val = (data
         | "CountPerKey" >> beam.combiners.Count.PerKey()
         | "FilterNullCount" >>
         beam.Filter(lambda x: filter_null(x, null_vals=null_classes))
         | "Values" >> beam.Values()
         | "GetSample" >> beam.CombineGlobally(sample_fn))

  # Actually performs the undersampling functionality.
  # Output format is a K-V PCollection: {class_label: TFRecord in string format}
  res = (data
         | "GroupBylabel" >> beam.GroupByKey()
         | "FilterNull" >>
         beam.Filter(lambda x: filter_null(x, null_vals=null_classes))
         | "Sample" >> beam.FlatMapTuple(sample_data,
                                         sampling_strategy=sampling_strategy,
                                         side=beam.pvalue.AsSingleton(val)))

  # Take out all the null values from the beginning and put them back in the pipeline
  null = (data
          | "ExtractNull" >> beam.Filter(
              lambda x: filter_null(x, keep_null=True, null_vals=null_classes))
          | "NullValues" >> beam.Values())
  return (res, null) | "Merge PCollections" >> beam.Flatten()