neuron-explainer/neuron_explainer/explanations/prompt_builder.py (74 lines of code) (raw):
from __future__ import annotations
from enum import Enum
from typing import TypedDict, Union
import tiktoken
HarmonyMessage = TypedDict(
"HarmonyMessage",
{
"role": str,
"content": str,
},
)
class PromptFormat(str, Enum):
"""
Different ways of formatting the components of a prompt into the format accepted by the relevant
API server endpoint.
"""
NONE = "none"
"""Suitable for use with models that don't use special tokens for instructions."""
INSTRUCTION_FOLLOWING = "instruction_following"
"""Suitable for IF models that use <|endofprompt|>."""
HARMONY_V4 = "harmony_v4"
"""
Suitable for Harmony models that use a structured turn-taking role+content format. Generates a
list of HarmonyMessage dicts that can be sent to the /chat/completions endpoint.
"""
@classmethod
def from_string(cls, s: str) -> PromptFormat:
for prompt_format in cls:
if prompt_format.value == s:
return prompt_format
raise ValueError(f"{s} is not a valid PromptFormat")
class Role(str, Enum):
"""See https://platform.openai.com/docs/guides/chat"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class PromptBuilder:
"""Class for accumulating components of a prompt and then formatting them into an output."""
def __init__(self) -> None:
self._messages: list[HarmonyMessage] = []
def add_message(self, role: Role, message: str) -> None:
self._messages.append(HarmonyMessage(role=role, content=message))
def prompt_length_in_tokens(self, prompt_format: PromptFormat) -> int:
# TODO(sbills): Make the model/encoding configurable. This implementation assumes GPT-4.
encoding = tiktoken.get_encoding("cl100k_base")
if prompt_format == PromptFormat.HARMONY_V4:
# Approximately-correct implementation adapted from this documentation:
# https://platform.openai.com/docs/guides/chat/introduction
num_tokens = 0
for message in self._messages:
num_tokens += (
4 # every message follows <|im_start|>{role/name}\n{content}<|im_end|>\n
)
num_tokens += len(encoding.encode(message["content"], allowed_special="all"))
num_tokens += 2 # every reply is primed with <|im_start|>assistant
return num_tokens
else:
prompt_str = self.build(prompt_format)
assert isinstance(prompt_str, str)
return len(encoding.encode(prompt_str, allowed_special="all"))
def build(
self, prompt_format: PromptFormat, *, allow_extra_system_messages: bool = False
) -> Union[str, list[HarmonyMessage]]:
"""
Validates the messages added so far (reasonable alternation of assistant vs. user, etc.)
and returns either a regular string (maybe with <|endofprompt|> tokens) or a list of
HarmonyMessages suitable for use with the /chat/completions endpoint.
The `allow_extra_system_messages` parameter allows the caller to specify that the prompt
should be allowed to contain system messages after the very first one.
"""
# Create a deep copy of the messages so we can modify it and so that the caller can't
# modify the internal state of this object.
messages = [message.copy() for message in self._messages]
expected_next_role = Role.SYSTEM
for message in messages:
role = message["role"]
assert role == expected_next_role or (
allow_extra_system_messages and role == Role.SYSTEM
), f"Expected message from {expected_next_role} but got message from {role}"
if role == Role.SYSTEM:
expected_next_role = Role.USER
elif role == Role.USER:
expected_next_role = Role.ASSISTANT
elif role == Role.ASSISTANT:
expected_next_role = Role.USER
if prompt_format == PromptFormat.INSTRUCTION_FOLLOWING:
last_user_message = None
for message in messages:
if message["role"] == Role.USER:
last_user_message = message
assert last_user_message is not None
last_user_message["content"] += "<|endofprompt|>"
if prompt_format == PromptFormat.HARMONY_V4:
return messages
elif prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]:
return "".join(message["content"] for message in messages)
else:
raise ValueError(f"Unknown prompt format: {prompt_format}")