in scripts/metadata_to_text.py [0:0]
def bins_to_text(dataset, text_bins, column_name, output_column_name, leading_split_for_bins="train", batch_size = 4, num_workers = 1, std_tolerance=5, save_dir=None, only_save_plot=False, lower_range=None, bin_edges=None):
'''
Compute bins of `column_name` from the splits `leading_split_for_bins` and apply text bins to every split.
`leading_split_for_bins` can be a string or a list.
'''
if bin_edges is None:
values = []
for df in dataset:
for split in df:
if leading_split_for_bins is None or leading_split_for_bins in split:
values.extend(df[split][column_name])
# filter out outliers
values = np.array(values)
values = values[~np.isnan(values)]
filtered_values = values
if std_tolerance is not None:
filtered_values = values[np.abs(values - np.mean(values)) < std_tolerance * np.std(values)]
if save_dir is not None:
visualize_bins_to_text(values, filtered_values, "Before filtering", "After filtering", text_bins, save_dir, output_column_name, lower_range=lower_range)
# speaking_rate can easily have outliers
if save_dir is not None and output_column_name=="speaking_rate":
visualize_bins_to_text(filtered_values, filtered_values, "After filtering", "After filtering", text_bins, save_dir, f"{output_column_name}_after_filtering", lower_range=lower_range)
values = filtered_values
hist, bin_edges = np.histogram(values, bins = len(text_bins), range=(lower_range, values.max()) if lower_range else None)
if only_save_plot:
return dataset, bin_edges
else:
print(f"Already computed bin edges have been passed for {output_column_name}. Will use: {bin_edges}.")
def batch_association(batch):
index_bins = np.searchsorted(bin_edges, batch, side="left")
# do min(max(...)) when values are outside of the main bins
# it happens when value = min or max or have been filtered out from bins computation
batch_bins = [text_bins[min(max(i-1, 0), len(text_bins)-1)] for i in index_bins]
return {
output_column_name: batch_bins
}
dataset = [df.map(batch_association, batched=True, batch_size=batch_size, input_columns=[column_name], num_proc=num_workers) for df in dataset]
return dataset, bin_edges