in glan-instruct/glan.py [0:0]
def generate_questions(
class_sessions, key_concepts, subject, level, subtopics, model_name="gpt-4o",
num_iterations=2, num_questions_per_iteration=5, max_tokens=2048, batch_size=4, language="Korean", **kwargs
):
"""
Generate questions based on class sessions and key concepts using LangChain pipeline. Please refer to section 2.4 of the paper.
"""
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, BaseOutputParser
#from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
llm = AzureChatOpenAI(
max_tokens=max_tokens,
openai_api_version="2024-09-01-preview",
azure_deployment=model_name,
**kwargs
)
class CustomOutputParser(BaseOutputParser):
def parse(self, text: str):
cleaned_text = text.strip()
return {"question": cleaned_text}
prompt = PromptTemplate.from_template(
"""Based on the class session(s) {selected_class_sessions} and key concepts {selected_key_concepts}, generate a homework question.
A question must be less than {max_tokens} tokens.
Write in {language}.
"""
)
chain = prompt | llm | CustomOutputParser()
questions = []
for idx in range(num_iterations):
t0 = time.time()
logger.info(f"\t\t===== Generating Questions: Iteration {idx}")
selected_class_sessions, selected_key_concepts = sample_class_sessions_and_key_concepts(class_sessions, key_concepts, single_session=True)
batch_inputs = [{
"selected_class_sessions": selected_class_sessions,
"selected_key_concepts": selected_key_concepts,
"max_tokens": max_tokens,
"language": language
} for _ in range(num_questions_per_iteration)]
metadata = {"subject": subject, "level": level, "subtopics": subtopics}
with tqdm(total=len(batch_inputs), desc="\t\tProcessing Questions") as pbar:
for i in range(0, len(batch_inputs), batch_size):
minibatch = batch_inputs[i:i+batch_size]
retries = 0
while retries <= MAX_RETRIES:
try:
questions_ = chain.batch(minibatch, {"max_concurrency": batch_size})
break # Exit the retry loop once successful
except RateLimitError as rate_limit_error:
delay = (retries + 1) * DELAY_INCREMENT
logger.warning(f"{rate_limit_error}. Retrying in {delay} seconds...")
time.sleep(delay)
retries += 1
if retries > MAX_RETRIES:
logger.error(f"Max retries reached this batch. Skipping to next batch.")
break
except Exception as e:
logger.error(f"Error in process_inputs: {e}")
break
for q in questions_:
q.update(metadata)
questions.extend(questions_)
pbar.set_postfix({"current_batch": f"{i//batch_size + 1}/{(len(batch_inputs) + (batch_size-1))//batch_size}"})
pbar.update(len(minibatch))
t1 = time.time()
logger.info(f"\t\tIteration {idx} took {t1 - t0:.4f} seconds.")
return questions