def generate_answers()

in glan-instruct/glan.py [0:0]


def generate_answers(all_questions, model_name="gpt-4o", max_tokens=1024, batch_size=5, **kwargs):
    """
    Generate answers to the questions using LangChain pipeline. Please refer to section 2.4 of the paper.
    """
    from langchain.schema.output_parser import StrOutputParser
    from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
    from langchain_openai import AzureChatOpenAI

    llm = AzureChatOpenAI(
        temperature=0, 
        max_tokens=max_tokens,
        openai_api_version="2024-09-01-preview",
        azure_deployment=model_name,                   
    )

    system_prompt = """Answer the question. Keep the answer short and concise. The topic, level, and subtopic of this question are as follows.

    ## Subject: {subject}
    ## Level: {level}
    ## Subtopics: {subtopics}

    Respond "DO NOT KNOW" if not sure about the answer.
    """
    system_prompt += f"Answer must be less than {max_tokens} token length."

    system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
    human_prompt = [
        {
            "type": "text",
            "text": "{question}"
        },
    ]
    human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)

    prompt = ChatPromptTemplate.from_messages(
        [
            system_message_template,
            human_message_template
        ]
    )

    chain = prompt | llm | StrOutputParser()
    logger.info(f"===== Generating Answers")
    t0 = time.time()
    all_answers = []
    
    with tqdm(total=len(all_questions), desc="Processing Answers") as pbar:
        for i in range(0, len(all_questions), batch_size):
            minibatch = all_questions[i:i+batch_size]
   
            retries = 0
            while retries <= MAX_RETRIES:
                try:
                    answers = chain.batch(minibatch, {"max_concurrency": batch_size})
                    break  # Exit the retry loop once successful
                except RateLimitError as rate_limit_error:
                    delay = (retries + 1) * DELAY_INCREMENT
                    logger.warning(f"{rate_limit_error}. Retrying in {delay} seconds...")
                    time.sleep(delay)
                    retries += 1

                    if retries > MAX_RETRIES:
                        logger.error(f"Max retries reached this batch. Skipping to next batch.")
                        break
                except Exception as e:
                    logger.error(f"Error in process_inputs: {e}")
                    break            
            
            all_answers.extend(answers)
            pbar.set_postfix({"current_batch": f"{i//batch_size + 1}/{(len(all_questions) + (batch_size-1))//batch_size}"})
            pbar.update(len(minibatch))
    t1 = time.time()
    timespan = format_timespan(t1 - t0)
    logger.info(f"Generating Answer dataset took {timespan}")

    return all_answers