aoai/token_count_utils.py (244 lines of code) (raw):
import os
import re
import json
import tiktoken
import numpy as np
from typing import Dict, List, Union
from collections import defaultdict
#logger = logging.getLogger(__name__)
from logger import logger
def validate_json(parsed_data) -> Union[bool, str]:
"""
Validate the parsed JSON data to ensure it has the required keys and values.
Args:
parsed_data: (dict): Parsed JSON data.
Returns:
bool: True if the JSON data is valid, False otherwise.
message: str: Error message if the JSON data is invalid.
"""
try:
# Check if 'messages' key exists and is a list
if "messages" not in parsed_data or not isinstance(parsed_data["messages"], list):
logger.warning("Invalid JSON: 'messages' key is missing or not a list.")
return False, "missing_messages_list"
# Check if each message has the required keys according to its 'role'
for message in parsed_data["messages"]:
if not isinstance(message, dict):
logger.warning(f"Invalid JSON: Each message should be a dictionary. Found: {type(message)}")
return False, "missing_message_dict"
# Check if 'role' key exists and is of type string
if "role" not in message or not isinstance(message["role"], str):
logger.warning(f"Invalid JSON: Each message should contain a 'role' key of type string.")
return False, "missing_role_key"
# Check required keys based on role
role = message["role"]
if role == "system":
# 'system' role must have 'content'
if "content" not in message or not isinstance(message["content"], str):
logger.warning(f"Invalid JSON: 'system' role must have a 'content' key of type string.")
return False, "content_key_missing"
if not message["content"].strip():
logger.warning("Invalid JSON: 'system' role 'content' cannot be empty or only whitespace.")
return False, "content_empty"
elif role == "user":
# 'user' role must have 'content'
if "content" not in message or not isinstance(message["content"], str):
logger.warning(f"Invalid JSON: 'user' role must have a 'content' key of type string.")
return False, "content_key_missing"
if not message["content"].strip():
logger.warning("Invalid JSON: 'user' role 'content' cannot be empty or only whitespace.")
return False, "content_empty"
elif role == "assistant":
# The 'assistant' role must have at least one of 'content' or 'tool_calls'
if ("content" not in message or not isinstance(message["content"], str)) and ("tool_calls" not in message):
logger.warning(f"Invalid JSON: 'assistant' role must have either 'content' or 'tool_calls'.")
return False, "content_or_tool_calls_missing"
if "content" in message:
if not message["content"].strip():
logger.warning("Invalid JSON: 'assistant' role 'content' cannot be empty or only whitespace.")
return False, "content_empty"
if "tool_calls" in message and not isinstance(message["tool_calls"], list):
logger.warning(f"Invalid JSON: 'tool_calls' must be a list if provided.")
return False, "tool_calls_not_list"
elif role == "tool":
# 'tool' role must have a 'tool_call_id'
if "tool_call_id" not in message or not isinstance(message["tool_call_id"], str):
logger.warning(f"Invalid JSON: 'tool' role must have a 'tool_call_id' key of type string.")
return False, "tool_call_id_missing"
else:
logger.warning(f"Invalid JSON: Unknown role '{role}'.")
return False, "unknown_role"
# Validate 'tools' key (if necessary)
if "tools" in parsed_data and not isinstance(parsed_data["tools"], list):
logger.warning(f"Invalid JSON: 'tools' key must be a list if present.")
return False, "tools_not_list"
return True, "passed"
except Exception as e:
logger.warning(f"Error during validation: {e}")
return False, "error_during_validation"
def validate_jsonl(jsonl_files):
for jsonl_path in jsonl_files:
# Format error checks
format_errors = defaultdict(int)
dataset = []
logger.info('*' * 50)
logger.info(f"### [JSONL_VALIDATION] Processing file: {jsonl_path}")
with open(jsonl_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f, start=1):
try:
parsed_data = json.loads(line)
dataset.append(parsed_data)
except json.JSONDecodeError as e:
logger.warning(f"Line {idx}: Invalid JSON format - {e}")
except Exception as e:
logger.warning(f"Line {idx}: Unexpected error - {e}")
for idx, data in enumerate(dataset):
is_valid, error_key = validate_json(data)
if not is_valid:
logger.warning(f"Validation failed for line {idx + 1}")
format_errors[error_key] += 1
if format_errors:
for k, v in format_errors.items():
logger.info(f"{k}: {v}")
else:
logger.info(f"{jsonl_path}: All examples are valid")
logger.info('*' * 50)
def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
# Handle common azure model names/aliases
model = re.sub(r"^gpt\-?35", "gpt-3.5", model)
model = re.sub(r"^gpt4", "gpt-4", model)
max_token_limit = {
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-4": 8192,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768, # deprecate in Sep
"gpt-4-0314": 8192, # deprecate in Sep
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
"gpt-4-1106-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 128000,
"gpt-4o-mini-2024-07-18": 128000,
}
return max_token_limit[model]
def percentile_used(input, model="gpt-3.5-turbo-0613"):
return count_token(input) / get_max_token_limit(model)
def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int:
"""Count number of tokens left for an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens left that the model can use for completion.
"""
return get_max_token_limit(model) - count_token(input, model=model)
def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613") -> int:
"""Count number of tokens used by an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens from the input.
"""
if isinstance(input, str):
return _num_token_from_text(input, model=model)
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
raise ValueError(f"input must be str, list or dict, but we got {type(input)}")
def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a string."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages.
retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/
"""
if isinstance(messages, dict):
messages = [messages]
if "gpt-4o" in model:
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
if value is None:
continue
# function calls
if not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError:
logger.warning(
f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping."
)
continue
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def num_assistant_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") -> int:
if "gpt-4o" in model:
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
if message["role"] == "assistant":
if "content" in message:
num_tokens += len(encoding.encode(message["content"]))
return num_tokens
def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
"""Return the number of tokens used by a list of functions.
Args:
functions: (list): List of function descriptions that will be passed in model.
model: (str): Model name.
Returns:
int: Number of tokens from the function descriptions.
"""
if "gpt-4o" in model:
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for f in functions:
if 'function' in f:
function = f["function"]
else:
function = f
function_tokens = len(encoding.encode(function["name"]))
function_tokens += len(encoding.encode(function["description"]))
function_tokens -= 2
if "parameters" in function:
parameters = function["parameters"]
if "properties" in parameters:
for propertiesKey in parameters["properties"]:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters["properties"][propertiesKey]
for field in v:
if field == "type":
function_tokens += 2
function_tokens += len(encoding.encode(v["type"]))
elif field == "description":
function_tokens += 2
function_tokens += len(encoding.encode(v["description"]))
elif field == "enum":
function_tokens -= 3
for o in v["enum"]:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
logger.warning(f"Not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2
num_tokens += function_tokens
num_tokens += 12
return num_tokens
def print_distribution(values, name):
if (len(values) > 0):
logger.info(f"### Distribution of {name}:")
logger.info(f"min / max: {min(values)}, {max(values)}")
logger.info(f"mean / median: {np.mean(values)}, {np.median(values)}")
logger.info(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
def print_stats_tokens(jsonl_files, model="gpt-4o-2024-05-13"):
for jsonl_path in jsonl_files:
logger.info('*' * 50)
logger.info(f"### [TOKEN_STATS] Processing file: {jsonl_path}")
with open(jsonl_path, 'r', encoding='utf-8') as f:
dataset = [json.loads(line) for line in f]
total_tokens = []
assistant_tokens = []
function_tokens = []
for idx, ex in enumerate(dataset):
messages = ex.get("messages", {})
functions = ex.get("tools", {""})
total_tokens.append(count_token(messages, model))
assistant_tokens.append(num_assistant_tokens_from_messages(messages, model))
if len(functions) > 1 and functions != {''}:
function_tokens.append(num_tokens_from_functions(functions, model))
print_distribution(total_tokens, "total tokens")
print_distribution(function_tokens, "function tokens")
print_distribution(assistant_tokens, "assistant tokens")
logger.info('*' * 50)