in tfx_addons/sampling/executor.py [0:0]
def sample_examples(data, null_classes, sampling_strategy):
"""Function that performs the sampling given a label-mapped dataset."""
# Finds the minimum frequency of all classes in the input label.
# Output is a singleton PCollection with the minimum # of examples.
def find_minimum(elements):
return min(elements or [0])
def find_maximum(elements):
return max(elements or [0])
if sampling_strategy == spec.SamplingStrategy.UNDERSAMPLE:
sample_fn = find_minimum
elif sampling_strategy == spec.SamplingStrategy.OVERSAMPLE:
sample_fn = find_maximum
else:
raise ValueError("Invalid value for sampling_strategy variable!")
val = (data
| "CountPerKey" >> beam.combiners.Count.PerKey()
| "FilterNullCount" >>
beam.Filter(lambda x: filter_null(x, null_vals=null_classes))
| "Values" >> beam.Values()
| "GetSample" >> beam.CombineGlobally(sample_fn))
# Actually performs the undersampling functionality.
# Output format is a K-V PCollection: {class_label: TFRecord in string format}
res = (data
| "GroupBylabel" >> beam.GroupByKey()
| "FilterNull" >>
beam.Filter(lambda x: filter_null(x, null_vals=null_classes))
| "Sample" >> beam.FlatMapTuple(sample_data,
sampling_strategy=sampling_strategy,
side=beam.pvalue.AsSingleton(val)))
# Take out all the null values from the beginning and put them back in the pipeline
null = (data
| "ExtractNull" >> beam.Filter(
lambda x: filter_null(x, keep_null=True, null_vals=null_classes))
| "NullValues" >> beam.Values())
return (res, null) | "Merge PCollections" >> beam.Flatten()