def generate_questions()

in glan-instruct/glan.py [0:0]


def generate_questions(
        class_sessions, key_concepts, subject, level, subtopics, model_name="gpt-4o", 
        num_iterations=2, num_questions_per_iteration=5, max_tokens=2048, batch_size=4, language="Korean", **kwargs
    ):
    """
    Generate questions based on class sessions and key concepts using LangChain pipeline. Please refer to section 2.4 of the paper.
    """

    from langchain_openai import AzureChatOpenAI
    from langchain_core.prompts import PromptTemplate
    from langchain_core.output_parsers import StrOutputParser, BaseOutputParser
    #from langchain_core.pydantic_v1 import BaseModel, Field
    from pydantic import BaseModel, Field    

    llm = AzureChatOpenAI(
        max_tokens=max_tokens,
        openai_api_version="2024-09-01-preview",
        azure_deployment=model_name,
        **kwargs                    
    )

    class CustomOutputParser(BaseOutputParser):
        def parse(self, text: str):
            cleaned_text = text.strip()
            return {"question": cleaned_text}


    prompt = PromptTemplate.from_template(
    """Based on the class session(s) {selected_class_sessions} and key concepts {selected_key_concepts}, generate a homework question.
    A question must be less than {max_tokens} tokens.
    Write in {language}.
    """
    )
    chain = prompt | llm | CustomOutputParser()    

    questions = []
    for idx in range(num_iterations):
        t0 = time.time()
        logger.info(f"\t\t===== Generating Questions: Iteration {idx}")
        selected_class_sessions, selected_key_concepts = sample_class_sessions_and_key_concepts(class_sessions, key_concepts, single_session=True)

        batch_inputs = [{
            "selected_class_sessions": selected_class_sessions,
            "selected_key_concepts": selected_key_concepts,
            "max_tokens": max_tokens,
            "language": language
        } for _ in range(num_questions_per_iteration)]

        metadata = {"subject": subject, "level": level, "subtopics": subtopics}

        with tqdm(total=len(batch_inputs), desc="\t\tProcessing Questions") as pbar:
            for i in range(0, len(batch_inputs), batch_size):
                minibatch = batch_inputs[i:i+batch_size]

                retries = 0
                while retries <= MAX_RETRIES:
                    try:
                        questions_ = chain.batch(minibatch, {"max_concurrency": batch_size})
                        break  # Exit the retry loop once successful
                    except RateLimitError as rate_limit_error:
                        delay = (retries + 1) * DELAY_INCREMENT
                        logger.warning(f"{rate_limit_error}. Retrying in {delay} seconds...")
                        time.sleep(delay)
                        retries += 1

                        if retries > MAX_RETRIES:
                            logger.error(f"Max retries reached this batch. Skipping to next batch.")
                            break
                    except Exception as e:
                        logger.error(f"Error in process_inputs: {e}")
                        break  
                
                for q in questions_:
                    q.update(metadata)
                questions.extend(questions_)
                pbar.set_postfix({"current_batch": f"{i//batch_size + 1}/{(len(batch_inputs) + (batch_size-1))//batch_size}"})

                pbar.update(len(minibatch))
        
        t1 = time.time()
        logger.info(f"\t\tIteration {idx} took {t1 - t0:.4f} seconds.")

    return questions