in evals/elsuite/multistep_web_tasks/solvers/strong_solver/strong_solver.py [0:0]
def _cut_messages_to_fit(self, messages: OpenAICreateChatPrompt) -> OpenAICreateChatPrompt:
"""Remove messages from the prompt, starting with the first observation,
until it fits within the context window"""
target_n_tokens = self.context_length - self.max_response_tokens - TOKEN_BUFFER
logger.debug(f"{target_n_tokens = }")
messages_tokens = [self.encoding.encode(msg["content"]) for msg in messages]
messages_n_tokens = [len(tokens) + TOKENS_PER_MESSAGE for tokens in messages_tokens]
total_n_tokens = sum(messages_n_tokens)
logger.debug(f"{total_n_tokens = }")
if total_n_tokens < target_n_tokens:
logger.debug("initial prompt is short enough, returning!")
return messages
if len(messages) < 2:
raise ValueError("Not enough messages (only 1, which is system)")
# try to cut messages to get below the target tokens
if len(messages) > 2:
for i in range(1, len(messages) - 1):
logger.debug(f"truncating messages, {i = }, {total_n_tokens = }")
logger.debug(f"{len(messages) = }, [:1] and [{i} + 1:]")
if total_n_tokens < target_n_tokens:
return messages[:1] + messages[i + 1 :]
total_n_tokens -= messages_n_tokens[i]
# if after the loop we didn't succeed, just take the first and last messages
remaining_messages = messages[:1] + messages[-1:]
if len(remaining_messages) != 2:
logger.debug(f"{len(remaining_messages) = }")
logger.debug(f"{[msg['role'] for msg in remaining_messages] = }")
assert len(remaining_messages) == 2, "At this point, should only be two messages left"
# only one observation (and system message), so we have to shorten the obs rather than drop it
messages = copy.deepcopy(remaining_messages)
token_budget_for_obs = target_n_tokens - messages_n_tokens[0]
truncated_content_tokens = messages_tokens[-1][:token_budget_for_obs]
truncated_content_text = self.encoding.decode(truncated_content_tokens)
untruncated_content_text = messages[-1]["content"]
logger.debug(f"{len(untruncated_content_text) = }")
logger.debug(f"{len(truncated_content_text) = }")
logger.debug(f"{len(truncated_content_tokens) = }")
logger.debug(
f"final total length = {len(truncated_content_tokens) + messages_n_tokens[0] = }"
)
remaining_messages[1]["content"] = f"OBSERVATION: {truncated_content_text}"
return messages