def get_donor_pools()

in mozetl/taar/taar_similarity.py [0:0]


def get_donor_pools(users_df, clusters_df, num_donors, random_seed=None):
    """Samples users from each cluster."""
    cluster_population = clusters_df.groupBy("prediction").count().collect()
    clusters_histogram = [(x["prediction"], x["count"]) for x in cluster_population]

    # Sort in-place from highest to lowest populated cluster.
    clusters_histogram.sort(key=lambda x: x[0], reverse=False)

    # Save the cluster ids and their respective scores separately.
    clusters = [cluster_id for cluster_id, _ in clusters_histogram]
    counts = [donor_count for _, donor_count in clusters_histogram]

    # Compute the proportion of user in each cluster.
    total_donors_in_clusters = sum(counts)
    clust_sample = [float(t) / total_donors_in_clusters for t in counts]
    sampling_proportions = dict(list(zip(clusters, clust_sample)))

    # Sample the users in each cluster according to the proportions
    # and pass along the random seed if needed for tests.
    sampling_kwargs = {"seed": random_seed} if random_seed else {}
    donor_df = clusters_df.sampleBy(
        "prediction", fractions=sampling_proportions, **sampling_kwargs
    )
    # Get the specific number of donors for each cluster and drop the
    # predicted cluster number information.
    current_sample_size = donor_df.count()
    donor_pool_df = donor_df.sample(
        False, float(num_donors) / current_sample_size, **sampling_kwargs
    )
    return clusters, donor_pool_df