in evals/elsuite/function_deduction/eval.py [0:0]
def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random):
test_inputs = rng.sample(range(101), 3)
values = sample.values
expected = tuple(sample.values[test_input] for test_input in test_inputs)
cs = CurrentState(self.n_rounds, self.mode, test_inputs)
task_state = TaskState(
prompts.task_description.format(inputs=test_inputs, n_rounds=self.n_rounds),
current_state=cs,
)
for round_ix in range(self.n_rounds):
raw_response = solver(task_state).output
try:
ints = self._parse_raw_response(raw_response)
except ValueError:
cs.incorrect_format_rounds += 1
answer = prompts.incorrect_format
else:
if len(ints) == 1:
ask = ints[0]
result = values[ask] if ask not in test_inputs else None
cs.ask_update(ask, result)
if result is None:
answer = prompts.test_input_not_allowed.format(inputs=test_inputs)
else:
answer = prompts.new_value.format(in_=ask, out=result)
else:
cs.guess_update(ints, expected)
if cs.success:
break
else:
answer = self._bad_guess_answer(test_inputs, ints, expected)
task_state.messages += [
Message("assistant", raw_response),
Message("system", answer),
]
evals.record.record_metrics(
sample_ix=sample.sample_ix,
success=cs.success,
num_rounds=cs.round_ix if cs.success else None,
ask_rounds=cs.ask_rounds,
guess_rounds=cs.guess_rounds,
incorrect_format_rounds=cs.incorrect_format_rounds,
repeated_rounds=len(cs.parsed_responses) - len(set(cs.parsed_responses)),
code="lambda x: " + sample.code,
complexity=sample.complexity,
)