def pick_indices_from_probs()

in python/dpu_utils/tfutils/pick_indices.py [0:0]


def pick_indices_from_probs(probs: np.ndarray, num_picks: int, use_sampling: bool=False,
                            temperature: float=0.5) -> Iterable[int]:
    """Given an array of probabilities, pick up to num_samples unique indices from it."""
    if use_sampling:
        # First, consider the temperature for sampling:
        probs = probs ** (1.0 / temperature)
        normaliser = np.sum(probs)
        probs = probs / normaliser

        probs_cum = np.cumsum(probs)
        probs_cum[-1] = 1.0  # To protect against floating point oddness
        picked_indices = set()
        remaining_picks = num_picks * 10
        while len(picked_indices) < num_picks and remaining_picks > 0:
            remaining_picks -= 1
            picked_val = random.random()
            picked_index = np.argmax(probs_cum >= picked_val)  # type: int
            if picked_index not in picked_indices and probs[picked_index] > SMALL_NUMBER:
                picked_indices.add(picked_index)
        return picked_indices
    else:
        num_samples = min(num_picks, len(probs))
        top_k_indices = np.argpartition(probs, -num_samples)[-num_samples:]
        top_k_indices = [index for index in top_k_indices if probs[index] > SMALL_NUMBER]
        return top_k_indices