gollm/azopenai.go (266 lines of code) (raw):
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gollm
import (
"context"
"encoding/json"
"fmt"
"net/url"
"os"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
)
func init() {
RegisterProvider("azopenai", azureOpenAIFactory)
}
func azureOpenAIFactory(ctx context.Context, u *url.URL) (Client, error) {
return NewAzureOpenAIClient(ctx, *u)
}
type AzureOpenAIClient struct {
client *azopenai.Client
}
var _ Client = &AzureOpenAIClient{}
func NewAzureOpenAIClient(ctx context.Context, u url.URL) (*AzureOpenAIClient, error) {
azureOpenAIEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT")
if u.Host != "" {
u.Scheme = "https"
azureOpenAIEndpoint = u.String()
}
if azureOpenAIEndpoint == "" {
return nil, fmt.Errorf("AZURE_OPENAI_ENDPOINT environment variable not set")
}
azureOpenAIClient := AzureOpenAIClient{}
azureOpenAIKey := os.Getenv("AZURE_OPENAI_API_KEY")
if azureOpenAIKey != "" {
keyCredential := azcore.NewKeyCredential(azureOpenAIKey)
client, err := azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, nil)
if err != nil {
return nil, err
}
azureOpenAIClient.client = client
} else {
credential, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
client, err := azopenai.NewClient(azureOpenAIEndpoint, credential, nil)
if err != nil {
return nil, err
}
azureOpenAIClient.client = client
}
return &azureOpenAIClient, nil
}
func (c *AzureOpenAIClient) Close() error {
return nil
}
func (c *AzureOpenAIClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) {
req := azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{Content: azopenai.NewChatRequestUserMessageContent(request.Prompt)},
},
DeploymentName: &request.Model,
}
resp, err := c.client.GetChatCompletions(ctx, req, nil)
if err != nil {
return nil, err
}
if len(resp.Choices) > 0 || resp.Choices[0].Message == nil || resp.Choices[0].Message.Content == nil {
return nil, fmt.Errorf("invalid completion response: %v", resp)
}
return &AzureOpenAICompletionResponse{response: *resp.Choices[0].Message.Content}, nil
}
func (c *AzureOpenAIClient) ListModels(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("listing models not supported yet for Azure OpenAI")
}
func (c *AzureOpenAIClient) SetResponseSchema(schema *Schema) error {
return nil
}
func (c *AzureOpenAIClient) StartChat(systemPrompt string, model string) Chat {
return &AzureOpenAIChat{
client: c.client,
model: model,
history: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestSystemMessage{Content: azopenai.NewChatRequestSystemMessageContent(systemPrompt)},
},
}
}
type AzureOpenAICompletionResponse struct {
response string
}
func (r *AzureOpenAICompletionResponse) Response() string {
return r.response
}
func (r *AzureOpenAICompletionResponse) UsageMetadata() any {
return nil
}
type AzureOpenAIChat struct {
client *azopenai.Client
model string
history []azopenai.ChatRequestMessageClassification
tools []azopenai.ChatCompletionsToolDefinitionClassification
}
func (c *AzureOpenAIChat) Send(ctx context.Context, contents ...any) (ChatResponse, error) {
for _, content := range contents {
switch v := content.(type) {
case string:
message := azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(v),
}
c.history = append(c.history, &message)
case FunctionCallResult:
message := azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(fmt.Sprintf("Function call result: %s", v.Result)),
}
c.history = append(c.history, &message)
default:
return nil, fmt.Errorf("unsupported content type: %T", v)
}
}
resp, err := c.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
DeploymentName: &c.model,
Messages: c.history,
Tools: c.tools,
}, nil)
if err != nil {
return nil, err
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("no response from Azure OpenAI: %v", resp)
}
return &AzureOpenAIChatResponse{azureOpenAIResponse: resp}, nil
}
func (c *AzureOpenAIChat) IsRetryableError(err error) bool {
// TODO: Implement this
return false
}
func (c *AzureOpenAIChat) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) {
// TODO: Implement streaming
response, err := c.Send(ctx, contents...)
if err != nil {
return nil, err
}
return singletonChatResponseIterator(response), nil
}
type AzureOpenAIChatResponse struct {
azureOpenAIResponse azopenai.GetChatCompletionsResponse
}
var _ ChatResponse = &AzureOpenAIChatResponse{}
func (r *AzureOpenAIChatResponse) MarshalJSON() ([]byte, error) {
formatted := RecordChatResponse{
Raw: r.azureOpenAIResponse,
}
return json.Marshal(&formatted)
}
func (r *AzureOpenAIChatResponse) String() string {
return fmt.Sprintf("AzureOpenAIChatResponse{candidates=%v}", r.azureOpenAIResponse.Choices)
}
func (r *AzureOpenAIChatResponse) UsageMetadata() any {
return r.azureOpenAIResponse.Usage
}
func (r *AzureOpenAIChatResponse) Candidates() []Candidate {
var candidates []Candidate
for _, candidate := range r.azureOpenAIResponse.Choices {
candidates = append(candidates, &AzureOpenAICandidate{candidate: candidate})
}
return candidates
}
type AzureOpenAICandidate struct {
candidate azopenai.ChatChoice
}
func (r *AzureOpenAICandidate) String() string {
var response strings.Builder
response.WriteString("[")
for i, parts := range r.Parts() {
if i > 0 {
response.WriteString(", ")
}
text, ok := parts.AsText()
if ok {
response.WriteString(text)
}
functionCalls, ok := parts.AsFunctionCalls()
if ok {
response.WriteString("functionCalls=[")
for _, functionCall := range functionCalls {
response.WriteString(fmt.Sprintf("%q(args=%v)", functionCall.Name, functionCall.Arguments))
}
response.WriteString("]}")
}
}
response.WriteString("]}")
return response.String()
}
func (r *AzureOpenAICandidate) Parts() []Part {
var parts []Part
if r.candidate.Message != nil {
parts = append(parts, &AzureOpenAIPart{
text: r.candidate.Message.Content,
})
}
for _, tool := range r.candidate.Message.ToolCalls {
if tool == nil {
continue
}
parts = append(parts, &AzureOpenAIPart{
functionCall: tool.(*azopenai.ChatCompletionsFunctionToolCall).Function,
})
}
return parts
}
type AzureOpenAIPart struct {
text *string
functionCall *azopenai.FunctionCall
}
func (p *AzureOpenAIPart) AsText() (string, bool) {
if p.text != nil && len(*p.text) > 0 {
return *p.text, true
}
return "", false
}
func (p *AzureOpenAIPart) AsFunctionCalls() ([]FunctionCall, bool) {
if p.functionCall != nil {
argumentsObj := map[string]any{}
err := json.Unmarshal([]byte(*p.functionCall.Arguments), &argumentsObj)
if err != nil {
return nil, false
}
functionCalls := []FunctionCall{
{
Name: *p.functionCall.Name,
Arguments: argumentsObj,
},
}
return functionCalls, true
}
return nil, false
}
func (c *AzureOpenAIChat) SetFunctionDefinitions(functionDefinitions []*FunctionDefinition) error {
var tools []azopenai.ChatCompletionsToolDefinitionClassification
for _, functionDefinition := range functionDefinitions {
tools = append(tools, &azopenai.ChatCompletionsFunctionToolDefinition{Function: fnDefToAzureOpenAITool(functionDefinition)})
}
c.tools = tools
return nil
}
func fnDefToAzureOpenAITool(fnDef *FunctionDefinition) *azopenai.ChatCompletionsFunctionToolDefinitionFunction {
properties := make(map[string]any)
for paramName, param := range fnDef.Parameters.Properties {
properties[paramName] = map[string]any{
"type": string(param.Type),
"description": param.Description,
}
}
parameters := map[string]any{
"type": "object",
"properties": properties,
}
if len(fnDef.Parameters.Required) > 0 {
parameters["required"] = fnDef.Parameters.Required
}
jsonBytes, _ := json.Marshal(parameters)
tool := azopenai.ChatCompletionsFunctionToolDefinitionFunction{
Name: &fnDef.Name,
Description: &fnDef.Description,
Parameters: jsonBytes,
}
return &tool
}