in agora/cerebral_api/src/llm.py [0:0]
def classify_question(self, question: str, industry: str, role: str) -> str:
"""
Classify the question into predefined categories.
Args:
question: The question to classify
industry: The industry context
role: The role context
Returns:
str: The classified category
"""
request_id = str(uuid.uuid4())
logger.info(f"[{request_id}] Starting question classification")
try:
# Log input parameters
if VERBOSE:
logger.debug(f"[{request_id}] Parameters:")
logger.debug(f"[{request_id}] - Question: {question}")
logger.debug(f"[{request_id}] - Industry: {industry}")
logger.debug(f"[{request_id}] - Role: {role}")
# Define categories
categories = ["data", "relational", "documentation", "greetings"]
logger.debug(f"[{request_id}] Available categories: {categories}")
# Get and format prompt
start_time = time.time()
prompt_text = self.get_prompt('classify_question', industry, role).format(
categories=', '.join(categories),
question=question
)
#prompt_time = time.time() - start_time
if VERBOSE:
logger.debug(f"[{request_id}] Generated prompt:")
logger.debug(f"[{request_id}] {prompt_text}")
#logger.debug(f"[{request_id}] Prompt generation took: {prompt_time:.2f}s")
# Make API request
logger.debug(f"[{request_id}] Sending request to OpenAI API")
start_time = time.time()
conversation = [
{"role": "system", "content": prompt_text},
{"role": "user", "content": question}
]
# Make API call using configured model
response = self.client.chat.completions.create(
model=self.CHATGPT_MODEL,
messages=conversation,
temperature=0, # Keep temperature low for consistent classification
max_tokens=10, # Short response needed
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=["\n"]
)
#api_time = time.time() - start_time
#logger.debug(f"[{request_id}] API request completed in {api_time:.2f}s")
# Process response
raw_response = response.choices[0].message.content
cleaned_response = raw_response.replace('<|im_end|>', '').strip().lower()
if VERBOSE:
logger.debug(f"[{request_id}] Raw response: {raw_response}")
logger.debug(f"[{request_id}] Cleaned response: {cleaned_response}")
# Validate response
if cleaned_response not in categories:
logger.warning(f"[{request_id}] Unexpected category: {cleaned_response}")
logger.warning(f"[{request_id}] Defaulting to 'unknown'")
return "unknown"
logger.info(f"[{request_id}] Question classified as: {cleaned_response}")
#logger.info(f"[{request_id}] Total classification time: {time.time() - start_time:.2f}s")
return cleaned_response
except Exception as e:
logger.error(f"[{request_id}] Error in classify_question: {str(e)}")
if VERBOSE:
import traceback
logger.debug(f"[{request_id}] Full error traceback:")
logger.debug(traceback.format_exc())
return "unknown"