ollama/ollama.go (143 lines of code) (raw):

// Copyright 2024 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // https://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package ollama provies handlers that proxies // ollama API calls to Gemini models. package ollama import ( "encoding/json" "io" "net/http" "strings" "time" "github.com/google-gemini/proxy-to-gemini/internal" "github.com/google/generative-ai-go/genai" "github.com/gorilla/mux" ) type handlers struct { client *genai.Client } func RegisterHandlers(r *mux.Router, client *genai.Client) { handlers := &handlers{client: client} r.HandleFunc("/api/generate", handlers.generateHandler) r.HandleFunc("/api/embed", handlers.embedHandler) } func (h *handlers) generateHandler(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err) return } defer r.Body.Close() var req GenerateRequest if err := json.Unmarshal(body, &req); err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal request body: %v", err) return } model := h.client.GenerativeModel(req.Model) model.GenerationConfig = genai.GenerationConfig{ Temperature: req.Options.Temperature, MaxOutputTokens: req.Options.NumPredict, TopK: req.Options.TopK, TopP: req.Options.TopP, } if req.Options.Stop != nil { model.GenerationConfig.StopSequences = []string{*req.Options.Stop} } if req.System != "" { model.SystemInstruction = &genai.Content{ Role: "system", Parts: []genai.Part{genai.Text(req.System)}, } } parts := []genai.Part{genai.Text(req.Prompt)} gresp, err := model.GenerateContent(r.Context(), parts...) if err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to generate content: %v", err) return } if len(gresp.Candidates) == 0 { internal.ErrorHandler(w, r, http.StatusInternalServerError, "no candidates returned") return } responseBuilder := &strings.Builder{} for _, part := range gresp.Candidates[0].Content.Parts { switch v := part.(type) { case genai.Text: responseBuilder.WriteString(string(v)) default: internal.ErrorHandler(w, r, http.StatusInternalServerError, "unsupported part type: %T", v) return } } if err := json.NewEncoder(w).Encode(&GenerateResponse{ Model: req.Model, Response: responseBuilder.String(), CreatedAt: time.Now(), PromptEvalCount: gresp.UsageMetadata.PromptTokenCount, EvalCount: gresp.UsageMetadata.TotalTokenCount, Done: true, }); err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode generate response: %v", err) return } } func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err) return } defer r.Body.Close() var req EmbedRequest if err := json.Unmarshal(body, &req); err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal request body: %v", err) return } model := h.client.EmbeddingModel(req.Model) batch := model.NewBatch() for _, input := range req.Input { batch.AddContent(genai.Text(input)) } gresp, err := model.BatchEmbedContents(r.Context(), batch) if err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to create embedding: %v", err) return } embeddings := make([][]float32, 0, len(gresp.Embeddings)) for _, embedding := range gresp.Embeddings { embeddings = append(embeddings, embedding.Values) } if err := json.NewEncoder(w).Encode(&EmbedResponse{ Model: req.Model, Embeddings: embeddings, }); err != nil { internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode embeddings response: %v", err) return } } type GenerateRequest struct { Model string `json:"model,omitempty"` Prompt string `json:"prompt,omitempty"` Suffix string `json:"suffix,omitempty"` Options Options `json:"options,omitempty"` System string `json:"system,omitempty"` // TODO: Support images. // TODO: Support format. // TODO: Support streaming. } type GenerateResponse struct { Model string `json:"model,omitempty"` Response string `json:"response,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"` PromptEvalCount int32 `json:"prompt_eval_count,omitempty"` EvalCount int32 `json:"eval_count,omitempty"` Done bool `json:"done,omitempty"` } type Options struct { Temperature *float32 `json:"temperature,omitempty"` Stop *string `json:"stop,omitempty"` NumPredict *int32 `json:"num_predict,omitempty"` TopK *int32 `json:"top_k,omitempty"` TopP *float32 `json:"top_p,omitempty"` // TODO: Anything else to support? } type EmbedRequest struct { Model string `json:"model,omitempty"` Input []string `json:"input,omitempty"` } type EmbedResponse struct { Model string `json:"model,omitempty"` Embeddings [][]float32 `json:"embeddings,omitempty"` }