1_synthetic-qa-generation/reasoningplaning/evolve.py (697 lines of code) (raw):

import heapq import math import random import os import re from concurrent.futures import ThreadPoolExecutor, as_completed import itertools import logging # logger = None def setup_logging(seed_data: str): log_filename = f"{seed_data.replace('.jsonl', '').replace('.json', '')}.log" if not os.path.exists(log_filename): os.mknod(log_filename) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.FileHandler(log_filename), logging.StreamHandler()], ) return logging.getLogger(__name__) import sys import ast import json import uuid from typing import List, Tuple, Dict, Any, Optional import pandas as pd import numpy as np from enum import Enum import time import torch from datasets import Dataset, DatasetDict from transformers import pipeline from transformers.pipelines.pt_utils import KeyDataset from tqdm.auto import tqdm import markdown from bs4 import BeautifulSoup from datasets import load_dataset import os, openai from dotenv import load_dotenv from openai import AzureOpenAI, RateLimitError load_dotenv() # take environment variables from .env. NUM_TRY_BEFORE_SEARCH = 16 def separateSteps(steps: List[str], mode: str = 'join') -> Any: delimiter = "\n\n" if mode == 'join': if not isinstance(steps, list): raise TypeError("For 'join' mode, 'steps' must be a list of strings.") return delimiter.join(steps) elif mode == 'split': if not isinstance(steps, str): raise TypeError("For 'split' mode, 'steps' must be a string.") return steps.split(delimiter) else: raise ValueError("Mode should be either 'join' or 'split'.") # Helper function to check correctness of a generated response def checkCorrectness(generated_response: str, expected_answer: str) -> bool: sentences = re.split( r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', generated_response.strip() ) last_sentence = sentences[-1] if sentences else '' return expected_answer.strip() in last_sentence.strip() class BranchMaker: def __init__(self, azureAiService): self.azureAiService = azureAiService self.default_prompt = ( "Please complete the answer for the question based on the given steps without generating existing steps again, " "and separate your following steps using \\n\\n.\n\n" ) # self.azureAiService.startService() def generateBranch(self, node_prefix: str, num_copies) -> List[str]: """ Combine the default prompt with the node prefix and generate a response. Parameters: - node_prefix (str): The current solution prefix. Returns: - str: Generated response from LLM. """ prompt = self.default_prompt + node_prefix batch_response = self.azureAiService.generateResponse(prompt, num_copies) return batch_response # Assuming the response format has ['role'] entries and 'assistant' response def evaluateCorrectness(self, response: str, expected_answer: str) -> bool: return checkCorrectness(response, expected_answer) # Define the node class class Nde: def __init__(self, solution_prefix: str, parent: Optional['Nde'] = None): self.solution_prefix = solution_prefix # Solution prefix as a single string self.parent = parent # Reference to the parent node self.numVisited = 0 # Visit count (number of times selected) self.numOfTtlBranches = 0 # Total number of branches generated from this node self.trueBranches = 0 # Number of correct branches self.reward: Optional[float] = None # Monte Carlo estimation (c/k) self.branchesQ: Dict[str, float] = {} # Q(s, r): estimated value for each branch self.branches: List[str] = [] # Set of all branches from this node self.falseBranches: List[str] = [] # List of incorrect branches self.children: List['Nde'] = [] # List of child states def addBranch(self, branch: str): self.branches.append(branch) def addFalseBranch(self, branch: str): if branch not in self.falseBranches: self.falseBranches.append(branch) def get_full_solution(self) -> str: # Return the complete solution from the root to this node if self.parent: return self.parent.get_full_solution() + '\n\n' + self.solution_prefix else: return self.solution_prefix def getNewTxt(self) -> str: """ Return the new text added at this node compared to the parent. """ if self.parent: parent_text = self.parent.solution_prefix new_text = self.solution_prefix[len(parent_text):].strip() return new_text else: # Root node (the question) return self.solution_prefix.strip() def getTxtWithLabels(self) -> Dict[str, Any]: """ Return a nested dictionary where each node contains: - 'text': The new text at this node. - 'mc_value': The reward value at this node. - 'children': A list of child nodes with the same structure. """ data = { 'text': self.getNewTxt(), 'mc_value': self.reward, 'children': [child.getTxtWithLabels() for child in self.children] } return data # Define the Search Tree class class SearchTree: def __init__(self): self.root: Optional[Nde] = None self.nodes: List[Nde] = [] # List of all states def addNode(self, node: Nde): self.nodes.append(node) # Define the Candidate Pool as a priority queue with update capability class CandidatePool: def __init__(self): self.heap: List[Tuple[float, int]] = [] # Heap of (-priority, unique_id) self.entry_finder: Dict[int, Tuple[float, int]] = {} # Maps unique_id to (-priority, unique_id) self.counter = itertools.count() # Unique sequence count self.id_to_branch: Dict[int, Tuple[Nde, str]] = {} # Maps unique_id to (node, branch) self.latest_id_per_branch: Dict[Tuple[int, str], int] = {} # Maps (node_id, branch) to unique_id def addOrUpdate(self, node: Nde, branch: str, priority: float): """ Add a new branch or update the priority of an existing branch. Parameters: - node (Nde): The node associated with the branch. - branch (str): The branch string. - priority (float): The new priority score. """ node_id = id(node) # Unique identifier for the node object branch_key = (node_id, branch) # Check if the branch already exists in the pool if branch_key in self.latest_id_per_branch: # Previous unique_id exists; it is now outdated old_unique_id = self.latest_id_per_branch[branch_key] # Mark the old entry as invalid by removing it from entry_finder if old_unique_id in self.entry_finder: del self.entry_finder[old_unique_id] del self.id_to_branch[old_unique_id] # Assign a new unique_id for the updated branch unique_id = next(self.counter) self.latest_id_per_branch[branch_key] = unique_id # Add the new entry to the heap and mappings heapq.heappush(self.heap, (-priority, unique_id)) # Max-heap using negative priority self.entry_finder[unique_id] = (-priority, unique_id) self.id_to_branch[unique_id] = (node, branch) def pop(self) -> Tuple[Optional[Nde], Optional[str]]: """ Pop the branch with the highest priority. Returns: - Tuple[Optional[Nde], Optional[str]]: The node and branch string, or (None, None) if empty. """ while self.heap: neg_priority, unique_id = heapq.heappop(self.heap) # Check if this unique_id is still valid if unique_id in self.entry_finder: # Valid entry node, branch = self.id_to_branch.pop(unique_id) del self.entry_finder[unique_id] # Remove from latest_id_per_branch node_id = id(node) branch_key = (node_id, branch) if self.latest_id_per_branch.get(branch_key) == unique_id: del self.latest_id_per_branch[branch_key] return node, branch # Else, outdated entry; skip return None, None def is_empty(self) -> bool: return not self.entry_finder MAX_ITERATIONS = 1 MAX_RETRIES = 2 GRND_TRUTH_COL = "final_answer" def mdToText(md, do_md_to_text=True): if not do_md_to_text: return md assert md is not None, "Markdown is None" html = markdown.markdown(md) soup = BeautifulSoup(html, features='html.parser') return soup.get_text() class Mutation(Enum): FRESH_START = 0 ADD_CONSTRAINTS = 1 DEEPEN = 2 CONCRETIZE = 3 INCREASE_REASONING = 4 COMPLICATE = 5 SWITCH_TOPIC = 6 # Retrieved from https://github.com/nlpxucan/WizardLM/tree/main base_depth_instruction = "I want you act as a Prompt Rewriter.\r\n \ Your objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., chatgpt and GPT4) a bit harder to handle.\r\n \ But the rewritten prompt must be reasonable and must be understood and responded by humans.\r\n \ Your rewriting cannot omit the non-text parts such as the table and code in #The Given Prompt#:. Also, please do not omit the input in #The Given Prompt#. \r\n \ You SHOULD complicate the given prompt using the following method: \r\n\ {} \r\n\ You should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #The Given Prompt#. \r\n\ '#The Given Prompt#', '#Rewritten Prompt#', 'given prompt' and 'rewritten prompt' are not allowed to appear in #Rewritten Prompt#\r\n" base_breadth_instruction = "I want you act as a Prompt Creator.\r\n\ Your goal is to draw inspiration from the #Given Prompt# to create a brand new prompt.\r\n\ This new prompt should belong to the same domain as the #Given Prompt# but be even more rare.\r\n\ The LENGTH and complexity of the #Created Prompt# should be similar to that of the #Given Prompt#.\r\n\ The #Created Prompt# must be reasonable and must be understood and responded by humans.\r\n\ '#Given Prompt#', '#Created Prompt#', 'given prompt' and 'created prompt' are not allowed to appear in #Created Prompt#\r\n" complicate_prompt = base_depth_instruction.format("#Given Prompt# to make it slightly more complicated.'") constraints_prompt = base_depth_instruction.format("Please add one more constraints/requirements into #The Given Prompt#'") deepen_prompt = base_depth_instruction.format("If #The Given Prompt# contains inquiries about certain issues, the depth and breadth of the inquiry can be increased.") concretizing_prompt = base_depth_instruction.format("Please replace general concepts with more specific concepts.") reasoning_prompt = base_depth_instruction.format("If #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning.") class WizardLM: def __init__( self, llm_pipeline: pipeline = None, seed_data: List[str] = None, column_names: List[str] = ["instruction"], num_rows: int = 10, min_len_chars: int = 512, max_len_chars: int = 1024, verbose: bool = False, language: str = "Chinese", expCnst=0.125, alpha=0.5, beta=0.9, maxSolutionLen=512, numBranches=4, maxSrch=4, maxBranches=40, saveAsTree=True, ): """ Open-Source Implementation of https://arxiv.org/abs/2304.12244 :param llm_pipeline: Pipeline that takes a HF dataset containing one string column and returns a list of strings :param seed_data: Optional data to create Q:A pairs from, list of strings containing prompts :param num_rows: Number of desired Q:A pairs :param min_len_bytes: Lower limit for prompt length in bytes :param max_len_bytes: Upper limit for prompt length in bytes :param verbose: Whether to enable verbose printing. """ self.branchMaker = BranchMaker(llm_pipeline) self.expected_answer = None self.expCnst = expCnst self.alpha = alpha self.beta = beta self.maxSolutionLen = maxSolutionLen self.numBranches = numBranches self.maxSrch = maxSrch self.maxBranches = maxBranches self.saveAsTree = saveAsTree self.mct = SearchTree() self.candidatePool = CandidatePool() self.numSrched = 0 self.numOfTtlBranches = 0 self.maxIdx = -100 self.llm_pipeline = llm_pipeline self.column_names = column_names self.num_rows = num_rows self.verbose = verbose self.seed_text_dict = dict() self.seed_data = seed_data self.prompts = dict() self.final_prompts = dict() self.final_answers = [] self.min_len_bytes = min_len_chars self.max_len_bytes = max_len_chars self.prompt_templates = dict() self.prompt_templates['base'] = "" seed = None np.random.seed(seed) self.language = language self.prompt_templates[Mutation.FRESH_START] = \ self.prompt_templates['base'] + \ f"""Write one question or request containing one or more of the following words. Write in {self.language}.: <PROMPT>""" self.prompt_templates[Mutation.COMPLICATE] = \ self.prompt_templates['base'] + \ f"""{complicate_prompt}\nWrite in {self.language}. #Given Prompt#: <PROMPT> """ self.prompt_templates[Mutation.ADD_CONSTRAINTS] = \ self.prompt_templates['base'] + \ f"""{constraints_prompt}\nWrite in {self.language}. #The Given Prompt#: <PROMPT> """ self.prompt_templates[Mutation.DEEPEN] = \ self.prompt_templates['base'] + \ f"""{deepen_prompt}\nWrite in {self.language}. #The Given Prompt#: <PROMPT> """ self.prompt_templates[Mutation.CONCRETIZE] = \ self.prompt_templates['base'] + \ f"""{concretizing_prompt}\nWrite in {self.language}. #The Given Prompt#: <PROMPT> """ self.prompt_templates[Mutation.INCREASE_REASONING] = \ self.prompt_templates['base'] + \ f"""{reasoning_prompt}\nWrite in {self.language}. #The Given Prompt#: <PROMPT> """ self.prompt_templates[Mutation.SWITCH_TOPIC] = \ self.prompt_templates['base'] + \ f"""{base_breadth_instruction}\nWrite in {self.language}. #Given Prompt#: <PROMPT> """ def run(self): self.createSeedPrompts() self.createPrompts() self.createAnswers() list_qa = [] for i in range(len(self.final_prompts)): if len(self.final_answers[i]) > 10: list_qa.append( { 'input': self.final_prompts[i], 'output': self.final_answers[i], } ) with open(f"{self.seed_data.replace('.jsonl', '').replace('json', '')}.%s.json" % str(uuid.uuid4())[:4], "wt") as f: f.write(json.dumps(list_qa, indent=2, ensure_ascii=False)) def monteCarloEstimation(self, node: Nde): numTrueBranches = 0 # Correct branches count falseBranches = [] trueBranches = [] genBranches = self.branchMaker.generateBranch(node.solution_prefix, self.numBranches) node.numVisited += 1 for i, branch in enumerate(genBranches): if branch is None or not branch: continue self.numOfTtlBranches += 1 # Generate branch r_i node.addBranch(branch) # Evaluate correctness of final answer in branch fullSolution = (node.solution_prefix + '\n\n' + branch).strip() if node.solution_prefix else branch isCorrect = self.branchMaker.evaluateCorrectness(fullSolution, self.expected_answer) if isCorrect: numTrueBranches += 1 trueBranches.append(branch) else: falseBranches.append(branch) node.addFalseBranch(branch) # Track incorrect branches # Update total branches and correct branches node.numOfTtlBranches += self.numBranches node.trueBranches += numTrueBranches node.reward = node.trueBranches / node.numOfTtlBranches if node.numOfTtlBranches > 0 else 0 # logger.info(f"Monte Carlo Estimation for Nde ID {self.mct.nodes.index(node)}: reward = {node.reward:.2f}, Total Rollouts = {node.numOfTtlBranches}, Correct Rollouts = {node.trueBranches}\n") if node.reward == 1.0: # Add all correct branches to the tree as new states for branch in trueBranches: self.addTrueBranch2Tree(node, branch) elif node.reward == 0.0: # Nde is incorrect; no further action return else: # 0 < reward(s) < 1.0 # Add correct branches to the tree for branch in trueBranches: self.addTrueBranch2Tree(node, branch) # Add incorrect branches to candidate pool with updated priorities for branch in falseBranches: priority = self.computeSelectionScore(node, branch) self.candidatePool.addOrUpdate(node, branch, priority) def computeHrdScor(self, node: Nde, branch: str) -> float: # Count words in the branch word_count = len(branch.split()) length_penalty = word_count / self.maxSolutionLen hardness = (self.alpha ** (1 - node.reward)) * (self.beta ** length_penalty) return hardness def computeVistScor(self, node: Nde) -> float: N_total = sum(s.numVisited for s in self.mct.nodes) if N_total == 0: N_total = 1 # Prevent division by zero seldomvisit = self.expCnst * (math.sqrt(N_total)) / (1 + node.numVisited) return seldomvisit def computeSelectionScore(self, node: Nde, branch: str) -> float: hardness = self.computeHrdScor(node, branch) seldomvisit = self.computeVistScor(node) score = hardness + seldomvisit return score def selectionPhase(self) -> Tuple[Optional[Nde], Optional[str]]: selected_node, selected_branch = self.candidatePool.pop() return selected_node, selected_branch def addTrueBranch2Tree(self, parent_node: Nde, branch: str): new_solution_prefix = (parent_node.solution_prefix + '\n\n' + branch).strip() if parent_node.solution_prefix else branch new_node = Nde(solution_prefix=new_solution_prefix, parent=parent_node) new_node.reward = 1.0 # Since the branch is correct new_node.numOfTtlBranches = 0 new_node.trueBranches = 0 self.mct.addNode(new_node) parent_node.children.append(new_node) # Add to parent's children def expansionPhaseBinSrch(self, parent_node: Nde, branch: str): """ Parameters: - parent_node (Nde): The node from which the branch was selected. - branch (str): The branch string that was selected and is incorrect. """ # Separate the branch into individual steps steps = separateSteps(branch, mode='split') # Perform binary search to find incorrect steps self.binSrchIncorrectStep(parent_node, steps, 0, len(steps) - 1) def binSrchIncorrectStep(self, s_ast: Nde, steps: List[str], left: int, right: int): """ Recursively call bin search Parameters: - s_ast (Nde): The selected parent node. - steps (List[str]): The branch steps as a list. - left (int): Left index of the current search interval. - right (int): Right index of the current search interval. """ if left > right: return mid = (left + right) // 2 new_steps = steps[left:mid + 1] if new_steps: prefix_solution = s_ast.solution_prefix + '\n\n' + separateSteps(new_steps, mode='join') else: assert False prefix_solution = s_ast.solution_prefix # Create new node s_new s_new = Nde(solution_prefix=prefix_solution.strip(), parent=s_ast) self.mct.addNode(s_new) s_ast.children.append(s_new) # Perform Monte Carlo estimate self.monteCarloEstimation(s_new) if s_new.reward == 0: # Found incorrect step; continue searching in the left half to find earlier incorrect steps self.binSrchIncorrectStep(s_ast, steps, left, mid - 1) else: self.binSrchIncorrectStep(s_new, steps, mid + 1, right) def maintenancePhase(self, node: Nde): for branch in node.falseBranches: # Since we've already determined these branches are incorrect, no need to re-evaluate correctness priority = self.computeSelectionScore(node, branch) self.candidatePool.addOrUpdate(node, branch, priority) # logger.info(f"Updated Incorrect Rollout: '{branch}' with new priority: {priority:.4f}") def collectSolutionPrefixes(self) -> List[Dict[str, Any]]: collected_data = [] for node in self.mct.nodes: solution_prefix = node.solution_prefix mc_value = node.reward collected_data.append({ "solution_prefix": solution_prefix, "mc_value": mc_value }) return collected_data def collectTreeStructure(self) -> Dict[str, Any]: if self.mct.root: tree_data = self.mct.root.getTxtWithLabels() return tree_data return {} def resetPrmState(self): self.expected_answer = None self.mct = SearchTree() # Reset search tree self.candidatePool = CandidatePool() # Reset candidate pool self.numSrched = 0 self.numOfTtlBranches = 0 self.collected_data = [] # Clear collected data def genPrm(self, question: str, answer: str) -> List: self.resetPrmState() logger.info(f"Running genPrm for question: '{question}'\n") # Initialization initial_node = Nde(solution_prefix=question, parent=None) self.expected_answer = answer self.mct.root = initial_node self.mct.addNode(initial_node) self.numSrched = 0 # Monte Carlo Estimation for initial_node self.monteCarloEstimation(initial_node) # Main loop while self.numSrched < self.maxSrch and self.numOfTtlBranches < self.maxBranches and not self.candidatePool.is_empty(): # Selection Phase selected_node, selected_branch = self.selectionPhase() if selected_node is None or selected_branch is None: # logger.info("No more candidates to explore. Terminating search.\n") break self.expansionPhaseBinSrch(selected_node, selected_branch) # Maintenance Phase self.maintenancePhase(selected_node) # Increment search count self.numSrched += 1 if self.saveAsTree: data = self.collectTreeStructure() else: data = self.collectSolutionPrefixes() return data def shouldProcessQuestion(self, question: Dict[str, str]) -> bool: prompt = question[self.column_names] correct_answer = question[GRND_TRUTH_COL] has_correct = False has_incorrect = False initial_batch_answers = self.branchMaker.generateBranch(prompt, NUM_TRY_BEFORE_SEARCH) for answer in initial_batch_answers: if answer and self.branchMaker.evaluateCorrectness(answer, correct_answer): has_correct = True else: has_incorrect = True if has_correct and has_incorrect: logger.info(f"Question passed filter: {question['problem']}") return True return False def processQuestion(self, question: Dict[str, str]): # logger.info(f"Processing question with genPrm: {question[self.column_names]}") reasoning_steps = self.genPrm(question[self.column_names], question[GRND_TRUTH_COL]) collected_data = { "question": question[self.column_names], GRND_TRUTH_COL: question[GRND_TRUTH_COL], "reasoning_steps": reasoning_steps, } return collected_data def saveQuestionData(self, collected_data: Dict, index: int, output_path: str): collected_data["question_id"] = index with open(output_path, "a") as fd: line = json.dumps(collected_data) #json.dumps(list_q, indent=2, ensure_ascii=False) fd.write(f"{line}\n") logger.debug(f"Question {index} is saved to {output_path}") def runQuestionOnly(self): self.createSeedPrompts() self.createPrompts() list_q = [] for k in self.final_prompts: list_q.append( { "idx": self.final_prompts[k]["idx"], "preidx": self.final_prompts[k]["preidx"], self.column_names: k, GRND_TRUTH_COL: self.final_prompts[k][GRND_TRUTH_COL] # 'input': self.final_prompts[k], } ) del self.final_prompts output_file = f"{self.seed_data.replace('.jsonl', '').replace('.json', '')}.%s.json" % str(uuid.uuid4())[:4] processed_count = 0 for question in list_q: if self.shouldProcessQuestion(question): collected_data = self.processQuestion(question) self.saveQuestionData(collected_data, question['idx'], output_file) processed_count += 1 else: logger.info(f"Skipping question: {question[self.column_names]}") # Log summary logger.info( f"Total questions processed by genPrm: {processed_count}/{len(list_q)} inf file>> {output_file}" ) def createSeedPrompts(self): """ Turn self.seed_data into a list of strings of text self.source_text_list Each text string can represent as little as a word, or as much as document. Just has to be representative of some concept or body of text. :return: None """ if isinstance(self.seed_data, str) and os.path.exists(self.seed_data): data = load_dataset("json", data_files=self.seed_data) self.seed_text_dict = dict() for d in data['train']: s = "" if isinstance(self.column_names, str): s = d[self.column_names] else: assert False, "column_names must be a str" for col in self.column_names: s += d[col] + "\n" # self.seed_text_dict.append(s.strip()) self.seed_text_dict[s.strip()] = { "idx": d["idx"], GRND_TRUTH_COL: d[GRND_TRUTH_COL] } if int(d["idx"]) > self.maxIdx: self.maxIdx = int(d["idx"]) assert self.seed_text_dict, "data import failed, got empty list" self.maxIdx = self.maxIdx + 10 def createPrompts(self): logger.info("Creating %d prompts." % self.num_rows) assert self.seed_text_dict, "must have seed text list" t0 = time.time() self.prompts.clear() # for i in range(self.num_rows): # new_prompt = np.random.choice(self.seed_text_dict) # self.prompts.append(new_prompt) #@#FORTST comment above and replaced by below self.num_rows = len(self.seed_text_dict) for new_prompt in self.seed_text_dict: # self.prompts.append(new_prompt) self.prompts[new_prompt] = { "idx": self.seed_text_dict[new_prompt]["idx"], GRND_TRUTH_COL: self.seed_text_dict[new_prompt][GRND_TRUTH_COL] } i = 0 logger.info(f"length of self prompts={len(self.prompts)}") while self.mutate(i): logger.info("Iteration: %d" % i) i += 1 if i >= MAX_ITERATIONS: logger.info("Reached maximum number of iterations.") break t1 = time.time() logger.info("Done creating %d prompts in %.4f seconds." % (len(self.final_prompts), t1 - t0)) #@# include the original prompt into final_prompts for k in self.seed_text_dict: self.final_prompts[k] = { "idx": self.seed_text_dict[k]["idx"], "preidx": int(-100), "preproblem": "", GRND_TRUTH_COL: self.seed_text_dict[k][GRND_TRUTH_COL] } def createAnswers(self): logger.info("Creating answers for %d prompts." % len(self.final_prompts)) t0 = time.time() ds = self.convertListToDataset(self.final_prompts) self.final_answers = self.llm_pipeline(ds['train']) t1 = time.time() logger.info("Done creating answers for %d prompts in %.4f seconds." % (ds['train'].num_rows, t1 - t0)) def convertListToDataset(self, text_list): df = pd.DataFrame({'text': text_list}) ds = DatasetDict() ds['train'] = Dataset.from_pandas(df) return ds def mutate(self, iteration): assert len(self.prompts) == self.num_rows or len(self.prompts) == len(self.seed_text_dict) list_prompts = [] mutations = [] original_prompts = [] # for i in range(self.num_rows): for k in self.prompts: mutation = np.random.choice(Mutation) mutations.append(mutation) # if mutation == Mutation.FRESH_START: # mutation = Mutation.COMPLICATE before = k #self.prompts[i] prompt = self.prompt_templates[mutation].replace("<PROMPT>", before) if mutation == Mutation.SWITCH_TOPIC: prompt += "#Created Prompt#:\r\n" else: prompt += "#Rewritten Prompt:\r\n" logger.info(f"Full prompt={prompt}") list_prompts.append(prompt) original_prompts.append(k) ds = self.convertListToDataset(list_prompts) assert ds['train'].num_rows == len(list_prompts) == self.num_rows == len(self.prompts) # Processing transformed prompts using the LLM pipeline t0 = time.time() after = self.llm_pipeline(ds['train']) assert len(after) == self.num_rows t1 = time.time() llm_pipeline_name = self.llm_pipeline.__class__.__name__ logger.info(f"{llm_pipeline_name} took {t1 - t0:.4f} seconds") for i in range(len(after)): after[i] = after[i].split("Prompt#:")[-1].strip() for pp in ['New Prompt:\n', 'New Prompt: ']: if after[i][:len(pp)] == pp: after[i] = after[i][len(pp):] after[i] = after[i].strip() #use_new_prompt, why = self.changeApproved(self.prompts[i], after[i]) use_new_prompt = True original_p = original_prompts[i] if self.verbose: logger.info("===========================") logger.info("Old Prompt: %s" % original_p) logger.info("Mutation: %s" % mutations[i].name) logger.info("New Prompt: %s" % after[i]) logger.info("===========================") if use_new_prompt: original_itm = self.prompts[original_p] self.maxIdx = self.maxIdx + 1 self.final_prompts[after[i]] = { "idx": self.maxIdx, "preidx": original_itm["idx"], "preproblem": original_p, GRND_TRUTH_COL: original_itm[GRND_TRUTH_COL] } del self.prompts[original_p] chosen_prmp = np.random.choice(list(self.seed_text_dict.keys())) self.prompts[chosen_prmp] = { "idx": self.seed_text_dict[chosen_prmp]["idx"], GRND_TRUTH_COL: self.seed_text_dict[chosen_prmp][GRND_TRUTH_COL] } # if self.max_len_bytes >= len(after[i]) >= self.min_len_bytes: # self.final_prompts.append(after[i]) # logger.info("Prompt was accepted, now have %d good prompts." % len(self.final_prompts)) # self.prompts[i] = np.random.choice(self.seed_text_dict) # logger.info("Creating new prompt.") # else: # self.prompts[i] = after[i] # logger.info("Prompt was successfully modified.") else: logger.info("Mutation rejected, will try again. Reason: %s" % why) # logger.info("", flush=True) logger.info("final_prompt=") logger.info(self.final_prompts) return len(self.final_prompts) <= self.num_rows def changeApproved(self, before, after): if before == after: return False, "same" if after.count('\n') > after.count(" ") * 2: return False, "too many lines" if after.count('\n') == after.count("- ") > 10: return False, "too many items" if self.prompt_templates['base'] and self.prompt_templates['base'] in after: return False, "prompt leaked 1" if "#New Prompt#" in after: return False, "prompt leaked 2" if "new prompt" in after.lower(): return False, "prompt leaked 3" if "how can i assist" in after.lower(): return False, "AI" if "as an ai" in after.lower(): return False, "AI" if "gpt" in after.lower() and "gpt" not in before.lower(): return False, "AI" if "ai assistant" in after.lower(): return False, "AI" if "i'm sorry" in after.lower() and "sorry" not in before.lower() and len(after) < 400: return False, "sorry" if False: # too slow in general, not needed prompt = """Are the two following prompts equal to each other? To be equal, they must meet two requirements: 1. Both prompts have the same constraints and requirements. 2. Both prompts have the same depth and breath of the inquiry. First prompt: %s Second prompt: %s Answer with 'Equal' or 'Not Equal'. No need to explain the reason.""" % (before, after) answer = self.llm_pipeline(prompt) if 'not equal' not in answer.lower(): return False, "equal" return True, "ok" class AzureGPTPipeline: def __init__(self, model_name, **kwargs): self.model_name = model_name self.model_type = "aoai" self.kwargs = kwargs self.client = AzureOpenAI( azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), api_key = os.getenv("AZURE_OPENAI_API_KEY"), api_version = os.getenv("AZURE_OPENAI_API_VERSION") ) def __call__(self, dataset, **kwargs): ret = [] gen_count = 0 for d in dataset: logger.info(f"Generating {gen_count+1} of {len(dataset)}") response = None retries = 0 while not response and retries < MAX_RETRIES: try: response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": d['text']}], **kwargs ) except RateLimitError as e: logger.info("Rate limit exceeded. Retrying in 10 seconds...") retries += 1 time.sleep(10) if response: ret.append(response.choices[0].message.content) else: ret.append("") gen_count += 1 if gen_count % 10 == 0: logger.info(gen_count) return ret def generateResponse(self, prompt: str, num_copies: int = 2) -> List[str]: if self.model_type == "aoai": return self.generateApi(prompt, num_copies) else: raise ValueError("Unsupported model_type.") def generateApi(self, prompt: str, num_rollouts) -> List[str]: def send_request(prompt): temperature = random.choice([0.7, 1.0])#(self.temperature_range) if self.model_type == "aoai": response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": prompt}], max_tokens=self.kwargs["max_tokens"], temperature=self.kwargs["temperature"] # seed= ) output = response.choices[0].message.content else: assert False, "only support azure open ai" return output responses = [] with ThreadPoolExecutor(max_workers=num_rollouts) as executor: futures = [executor.submit(send_request, prompt) for _ in range(num_rollouts)] for future in tqdm(as_completed(futures), total=len(futures)): responses.append(future.result()) return responses import argparse if __name__ == "__main__": defaultseedfile = os.path.join(os.path.dirname(__file__),'samples/math_500_tst.json') parser = argparse.ArgumentParser(description='Options') parser.add_argument("--seed_file", type=str, default=defaultseedfile) parser.add_argument("--column_names", default="problem") #Instruction parser.add_argument("--temperature", type=int, default=0.7) parser.add_argument("--top_p", type=int, default=0.95) parser.add_argument("--model_name", type=str, default="gpt-4o") parser.add_argument("--num_branches", type=int, default=4) # how many branches we should explore when we complete rest part of a solution base on the existing part. parser.add_argument("--max_search", type=int, default=4) # the max limit of times we try to explore different solution of a problem using a MCTS like method parser.add_argument("--max_branches", type=int, default=40) # the max limit of the total number of branches we have explored when search different solutions for a problem args = parser.parse_args() # global logger logger = setup_logging(args.seed_file) llm_pipeline = AzureGPTPipeline( args.model_name, max_tokens=1024, temperature=args.temperature, top_p=args.top_p ) wizardlm = WizardLM( llm_pipeline=llm_pipeline, seed_data=args.seed_file, column_names=args.column_names, num_rows=2, # min_len_chars=args.min_len_chars, # max_len_chars=args.max_len_chars, language="English", verbose=True, numBranches = args.num_branches, maxSrch = args.max_search, maxBranches = args.max_branches, saveAsTree = True ) # if args.question_only: wizardlm.runQuestionOnly() # no need to gen answer, as there are ground truth in seed datas. # else: # wizardlm.run()