openai/chat.go (124 lines of code) (raw):
// Copyright 2024 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// https://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package openai
import (
"encoding/json"
"io"
"log"
"net/http"
"reflect"
"strings"
"time"
"github.com/google-gemini/proxy-to-gemini/internal"
"github.com/google/generative-ai-go/genai"
)
func (h *handlers) ChatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
internal.ErrorHandler(w, r, http.StatusMethodNotAllowed, "method not allowed")
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err)
return
}
defer r.Body.Close()
var chatReq ChatCompletionRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to parse chat completions body: %v", err)
return
}
model := h.geminiClient.GenerativeModel(chatReq.Model)
model.GenerationConfig = genai.GenerationConfig{
CandidateCount: chatReq.N,
StopSequences: chatReq.Stop,
ResponseMIMEType: "text/plain",
MaxOutputTokens: chatReq.MaxTokens,
Temperature: chatReq.Temperature,
TopP: chatReq.TopP,
}
chat := model.StartChat()
var lastPart genai.Part
for i, r := range chatReq.Messages {
if r.Role == "system" {
model.SystemInstruction = &genai.Content{
Role: r.Role,
Parts: []genai.Part{genai.Text(r.Content)},
}
continue
}
if i == len(chatReq.Messages)-1 { // the last message
// TODO(jbd): This hack strips away the role of the last message.
// But Gemini API Go SDK doesn't give flexibility to call SendMessage
// with a list of contents.
lastPart = genai.Text(r.Content)
break
}
chat.History = append(chat.History, &genai.Content{
Role: r.Role,
Parts: []genai.Part{genai.Text(r.Content)},
})
}
if chatReq.Stream {
streamingChatCompletionsHandler(w, r, chatReq.Model, chat, lastPart)
return
}
geminiResp, err := chat.SendMessage(r.Context(), lastPart)
if err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to generate content: %v", err)
return
}
resp := toOpenAIResponse(geminiResp, "chat.completion", chatReq.Model)
if err := json.NewEncoder(w).Encode(resp); err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode chat completions response: %v", err)
return
}
}
func toOpenAIResponse(from *genai.GenerateContentResponse, object, model string) (to ChatCompletionResponse) {
to.Object = object
to.Created = time.Now().Unix()
to.Model = model
if from.UsageMetadata != nil {
to.Usage = Usage{
PromptTokens: from.UsageMetadata.PromptTokenCount,
CompletionTokens: from.UsageMetadata.CandidatesTokenCount,
TotalTokens: from.UsageMetadata.TotalTokenCount,
}
}
to.Choices = make([]ChatCompletionChoice, 0, len(from.Candidates))
for i, c := range from.Candidates {
var builder strings.Builder
for _, p := range c.Content.Parts {
content, ok := p.(genai.Text)
if !ok {
log.Printf("failed to process content part; type = %v", reflect.TypeOf(p))
continue
}
builder.WriteString(string(content))
}
choice := ChatCompletionChoice{
Index: i,
Message: ChatMessage{
Role: c.Content.Role,
Content: builder.String(),
},
}
finishReason := toGeminiFinishReason(c.FinishReason)
if finishReason != "" {
choice.FinishReason = finishReason
}
to.Choices = append(to.Choices, choice)
}
return to
}
func toGeminiFinishReason(code genai.FinishReason) string {
switch code {
case genai.FinishReasonStop:
return "stop"
case genai.FinishReasonMaxTokens:
return "length"
case genai.FinishReasonRecitation:
return "content_filter"
case genai.FinishReasonSafety:
return "content_filter"
case genai.FinishReasonOther:
return "other"
default:
return ""
}
}