app/components/chat.tsx (229 lines of code) (raw):
"use client";
import React, { useState, useEffect, useRef } from "react";
import styles from "./chat.module.css";
import { AssistantStream } from "openai/lib/AssistantStream";
import Markdown from "react-markdown";
// @ts-expect-error - no types for this yet
import { AssistantStreamEvent } from "openai/resources/beta/assistants/assistants";
import { RequiredActionFunctionToolCall } from "openai/resources/beta/threads/runs/runs";
type MessageProps = {
role: "user" | "assistant" | "code";
text: string;
};
const UserMessage = ({ text }: { text: string }) => {
return <div className={styles.userMessage}>{text}</div>;
};
const AssistantMessage = ({ text }: { text: string }) => {
return (
<div className={styles.assistantMessage}>
<Markdown>{text}</Markdown>
</div>
);
};
const CodeMessage = ({ text }: { text: string }) => {
return (
<div className={styles.codeMessage}>
{text.split("\n").map((line, index) => (
<div key={index}>
<span>{`${index + 1}. `}</span>
{line}
</div>
))}
</div>
);
};
const Message = ({ role, text }: MessageProps) => {
switch (role) {
case "user":
return <UserMessage text={text} />;
case "assistant":
return <AssistantMessage text={text} />;
case "code":
return <CodeMessage text={text} />;
default:
return null;
}
};
type ChatProps = {
functionCallHandler?: (
toolCall: RequiredActionFunctionToolCall
) => Promise<string>;
};
const Chat = ({
functionCallHandler = () => Promise.resolve(""), // default to return empty string
}: ChatProps) => {
const [userInput, setUserInput] = useState("");
const [messages, setMessages] = useState([]);
const [inputDisabled, setInputDisabled] = useState(false);
const [threadId, setThreadId] = useState("");
// automatically scroll to bottom of chat
const messagesEndRef = useRef<HTMLDivElement | null>(null);
const scrollToBottom = () => {
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
};
useEffect(() => {
scrollToBottom();
}, [messages]);
// create a new threadID when chat component created
useEffect(() => {
const createThread = async () => {
const res = await fetch(`/api/assistants/threads`, {
method: "POST",
});
const data = await res.json();
setThreadId(data.threadId);
};
createThread();
}, []);
const sendMessage = async (text) => {
const response = await fetch(
`/api/assistants/threads/${threadId}/messages`,
{
method: "POST",
body: JSON.stringify({
content: text,
}),
}
);
const stream = AssistantStream.fromReadableStream(response.body);
handleReadableStream(stream);
};
const submitActionResult = async (runId, toolCallOutputs) => {
const response = await fetch(
`/api/assistants/threads/${threadId}/actions`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
runId: runId,
toolCallOutputs: toolCallOutputs,
}),
}
);
const stream = AssistantStream.fromReadableStream(response.body);
handleReadableStream(stream);
};
const handleSubmit = (e) => {
e.preventDefault();
if (!userInput.trim()) return;
sendMessage(userInput);
setMessages((prevMessages) => [
...prevMessages,
{ role: "user", text: userInput },
]);
setUserInput("");
setInputDisabled(true);
scrollToBottom();
};
/* Stream Event Handlers */
// textCreated - create new assistant message
const handleTextCreated = () => {
appendMessage("assistant", "");
};
// textDelta - append text to last assistant message
const handleTextDelta = (delta) => {
if (delta.value != null) {
appendToLastMessage(delta.value);
};
if (delta.annotations != null) {
annotateLastMessage(delta.annotations);
}
};
// imageFileDone - show image in chat
const handleImageFileDone = (image) => {
appendToLastMessage(`\n\n`);
}
// toolCallCreated - log new tool call
const toolCallCreated = (toolCall) => {
if (toolCall.type != "code_interpreter") return;
appendMessage("code", "");
};
// toolCallDelta - log delta and snapshot for the tool call
const toolCallDelta = (delta, snapshot) => {
if (delta.type != "code_interpreter") return;
if (!delta.code_interpreter.input) return;
appendToLastMessage(delta.code_interpreter.input);
};
// handleRequiresAction - handle function call
const handleRequiresAction = async (
event: AssistantStreamEvent.ThreadRunRequiresAction
) => {
const runId = event.data.id;
const toolCalls = event.data.required_action.submit_tool_outputs.tool_calls;
// loop over tool calls and call function handler
const toolCallOutputs = await Promise.all(
toolCalls.map(async (toolCall) => {
const result = await functionCallHandler(toolCall);
return { output: result, tool_call_id: toolCall.id };
})
);
setInputDisabled(true);
submitActionResult(runId, toolCallOutputs);
};
// handleRunCompleted - re-enable the input form
const handleRunCompleted = () => {
setInputDisabled(false);
};
const handleReadableStream = (stream: AssistantStream) => {
// messages
stream.on("textCreated", handleTextCreated);
stream.on("textDelta", handleTextDelta);
// image
stream.on("imageFileDone", handleImageFileDone);
// code interpreter
stream.on("toolCallCreated", toolCallCreated);
stream.on("toolCallDelta", toolCallDelta);
// events without helpers yet (e.g. requires_action and run.done)
stream.on("event", (event) => {
if (event.event === "thread.run.requires_action")
handleRequiresAction(event);
if (event.event === "thread.run.completed") handleRunCompleted();
});
};
/*
=======================
=== Utility Helpers ===
=======================
*/
const appendToLastMessage = (text) => {
setMessages((prevMessages) => {
const lastMessage = prevMessages[prevMessages.length - 1];
const updatedLastMessage = {
...lastMessage,
text: lastMessage.text + text,
};
return [...prevMessages.slice(0, -1), updatedLastMessage];
});
};
const appendMessage = (role, text) => {
setMessages((prevMessages) => [...prevMessages, { role, text }]);
};
const annotateLastMessage = (annotations) => {
setMessages((prevMessages) => {
const lastMessage = prevMessages[prevMessages.length - 1];
const updatedLastMessage = {
...lastMessage,
};
annotations.forEach((annotation) => {
if (annotation.type === 'file_path') {
updatedLastMessage.text = updatedLastMessage.text.replaceAll(
annotation.text,
`/api/files/${annotation.file_path.file_id}`
);
}
})
return [...prevMessages.slice(0, -1), updatedLastMessage];
});
}
return (
<div className={styles.chatContainer}>
<div className={styles.messages}>
{messages.map((msg, index) => (
<Message key={index} role={msg.role} text={msg.text} />
))}
<div ref={messagesEndRef} />
</div>
<form
onSubmit={handleSubmit}
className={`${styles.inputForm} ${styles.clearfix}`}
>
<input
type="text"
className={styles.input}
value={userInput}
onChange={(e) => setUserInput(e.target.value)}
placeholder="Enter your question"
/>
<button
type="submit"
className={styles.button}
disabled={inputDisabled}
>
Send
</button>
</form>
</div>
);
};
export default Chat;