def bins_to_text()

in scripts/metadata_to_text.py [0:0]


def bins_to_text(dataset, text_bins, column_name, output_column_name, leading_split_for_bins="train", batch_size = 4, num_workers = 1, std_tolerance=5, save_dir=None, only_save_plot=False, lower_range=None, bin_edges=None):
    '''
    Compute bins of `column_name` from the splits `leading_split_for_bins` and apply text bins to every split.
    `leading_split_for_bins` can be a string or a list.
    '''
    if bin_edges is None:
        values = []
        for df in dataset:
            for split in df:
                if leading_split_for_bins is None or leading_split_for_bins in split:
                    values.extend(df[split][column_name])
        
        # filter out outliers
        values = np.array(values)
        values = values[~np.isnan(values)]
        filtered_values = values
        if std_tolerance is not None:
            filtered_values = values[np.abs(values - np.mean(values)) < std_tolerance * np.std(values)]

        if save_dir is not None:
            visualize_bins_to_text(values, filtered_values, "Before filtering", "After filtering", text_bins, save_dir, output_column_name, lower_range=lower_range)
            
        # speaking_rate can easily have outliers
        if save_dir is not None and output_column_name=="speaking_rate":
            visualize_bins_to_text(filtered_values, filtered_values, "After filtering", "After filtering", text_bins, save_dir, f"{output_column_name}_after_filtering", lower_range=lower_range)
        
        values = filtered_values
        hist, bin_edges = np.histogram(values, bins = len(text_bins), range=(lower_range, values.max()) if lower_range else None)
        
        if only_save_plot:
            return dataset, bin_edges
    else:
        print(f"Already computed bin edges have been passed for {output_column_name}. Will use: {bin_edges}.")

    def batch_association(batch):
        index_bins = np.searchsorted(bin_edges, batch, side="left")
        # do min(max(...)) when values are outside of the main bins
        # it happens when value = min or max or have been filtered out from bins computation
        batch_bins = [text_bins[min(max(i-1, 0), len(text_bins)-1)] for i in index_bins]
        return {
            output_column_name: batch_bins
        }
    
    dataset = [df.map(batch_association, batched=True, batch_size=batch_size, input_columns=[column_name], num_proc=num_workers) for df in dataset]
    return dataset, bin_edges