in glan-instruct/glan.py [0:0]
def generate_answers(all_questions, model_name="gpt-4o", max_tokens=1024, batch_size=5, **kwargs):
"""
Generate answers to the questions using LangChain pipeline. Please refer to section 2.4 of the paper.
"""
from langchain.schema.output_parser import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain_openai import AzureChatOpenAI
llm = AzureChatOpenAI(
temperature=0,
max_tokens=max_tokens,
openai_api_version="2024-09-01-preview",
azure_deployment=model_name,
)
system_prompt = """Answer the question. Keep the answer short and concise. The topic, level, and subtopic of this question are as follows.
## Subject: {subject}
## Level: {level}
## Subtopics: {subtopics}
Respond "DO NOT KNOW" if not sure about the answer.
"""
system_prompt += f"Answer must be less than {max_tokens} token length."
system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
human_prompt = [
{
"type": "text",
"text": "{question}"
},
]
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
prompt = ChatPromptTemplate.from_messages(
[
system_message_template,
human_message_template
]
)
chain = prompt | llm | StrOutputParser()
logger.info(f"===== Generating Answers")
t0 = time.time()
all_answers = []
with tqdm(total=len(all_questions), desc="Processing Answers") as pbar:
for i in range(0, len(all_questions), batch_size):
minibatch = all_questions[i:i+batch_size]
retries = 0
while retries <= MAX_RETRIES:
try:
answers = chain.batch(minibatch, {"max_concurrency": batch_size})
break # Exit the retry loop once successful
except RateLimitError as rate_limit_error:
delay = (retries + 1) * DELAY_INCREMENT
logger.warning(f"{rate_limit_error}. Retrying in {delay} seconds...")
time.sleep(delay)
retries += 1
if retries > MAX_RETRIES:
logger.error(f"Max retries reached this batch. Skipping to next batch.")
break
except Exception as e:
logger.error(f"Error in process_inputs: {e}")
break
all_answers.extend(answers)
pbar.set_postfix({"current_batch": f"{i//batch_size + 1}/{(len(all_questions) + (batch_size-1))//batch_size}"})
pbar.update(len(minibatch))
t1 = time.time()
timespan = format_timespan(t1 - t0)
logger.info(f"Generating Answer dataset took {timespan}")
return all_answers