in 1_synthetic-qa-generation/reasoningplaning/evolve.py [0:0]
def monteCarloEstimation(self, node: Nde):
numTrueBranches = 0 # Correct branches count
falseBranches = []
trueBranches = []
genBranches = self.branchMaker.generateBranch(node.solution_prefix, self.numBranches)
node.numVisited += 1
for i, branch in enumerate(genBranches):
if branch is None or not branch:
continue
self.numOfTtlBranches += 1
# Generate branch r_i
node.addBranch(branch)
# Evaluate correctness of final answer in branch
fullSolution = (node.solution_prefix + '\n\n' + branch).strip() if node.solution_prefix else branch
isCorrect = self.branchMaker.evaluateCorrectness(fullSolution, self.expected_answer)
if isCorrect:
numTrueBranches += 1
trueBranches.append(branch)
else:
falseBranches.append(branch)
node.addFalseBranch(branch) # Track incorrect branches
# Update total branches and correct branches
node.numOfTtlBranches += self.numBranches
node.trueBranches += numTrueBranches
node.reward = node.trueBranches / node.numOfTtlBranches if node.numOfTtlBranches > 0 else 0
# logger.info(f"Monte Carlo Estimation for Nde ID {self.mct.nodes.index(node)}: reward = {node.reward:.2f}, Total Rollouts = {node.numOfTtlBranches}, Correct Rollouts = {node.trueBranches}\n")
if node.reward == 1.0:
# Add all correct branches to the tree as new states
for branch in trueBranches:
self.addTrueBranch2Tree(node, branch)
elif node.reward == 0.0:
# Nde is incorrect; no further action
return
else:
# 0 < reward(s) < 1.0
# Add correct branches to the tree
for branch in trueBranches:
self.addTrueBranch2Tree(node, branch)
# Add incorrect branches to candidate pool with updated priorities
for branch in falseBranches:
priority = self.computeSelectionScore(node, branch)
self.candidatePool.addOrUpdate(node, branch, priority)