in 5-4o_fine_tuning/data_validator.py [0:0]
def check_train_data_stats(self):
"""
Data warnings and token counts:
- Missing System/User Messages: Counts the number of conversations missing a "system" or "user" message. Such messages are critical for defining the assistant's behavior and initiating the conversation.
- Number of Messages Per Example: Summarizes the distribution of the number of messages in each conversation, providing insight into dialogue complexity.
- Total Tokens Per Example: Calculates and summarizes the distribution of the total number of tokens in each conversation. Important for understanding fine-tuning costs.
- Tokens in Assistant's Messages: Calculates the number of tokens in the assistant's messages per conversation and summarizes this distribution. Useful for understanding the assistant's verbosity.
"""
data_path = self.train_file
if not data_path.endswith('.jsonl'):
raise ValueError(f"Invalid JSONL file: `{data_path}`")
with open(data_path) as f:
dataset = [json.loads(line) for line in f]
logger.info(f"Checking data stats in {data_path}")
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
return sum(tokens_per_message + sum(len(encoding.encode(value)) + (tokens_per_name if key == "name" else 0) for key, value in message.items()) for message in messages) + 3
def num_assistant_tokens_from_messages(messages):
return sum(len(encoding.encode(message["content"])) for message in messages if message["role"] == "assistant")
def print_distribution(values, name):
logger.info(f"\n#### Distribution of {name}:")
logger.info(f" min / max: {min(values)}, {max(values)}")
logger.info(
f" mean / median: {np.mean(values)}, {np.median(values)}")
logger.info(
f" p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
n_missing_system = sum(1 for ex in dataset if not any(
message["role"] == "system" for message in ex["messages"]))
n_missing_user = sum(1 for ex in dataset if not any(
message["role"] == "user" for message in ex["messages"]))
n_messages = [len(ex["messages"]) for ex in dataset]
convo_lens = [num_tokens_from_messages(
ex["messages"]) for ex in dataset]
assistant_message_lens = [num_assistant_tokens_from_messages(
ex["messages"]) for ex in dataset]
logger.info(
f"\n\033[94mNum examples missing system message:\033[0m\n{n_missing_system}")
logger.info(
f"\n\033[94mNum examples missing user message:\033[0m\n{n_missing_user}")
print_distribution(
n_messages, "\033[92mnum_messages_per_example\033[0m")
print_distribution(
convo_lens, "\033[92mnum_total_tokens_per_example\033[0m")
print_distribution(assistant_message_lens,
"\033[92mnum_assistant_tokens_per_example\033[0m")