pkg/rules/langchain/llm_setup.go (119 lines of code) (raw):

// Copyright (c) 2024 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package langchain import ( "context" "reflect" _ "unsafe" "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/api" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/openai" ) //go:linkname openaiGenerateContentOnEnter github.com/tmc/langchaingo/llms/openai.openaiGenerateContentOnEnter func openaiGenerateContentOnEnter(call api.CallContext, llm *openai.LLM, ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption, ) { request := &langChainLLMRequest{ moduleName: "unKnown", operationName: "chat", } client := reflect.ValueOf(*llm).FieldByName("client") if client.IsValid() && !client.IsNil() { if client.Elem().FieldByName("Model").IsValid() { request.moduleName = client.Elem().FieldByName("Model").String() } if client.Elem().FieldByName("baseURL").IsValid() { request.serverAddress = client.Elem().FieldByName("baseURL").String() } } LLMBaseOnEnter(call, ctx, request, messages, options...) } //go:linkname openaiGenerateContentOnExit github.com/tmc/langchaingo/llms/openai.openaiGenerateContentOnExit func openaiGenerateContentOnExit(call api.CallContext, resp *llms.ContentResponse, err error) { data := call.GetData().(map[string]interface{}) request := langChainLLMRequest{} response := langChainLLMResponse{} ctx, ok := data["ctx"].(context.Context) if !ok { return } if err != nil { langChainLLMInstrument.End(ctx, request, response, err) return } request = data["request"].(langChainLLMRequest) if len(resp.Choices) > 0 { var finishReasons []string for _, choice := range resp.Choices { finishReasons = append(finishReasons, choice.StopReason) } response.responseFinishReasons = finishReasons if totalTokensAny, ok1 := resp.Choices[0].GenerationInfo["TotalTokens"]; ok1 { if totalTokens, ok2 := totalTokensAny.(int); ok2 { response.usageOutputTokens = int64(totalTokens) } } if reasoningTokensAny, ok1 := resp.Choices[0].GenerationInfo["ReasoningTokens"]; ok1 { if totalTokens, ok2 := reasoningTokensAny.(int); ok2 { request.usageInputTokens = int64(totalTokens) } } } langChainLLMInstrument.End(ctx, request, response, nil) } //go:linkname ollamaGenerateContentOnEnter github.com/tmc/langchaingo/llms/ollama.ollamaGenerateContentOnEnter func ollamaGenerateContentOnEnter(call api.CallContext, llm *ollama.LLM, ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption, ) { request := &langChainLLMRequest{ moduleName: "unKnown", operationName: "chat", } opt := reflect.ValueOf(*llm).FieldByName("options") if opt.IsValid() { if opt.FieldByName("model").IsValid() { request.moduleName = opt.FieldByName("model").String() } } LLMBaseOnEnter(call, ctx, request, messages, options...) } //go:linkname ollamaGenerateContentOnExit github.com/tmc/langchaingo/llms/ollama.ollamaGenerateContentOnExit func ollamaGenerateContentOnExit(call api.CallContext, resp *llms.ContentResponse, err error) { data := call.GetData().(map[string]interface{}) request := langChainLLMRequest{} response := langChainLLMResponse{} ctx, ok := data["ctx"].(context.Context) if !ok { return } if err != nil { langChainLLMInstrument.End(ctx, request, response, err) return } request = data["request"].(langChainLLMRequest) if totalTokensAny, ok1 := resp.Choices[0].GenerationInfo["TotalTokens"]; ok1 { if totalTokens, ok2 := totalTokensAny.(int); ok2 { response.usageOutputTokens = int64(totalTokens) } } langChainLLMInstrument.End(ctx, request, response, nil) } func LLMBaseOnEnter(call api.CallContext, ctx context.Context, req *langChainLLMRequest, messages []llms.MessageContent, options ...llms.CallOption, ) { llmsOpts := llms.CallOptions{} for _, opt := range options { opt(&llmsOpts) } if llmsOpts.Model != "" { req.moduleName = llmsOpts.Model } req.frequencyPenalty = llmsOpts.FrequencyPenalty req.presencePenalty = llmsOpts.PresencePenalty req.maxTokens = int64(llmsOpts.MaxTokens) req.temperature = llmsOpts.Temperature req.stopSequences = llmsOpts.StopWords req.topK = float64(llmsOpts.TopK) req.topP = llmsOpts.TopP req.seed = int64(llmsOpts.Seed) langCtx := langChainLLMInstrument.Start(ctx, *req) data := make(map[string]interface{}) data["ctx"] = langCtx data["request"] = *req call.SetData(data) }