server-go/main.go (149 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "cmp" "context" "fmt" "log" "net/http" "os" "strings" "github.com/google/generative-ai-go/genai" "github.com/rs/cors" "google.golang.org/api/iterator" "google.golang.org/api/option" ) const modelName = "gemini-1.5-flash" const defaultPort = "9000" // Server state holding the context of the Gemini client and the generative model. type genaiServer struct { ctx context.Context model *genai.GenerativeModel } func main() { ctx := context.Background() // Access your API key as an environment variable to create a client. apiKey := os.Getenv("GOOGLE_API_KEY") client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { log.Fatalf("could not create Gemini client %v", err) } defer client.Close() model := client.GenerativeModel(modelName) server := &genaiServer{ ctx: ctx, model: model, } mux := http.NewServeMux() mux.HandleFunc("POST /chat", server.chatHandler) mux.HandleFunc("POST /stream", server.streamingChatHandler) // Add CORS middleware handler. c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedHeaders: []string{"Access-Control-Allow-Origin", "Content-Type"}, }) handler := c.Handler(mux) // Access preferred port the server must listen to as an environment variable if provided. port := cmp.Or(os.Getenv("PORT"), defaultPort) addr := "localhost:" + port log.Println("Listening on ", addr) log.Fatal(http.ListenAndServe(addr, handler)) } // part is a piece of model content or user query. It can hold only text pieces. An item in the JSON // encoded history array based on the role it represents (user / model) holds a single model // response / user query as an ordered array of text chunks. Each item in this array must comply to part. type part struct { // Piece of model content or user query. Text string } // content is the structure to which each item in the incoming JSON-encoded history array must // comply to. type content struct { // The producer of the content. Must be either 'user' or 'model'. Role string // Ordered `Parts` that constitute a single message. Parts []part } // chatRequest is the structure to which the incoming JSON-encoded value in the response body is // decoded. type chatRequest struct { // The query from the user to the model. Chat string // The history of the conversation between the user and the model in the current session. History []content } // chatHandler returns the complete response of the model to the client. Expects a JSON payload in // the request with the following format: // Request: // - chat: string // - history: [] // // Sends a JSON payload containing the model response to the client with the following format. // Response: // - text: string func (gs *genaiServer) chatHandler(w http.ResponseWriter, r *http.Request) { cr := &chatRequest{} if err := parseRequestJSON(r, cr); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } cs := gs.startChat(cr.History) res, err := cs.SendMessage(gs.ctx, genai.Text(cr.Chat)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } resTxt, err := responseString(res) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } renderResponseJSON(w, map[string]string{"text": resTxt}) } // streamingChatHandler continuously streams the response of the model to the client. Expects a // JSON payload in the request with the following format: // Request: // - chat: string, // - history: [], // // A partial response from the model contains a piece of text. func (gs *genaiServer) streamingChatHandler(w http.ResponseWriter, r *http.Request) { cr := &chatRequest{} if err := parseRequestJSON(r, cr); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } cs := gs.startChat(cr.History) iter := cs.SendMessageStream(gs.ctx, genai.Text(cr.Chat)) w.Header().Set("Content-Type", "text/event-stream") for { res, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Println(err) break } resTxt, err := responseString(res) if err != nil { log.Println(err) break } fmt.Fprint(w, resTxt) if f, ok := w.(http.Flusher); ok { f.Flush() } } } // startChat starts a chat session with the model using the given history. func (gs *genaiServer) startChat(hist []content) *genai.ChatSession { cs := gs.model.StartChat() cs.History = transform(hist) return cs } // transform converts []content to a []*genai.Content that is accepted by the model's chat session. func transform(cs []content) []*genai.Content { gcs := make([]*genai.Content, len(cs)) for i, c := range cs { gcs[i] = c.transform() } return gcs } // transform converts content to genai.Content that is accepted by the model's chat session. func (c *content) transform() *genai.Content { gc := &genai.Content{} gc.Role = c.Role ps := make([]genai.Part, len(c.Parts)) for i, p := range c.Parts { ps[i] = genai.Text(p.Text) } gc.Parts = ps return gc } // responseString converts the model response of type genai.GenerateContentResponse to a string. func responseString(res *genai.GenerateContentResponse) (string, error) { // Only taking the first candidate since GenerationConfig.CandidateCount defaults to 1. if len(res.Candidates) > 0 { if cs := contentString(res.Candidates[0].Content); cs != nil { return *cs, nil } } return "", fmt.Errorf("invalid response from Gemini model") } // contentString converts genai.Content to a string. If the parts in the input content are of type // text, they are concatenated with new lines in between them to form a string. func contentString(c *genai.Content) *string { if c == nil || c.Parts == nil { return nil } cStrs := make([]string, len(c.Parts)) for i, part := range c.Parts { if pt, ok := part.(genai.Text); ok { cStrs[i] = string(pt) } else { return nil } } cStr := strings.Join(cStrs, "\n") return &cStr }