llm/go-client/frontend/handlers/chat.go (197 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package handlers import ( "context" "fmt" "io" "log" "net/http" "regexp" "runtime/debug" "time" ) import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) import ( "github.com/apache/dubbo-go-samples/llm/config" "github.com/apache/dubbo-go-samples/llm/go-client/frontend/service" chat "github.com/apache/dubbo-go-samples/llm/proto" ) type ChatHandler struct { svc chat.ChatService ctxManager *service.ContextManager } func NewChatHandler(svc chat.ChatService, mgr *service.ContextManager) *ChatHandler { return &ChatHandler{ svc: svc, ctxManager: mgr, } } func (h *ChatHandler) Index(c *gin.Context) { session := sessions.Default(c) ctxID := session.Get("current_context") if ctxID == nil { ctxID = h.ctxManager.CreateContext() session.Set("current_context", ctxID) err := session.Save() if err != nil { return } } c.HTML(http.StatusOK, "index.html", gin.H{ "title": "LLM Chat", }) } func (h *ChatHandler) Chat(c *gin.Context) { session := sessions.Default(c) ctxID, ok := session.Get("current_context").(string) if !ok { h.NewContext(c) ctxID, ok = session.Get("current_context").(string) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get context"}) return } } var req struct { Message string `json:"message"` Bin string `json:"bin"` Model string `json:"model"` } if err := c.BindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request payload"}) return } var img string if len(req.Bin) > 0 { re := regexp.MustCompile(`^data:image/([a-zA-Z]+);base64,([^"]+)$`) // this regex does not support file types like svg matches := re.FindStringSubmatch(req.Bin) if len(matches) != 3 { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid base64 data format"}) return } img = matches[2] } h.ctxManager.AppendMessage(ctxID, &chat.ChatMessage{ Role: "human", Content: req.Message, Bin: []byte(img), }) messages := h.ctxManager.GetHistory(ctxID) stream, err := h.svc.Chat(context.Background(), &chat.ChatRequest{ Messages: messages, Model: req.Model, }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer func() { if err := stream.Close(); err != nil { log.Println("Error closing stream:", err) } }() c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "close") responseCh := make(chan string, 100) // use buffer go func() { defer func() { if r := recover(); r != nil { log.Printf("Recovered in stream processing: %v\n%s", r, debug.Stack()) } close(responseCh) }() resp := "" for { select { case <-c.Request.Context().Done(): // client disconnect log.Println("Client disconnected, stopping stream processing") return default: if !stream.Recv() { if err := stream.Err(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) log.Printf("Stream receive error: %v", err) } h.ctxManager.AppendMessage(ctxID, &chat.ChatMessage{ Role: "ai", Content: resp, Bin: nil, }) return } content := stream.Msg().Content resp += content responseCh <- content } } }() // SSE stream output cfg, err := config.GetConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) return } timeout := cfg.TimeoutSeconds c.Stream(func(w io.Writer) bool { select { case chunk, ok := <-responseCh: if !ok { return false } c.SSEvent("message", gin.H{"content": chunk}) return true case <-time.After(time.Duration(timeout) * time.Second): log.Println("Stream time out") return false case <-c.Request.Context().Done(): log.Println("Client disconnected") return false } }) } func (h *ChatHandler) NewContext(c *gin.Context) { session := sessions.Default(c) newCtxID := h.ctxManager.CreateContext() session.Set("current_context", newCtxID) if err := session.Save(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) return } c.JSON(http.StatusOK, gin.H{ "context_id": newCtxID, }) } func (h *ChatHandler) ListContexts(c *gin.Context) { session := sessions.Default(c) currentCtx := session.Get("current_context").(string) contexts := h.ctxManager.List() c.JSON(http.StatusOK, gin.H{ "current": currentCtx, "contexts": contexts, }) } func (h *ChatHandler) SwitchContext(c *gin.Context) { var req struct { ContextID string `json:"context_id"` } if err := c.BindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } exists := h.ctxManager.Consists(req.ContextID) if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "context not found"}) return } session := sessions.Default(c) session.Set("current_context", req.ContextID) if err := session.Save(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) return } c.JSON(http.StatusOK, gin.H{ "message": "context switched", }) }