def token_probability_distributions_per_percent_masked_bucket()

in muse/training_utils.py [0:0]


def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
    probs = F.softmax(logits, dim=-1)

    total_buckets = 10
    masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

    data = []

    for bucket_idx in range(total_buckets):
        indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]

        # It's ok if none were noised in the range of this bucket. This
        # function will be called for a later training step where it's likely
        # there will be an element noised in the range.
        if indices_for_bucket.shape[0] == 0:
            continue

        index_for_bucket = indices_for_bucket[0]

        image_probs = probs[index_for_bucket]

        # find the index of a masked pixel for the image
        input_ids_for_image = input_ids[index_for_bucket]
        masked_pixels_probs = image_probs[input_ids_for_image == mask_id]

        masked_pixel_probs = masked_pixels_probs[0]

        masked_pixel_probs = masked_pixel_probs.cpu().numpy()

        for masked_pixel_prob in masked_pixel_probs:
            data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})

    df = pd.DataFrame(data)

    return df