def check_train_data_stats()

in 5-4o_fine_tuning/data_validator.py [0:0]


    def check_train_data_stats(self):
        """
        Data warnings and token counts:
        - Missing System/User Messages: Counts the number of conversations missing a "system" or "user" message. Such messages are critical for defining the assistant's behavior and initiating the conversation.
        - Number of Messages Per Example: Summarizes the distribution of the number of messages in each conversation, providing insight into dialogue complexity.
        - Total Tokens Per Example: Calculates and summarizes the distribution of the total number of tokens in each conversation. Important for understanding fine-tuning costs.
        - Tokens in Assistant's Messages: Calculates the number of tokens in the assistant's messages per conversation and summarizes this distribution. Useful for understanding the assistant's verbosity.
        """
        data_path = self.train_file
        if not data_path.endswith('.jsonl'):
            raise ValueError(f"Invalid JSONL file: `{data_path}`")

        with open(data_path) as f:
            dataset = [json.loads(line) for line in f]

        logger.info(f"Checking data stats in {data_path}")

        def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
            return sum(tokens_per_message + sum(len(encoding.encode(value)) + (tokens_per_name if key == "name" else 0) for key, value in message.items()) for message in messages) + 3

        def num_assistant_tokens_from_messages(messages):
            return sum(len(encoding.encode(message["content"])) for message in messages if message["role"] == "assistant")

        def print_distribution(values, name):
            logger.info(f"\n#### Distribution of {name}:")
            logger.info(f"  min / max: {min(values)}, {max(values)}")
            logger.info(
                f"  mean / median: {np.mean(values)}, {np.median(values)}")
            logger.info(
                f"  p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

        n_missing_system = sum(1 for ex in dataset if not any(
            message["role"] == "system" for message in ex["messages"]))
        n_missing_user = sum(1 for ex in dataset if not any(
            message["role"] == "user" for message in ex["messages"]))
        n_messages = [len(ex["messages"]) for ex in dataset]
        convo_lens = [num_tokens_from_messages(
            ex["messages"]) for ex in dataset]
        assistant_message_lens = [num_assistant_tokens_from_messages(
            ex["messages"]) for ex in dataset]

        logger.info(
            f"\n\033[94mNum examples missing system message:\033[0m\n{n_missing_system}")
        logger.info(
            f"\n\033[94mNum examples missing user message:\033[0m\n{n_missing_user}")
        print_distribution(
            n_messages, "\033[92mnum_messages_per_example\033[0m")
        print_distribution(
            convo_lens, "\033[92mnum_total_tokens_per_example\033[0m")
        print_distribution(assistant_message_lens,
                           "\033[92mnum_assistant_tokens_per_example\033[0m")