vertex/vertex.go (81 lines of code) (raw):
package vertex
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/api/transport"
"github.com/anthropics/anthropic-sdk-go/internal/requestconfig"
sdkoption "github.com/anthropics/anthropic-sdk-go/option"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const DefaultVersion = "vertex-2023-10-16"
// WithGoogleAuth returns a request option which loads the [Application Default Credentials] for Google Vertex AI and registers
// middleware that intercepts requests to the Messages API.
//
// If you already have a [*google.Credentials], it is recommended that you instead call [WithCredentials] directly.
//
// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
func WithGoogleAuth(ctx context.Context, region string, projectID string, scopes ...string) sdkoption.RequestOption {
if region == "" {
panic("region must be provided")
}
creds, err := google.FindDefaultCredentials(ctx, scopes...)
if err != nil {
panic(fmt.Errorf("failed to find default credentials: %v", err))
}
return WithCredentials(ctx, region, projectID, creds)
}
// WithCredentials returns a request option which uses the provided credentials for Google Vertex AI and registers middleware that
// intercepts request to the Messages API.
func WithCredentials(ctx context.Context, region string, projectID string, creds *google.Credentials) sdkoption.RequestOption {
client, _, err := transport.NewHTTPClient(ctx, option.WithTokenSource(creds.TokenSource))
if err != nil {
panic(fmt.Errorf("failed to create HTTP client: %v", err))
}
middleware := vertexMiddleware(region, projectID)
return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
return rc.Apply(
sdkoption.WithBaseURL(fmt.Sprintf("https://%s-aiplatform.googleapis.com/", region)),
sdkoption.WithMiddleware(middleware),
sdkoption.WithHTTPClient(client),
)
})
}
func vertexMiddleware(region, projectID string) sdkoption.Middleware {
return func(r *http.Request, next sdkoption.MiddlewareNext) (*http.Response, error) {
if r.Body != nil {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body.Close()
if !gjson.GetBytes(body, "anthropic_version").Exists() {
body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion)
}
if r.URL.Path == "/v1/messages" && r.Method == http.MethodPost {
if projectID == "" {
return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials")
}
model := gjson.GetBytes(body, "model").String()
stream := gjson.GetBytes(body, "stream").Bool()
body, _ = sjson.DeleteBytes(body, "model")
specifier := "rawPredict"
if stream {
specifier = "streamRawPredict"
}
r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", projectID, region, model, specifier)
}
if r.URL.Path == "/v1/messages/count_tokens" && r.Method == http.MethodPost {
if projectID == "" {
return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials")
}
r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", projectID, region)
}
reader := bytes.NewReader(body)
r.Body = io.NopCloser(reader)
r.GetBody = func() (io.ReadCloser, error) {
_, err := reader.Seek(0, 0)
return io.NopCloser(reader), err
}
r.ContentLength = int64(len(body))
}
return next(r)
}
}