seed/util/qa.py (39 lines of code) (raw):

import os from azure.ai.generative.synthetic.qa import QADataGenerator, QAType from typing import Dict, List, Tuple, Any, Union, Optional from azure.ai.generative._telemetry import ActivityType, monitor_with_activity, ActivityLogger activity_logger = ActivityLogger(__name__) logger, module_logger = activity_logger.package_logger, activity_logger.module_logger class CustomQADataGenerator(QADataGenerator): def __init__(self, templates_dir: str, **kwargs): self.templates_dir = templates_dir super().__init__(**kwargs) def _get_template(self, filename) -> str: logger.debug("Getting prompt template from %s file", filename) filepath = os.path.join(self.templates_dir, filename) with open(filepath, encoding="utf-8") as f: template = f.read() return template def _get_messages_for_qa_type(self, qa_type: QAType, text: str, num_questions: int) -> List: logger.debug("Getting prompt messages for %s QA type", qa_type) template_filename = { QAType.SHORT_ANSWER: "prompt_qa_short_answer.txt", QAType.LONG_ANSWER: "prompt_qa_long_answer.txt", QAType.BOOLEAN: "prompt_qa_boolean.txt", QAType.SUMMARY: "prompt_qa_summary.txt", QAType.CONVERSATION: "prompt_qa_conversation.txt", } filename = template_filename[qa_type] messages = self._get_messages_from_file(filename) input_variables: Dict[str, Any] = {"text": text} if qa_type == QAType.SUMMARY: input_variables["num_words"] = 100 else: input_variables["num_questions"] = num_questions messages[-1]["content"] = messages[-1]["content"].format(**input_variables) return messages def _get_messages_for_modify_conversation(self, questions: List[str]) -> List: messages = self._get_messages_from_file("prompt_qa_conversation_modify.txt") questions_str = "\n".join([f"[Q]: {q}" for q in questions]) messages[-1]["content"] = messages[-1]["content"].format(questions=questions_str) return messages