arctic_inference/dynasor/entropy.py (101 lines of code) (raw):
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Dict, Any, Union, Optional
from arctic_inference.dynasor.evaluator import math_equal
def entropy(Plist: List[float]) -> float:
"""Calculate the Shannon entropy of a probability distribution.
Args:
Plist: List of probabilities that sum to 1
Returns:
float: The entropy value in bits (using log base 2)
"""
if len(Plist):
result = 0
for x in Plist:
result += (-x) * math.log(x, 2)
return result
else:
return 0
def norm(Olist: List[float]) -> List[float]:
"""Normalize a list of numbers to sum to 1.
Args:
Olist: List of numbers to normalize
Returns:
List[float]: Normalized list where sum equals 1
"""
s = sum(Olist)
return [o / s for o in Olist]
def count(Olist: List[Any]) -> List[float]:
"""Count occurrences of each unique element in a list.
Args:
Olist: List of elements to count
Returns:
List[float]: List of counts for each unique element
"""
x_dict = defaultdict(lambda: 0.0)
for x in Olist:
x_dict[x] += 1
cc = [c for _, c in x_dict.items()]
# print(cc)
return cc
def item_entropy(answers: List[Any]) -> float:
"""Calculate the entropy of a list of answers.
Args:
answers: List of answers to calculate entropy for
Returns:
float: Entropy value in bits
"""
return entropy(norm(count(answers)))
def count_not_empty(answers: List[str]) -> int:
"""Count the number of non-empty strings in a list.
Args:
answers: List of strings to check
Returns:
int: Number of non-empty strings
"""
return sum(1 for answer in answers if answer != "")
def equal_group(answers: List[Any]) -> bool:
"""Check if all answers in a list are equivalent.
Args:
answers: List of answers to compare
Returns:
bool: True if all answers are equivalent, False otherwise
"""
equiv_classes = []
for answer in answers:
weight = 1
flag = 0
for i, rep in enumerate(equiv_classes):
if math_equal(answer, rep):
flag = 1
break
if flag:
continue
equiv_classes.append(answer)
return len(equiv_classes) == 1
def majority_voting(answers: List[Any]) -> Any:
"""Find the most common answer using majority voting.
Args:
answers: List of answers to vote on
Returns:
Any: The most common answer
"""
equiv_classes = []
equiv_weights = []
max_vote = 0
for answer in answers:
weight = 1
flag = 0
for i, rep in enumerate(equiv_classes):
if math_equal(answer, rep):
flag = 1
equiv_weights[i] = equiv_weights[i] + weight
if equiv_weights[i] > max_vote:
max_vote = equiv_weights[i]
max_rep = answer
break
if flag:
continue
equiv_classes.append(answer)
equiv_weights.append(weight)
if max_vote == 0:
max_vote = weight
max_rep = answer
return max_rep
def obtain_answer(s: str) -> str:
"""Extract the first complete answer from a string by matching braces.
Args:
s: Input string containing potential answer
Returns:
str: The first complete answer found, or empty string if none found
"""
# Find first unpaired } by counting { and }
stack = []
for i, c in enumerate(s):
if c == "{":
stack.append(c)
elif c == "}":
if not stack: # No matching { found
return s[:i]
stack.pop()
return ""
uncertain_words = ["wait", "hold", "but", "okay", "no", "hmm"]
def is_certain_answer(probe_response_text: str, uncertain_words: List[str]) -> bool:
"""Check if the answer is certain by looking for uncertain words.
Args:
probe_response_text: Text to check for uncertainty
uncertain_words: List of words that indicate uncertainty
Returns:
bool: True if the answer is certain, False otherwise
"""
return not any(word in probe_response_text.lower() for word in uncertain_words)
def has_value(x: Any) -> bool:
"""Check if a value exists and is non-empty.
Args:
x: Value to check
Returns:
bool: True if value exists and is non-empty, False otherwise
"""
if x is None:
return False
if isinstance(x, str):
return len(x.strip()) > 0
if isinstance(x, list):
return len(x) > 0
return True
def should_early_exit(
answers: List[str],
probe_response_text: str,
uncertain_words: List[str],
continue_certain_bar: int,
is_certains: List[bool],
) -> bool:
"""Check if the answer is consistent and certain enough to exit early.
1. Number of answers should be greater than the threshold
2. The probe response text should not contain any uncertain words
3. The answers should be consistent
Args:
answers: List of answers to check
probe_response_text: Text of the probe response
uncertain_words: List of words that indicate uncertainty
continue_certain_bar: Threshold for number of consistent answers needed
is_certains: List of booleans indicating if each answer is certain
Returns:
bool: True if should exit early, False otherwise
"""
# Number of answers should be greater than the threshold
if len(answers) < continue_certain_bar:
return False
# The probe response text should not contain any uncertain words
probe_response_text_lower = probe_response_text.lower()
if any(word in probe_response_text_lower for word in uncertain_words):
return False
# The last answer window should be consistent
answer_candidates = answers[-continue_certain_bar:]
is_certains = is_certains[-continue_certain_bar:]
if equal_group(answer_candidates):
if count_not_empty(answer_candidates) == continue_certain_bar:
if sum(is_certains) == continue_certain_bar:
# logger.debug(f"Early exit on: {answer_candidates = } ({is_certains = })")
return True
return True