gollm/ollama.go (262 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gollm import ( "context" "encoding/json" "fmt" "net/url" "github.com/ollama/ollama/api" "k8s.io/klog/v2" ) func init() { RegisterProvider("ollama", ollamaFactory) } func ollamaFactory(ctx context.Context, u *url.URL) (Client, error) { return NewOllamaClient(ctx) } const ( defaultOllamaModel = "gemma2:latest" ) type OllamaClient struct { client *api.Client } type OllamaChat struct { client *api.Client model string history []api.Message tools []api.Tool } var _ Client = &OllamaClient{} func NewOllamaClient(ctx context.Context) (*OllamaClient, error) { client, err := api.ClientFromEnvironment() if err != nil { return nil, err } return &OllamaClient{ client: client, }, nil } func (c *OllamaClient) Close() error { return nil } func (c *OllamaClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) { req := &api.GenerateRequest{ Model: request.Model, Prompt: request.Prompt, Stream: ptrTo(false), } var ollamaResponse *OllamaCompletionResponse respFunc := func(resp api.GenerateResponse) error { ollamaResponse = &OllamaCompletionResponse{response: resp.Response} return nil } err := c.client.Generate(ctx, req, respFunc) if err != nil { return nil, err } return ollamaResponse, nil } func (c *OllamaClient) ListModels(ctx context.Context) ([]string, error) { modelResponse, err := c.client.List(ctx) if err != nil { return nil, err } var models []string for _, model := range modelResponse.Models { models = append(models, model.Name) } return models, nil } func (c *OllamaClient) SetResponseSchema(schema *Schema) error { return nil } func (c *OllamaClient) StartChat(systemPrompt, model string) Chat { return &OllamaChat{ client: c.client, model: model, history: []api.Message{ { Role: "system", Content: systemPrompt, }, }, } } type OllamaCompletionResponse struct { response string } func (r *OllamaCompletionResponse) Response() string { return r.response } func (r *OllamaCompletionResponse) UsageMetadata() any { return nil } func (c *OllamaChat) Send(ctx context.Context, contents ...any) (ChatResponse, error) { log := klog.FromContext(ctx) for _, content := range contents { switch v := content.(type) { case string: message := api.Message{ Role: "user", Content: v, } c.history = append(c.history, message) case FunctionCallResult: message := api.Message{ Role: "user", Content: fmt.Sprintf("Function call result: %s", v.Result), } c.history = append(c.history, message) default: return nil, fmt.Errorf("unsupported content type: %T", v) } } req := &api.ChatRequest{ Model: c.model, Messages: c.history, // set streaming to false Stream: new(bool), Tools: c.tools, } var ollamaResponse *OllamaChatResponse respFunc := func(resp api.ChatResponse) error { log.Info("recieved response from ollama", "resp", resp) ollamaResponse = &OllamaChatResponse{ ollamaResponse: resp, candidates: []*OllamaCandidate{ { parts: []OllamaPart{ { text: resp.Message.Content, toolCalls: resp.Message.ToolCalls, }, }, }, }, } c.history = append(c.history, resp.Message) return nil } err := c.client.Chat(ctx, req, respFunc) if err != nil { return nil, err } log.Info("ollama response", "parsed_response", ollamaResponse) return ollamaResponse, nil } func (c *OllamaChat) IsRetryableError(err error) bool { // TODO(droot): Implement this return false } func (c *OllamaChat) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) { // TODO: Implement streaming response, err := c.Send(ctx, contents...) if err != nil { return nil, err } return singletonChatResponseIterator(response), nil } type OllamaChatResponse struct { candidates []*OllamaCandidate ollamaResponse api.ChatResponse } var _ ChatResponse = &OllamaChatResponse{} func (r *OllamaChatResponse) MarshalJSON() ([]byte, error) { formatted := RecordChatResponse{ Raw: r.ollamaResponse, } return json.Marshal(&formatted) } func (r *OllamaChatResponse) String() string { return fmt.Sprintf("OllamaChatResponse{candidates=%v}", r.candidates) } func (r *OllamaChatResponse) UsageMetadata() any { return nil } func (r *OllamaChatResponse) Candidates() []Candidate { var cads []Candidate for _, candidate := range r.candidates { cads = append(cads, candidate) } return cads } type OllamaCandidate struct { parts []OllamaPart } func (r *OllamaCandidate) String() string { return r.parts[0].text } func (r *OllamaCandidate) Parts() []Part { var parts []Part for _, part := range r.parts { parts = append(parts, &OllamaPart{ text: part.text, toolCalls: part.toolCalls, }) } return parts } type OllamaPart struct { text string toolCalls []api.ToolCall } func (p *OllamaPart) AsText() (string, bool) { if len(p.text) > 0 { return p.text, true } return "", false } func (p *OllamaPart) AsFunctionCalls() ([]FunctionCall, bool) { if len(p.toolCalls) > 0 { var functionCalls []FunctionCall for _, toolCall := range p.toolCalls { functionCalls = append(functionCalls, FunctionCall{ Name: toolCall.Function.Name, Arguments: toolCall.Function.Arguments, }) } return functionCalls, true } return nil, false } func (c *OllamaChat) SetFunctionDefinitions(functionDefinitions []*FunctionDefinition) error { var tools []api.Tool for _, functionDefinition := range functionDefinitions { tools = append(tools, fnDefToOllamaTool(functionDefinition)) } c.tools = tools return nil } func fnDefToOllamaTool(fnDef *FunctionDefinition) api.Tool { tool := api.Tool{ Type: "function", Function: api.ToolFunction{ Name: fnDef.Name, Description: fnDef.Description, Parameters: struct { Type string `json:"type"` Required []string `json:"required"` Properties map[string]struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` } `json:"properties"` }{ Type: "object", Required: fnDef.Parameters.Required, Properties: map[string]struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` }{}, }, }, } for paramName, param := range fnDef.Parameters.Properties { tool.Function.Parameters.Properties[paramName] = struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` }{ Type: string(param.Type), Description: param.Description, } } return tool }