def check_data_format_errors()

in 5-4o_fine_tuning/data_validator.py [0:0]


    def check_data_format_errors(self) -> dict:
        """
        Format validation checks:
        - Data Type Check: Checks whether each entry in the dataset is a dictionary (dict). Error type: data_type.
        - Presence of Message List: Checks if a messages list is present in each entry. Error type: missing_messages_list.
        - Message Keys Check: Validates that each message in the messages list contains the keys role and content. Error type: message_missing_key.
        - Unrecognized Keys in Messages: Logs if a message has keys other than role, content, weight, function_call, and name. Error type: message_unrecognized_key.
        - Role Validation: Ensures the role is one of "system", "user", or "assistant". Error type: unrecognized_role.
        - Content Validation: Verifies that content has textual data and is a string. Error type: missing_content.
        - Assistant Message Presence: Checks that each conversation has at least one message from the assistant. Error type: example_missing_assistant_message.

        # Example usage --> errors = check_format_errors(dataset)
        """
        files_to_check = [
            self.train_file, self.validation_file] if self.validation_file else [self.train_file]

        data_format_errors = {file: False for file in files_to_check}

        for dataset_path in files_to_check:
            data_path = dataset_path
            if not data_path.endswith('.jsonl'):
                raise ValueError(
                    f"The provided dataset path `{data_path}` is not a valid JSONL file.")

            # Load dataset
            with open(data_path) as f:
                dataset = [json.loads(line) for line in f]

            logger.info(
                f"Checking format errors in {data_path}")

            # initial dataset stats
            try:
                logger.info(f"Number of examples: {len(dataset)}")
                logger.info("First example:")
                for message in dataset[0]["messages"]:
                    logger.info(message)
            except KeyError:
                logger.error(
                    "\033[91mNo messages found in the first example.\033[0m")

            # Format error checks
            format_errors = defaultdict(int)

            for ex in dataset:
                if not isinstance(ex, dict):
                    format_errors["data_type"] += 1
                    continue

                messages = ex.get("messages", None)
                if not messages:
                    format_errors["missing_messages_list"] += 1
                    continue

                for message in messages:
                    if "role" not in message or "content" not in message:
                        format_errors["message_missing_key"] += 1

                    if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
                        format_errors["message_unrecognized_key"] += 1

                    if message.get("role", None) not in ("system", "user", "assistant", "function"):
                        format_errors["unrecognized_role"] += 1

                    content = message.get("content", None)
                    function_call = message.get("function_call", None)

                    if (not content and not function_call) or not isinstance(content, str):
                        format_errors["missing_content"] += 1

                if not any(message.get("role", None) == "assistant" for message in messages):
                    format_errors["example_missing_assistant_message"] += 1

            if format_errors:
                data_format_errors[data_path] = True
                logger.error(f"\033[91mFound errors in {data_path}:\033[0m")
                for k, v in format_errors.items():
                    logger.error(f"\033[91m  {k}: {v}\033[0m")
            else:
                logger.info(f"\033[92mNo errors found for {data_path}\033[0m")
            logger.info("-----------------------------------")
        return data_format_errors