bedrock/bedrock.go (200 lines of code) (raw):

package bedrock import ( "bytes" "context" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" ) const DefaultVersion = "bedrock-2023-05-31" var DefaultEndpoints = map[string]bool{ "/v1/complete": true, "/v1/messages": true, } type eventstreamChunk struct { Bytes string `json:"bytes"` P string `json:"p"` } type eventstreamDecoder struct { eventstream.Decoder rc io.ReadCloser evt ssestream.Event err error } func (e *eventstreamDecoder) Close() error { return e.rc.Close() } func (e *eventstreamDecoder) Err() error { return e.err } func (e *eventstreamDecoder) Next() bool { if e.err != nil { return false } msg, err := e.Decoder.Decode(e.rc, nil) if err != nil { e.err = err return false } messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { e.err = fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) return false } switch messageType.String() { case eventstreamapi.EventMessageType: eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader) if eventType == nil { e.err = fmt.Errorf("%s event header not present", eventstreamapi.EventTypeHeader) return false } if eventType.String() == "chunk" { chunk := eventstreamChunk{} err = json.Unmarshal(msg.Payload, &chunk) if err != nil { e.err = err return false } decoded, err := base64.StdEncoding.DecodeString(chunk.Bytes) if err != nil { e.err = err return false } e.evt = ssestream.Event{ Type: gjson.GetBytes(decoded, "type").String(), Data: decoded, } } case eventstreamapi.ExceptionMessageType: // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/service/iotsitewise/deserializers.go#L15511C1-L15567C2 exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) if exceptionType == nil { e.err = fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader) return false } // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/aws/protocol/restjson/decoder_util.go#L15-L48k var errInfo struct { Code string Type string `json:"__type"` Message string } err = json.Unmarshal(msg.Payload, &errInfo) if err != nil && err != io.EOF { e.err = fmt.Errorf("received exception %s: parsing exception payload failed: %w", exceptionType.String(), err) return false } errorCode := "UnknownError" errorMessage := errorCode if ev := exceptionType.String(); len(ev) > 0 { errorCode = ev } else if len(errInfo.Code) > 0 { errorCode = errInfo.Code } else if len(errInfo.Type) > 0 { errorCode = errInfo.Type } if len(errInfo.Message) > 0 { errorMessage = errInfo.Message } e.err = fmt.Errorf("received exception %s: %s", errorCode, errorMessage) return false case eventstreamapi.ErrorMessageType: errorCode := "UnknownError" errorMessage := errorCode if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil { errorCode = header.String() } if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil { errorMessage = header.String() } e.err = fmt.Errorf("received error %s: %s", errorCode, errorMessage) return false } return true } func (e *eventstreamDecoder) Event() ssestream.Event { return e.evt } var ( _ ssestream.Decoder = &eventstreamDecoder{} ) func init() { ssestream.RegisterDecoder("application/vnd.amazon.eventstream", func(rc io.ReadCloser) ssestream.Decoder { return &eventstreamDecoder{rc: rc} }) } // WithLoadDefaultConfig returns a request option which loads the default config for Amazon and registers // middleware that intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock. // // If you already have an [aws.Config], it is recommended that you instead call [WithConfig] directly. func WithLoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) option.RequestOption { cfg, err := config.LoadDefaultConfig(ctx, optFns...) if err != nil { panic(err) } return WithConfig(cfg) } // WithConfig returns a request option which uses the provided config and registers middleware that // intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock. func WithConfig(cfg aws.Config) option.RequestOption { signer := v4.NewSigner() middleware := bedrockMiddleware(signer, cfg) return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error { return rc.Apply( option.WithBaseURL(fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", cfg.Region)), option.WithMiddleware(middleware), ) }) } func bedrockMiddleware(signer *v4.Signer, cfg aws.Config) option.Middleware { return func(r *http.Request, next option.MiddlewareNext) (res *http.Response, err error) { var body []byte if r.Body != nil { body, err = io.ReadAll(r.Body) if err != nil { return nil, err } r.Body.Close() if !gjson.GetBytes(body, "anthropic_version").Exists() { body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion) } if r.Method == http.MethodPost && DefaultEndpoints[r.URL.Path] { model := gjson.GetBytes(body, "model").String() stream := gjson.GetBytes(body, "stream").Bool() body, _ = sjson.DeleteBytes(body, "model") body, _ = sjson.DeleteBytes(body, "stream") var path string if stream { path = fmt.Sprintf("/model/%s/invoke-with-response-stream", model) } else { path = fmt.Sprintf("/model/%s/invoke", model) } r.URL.Path = path } reader := bytes.NewReader(body) r.Body = io.NopCloser(reader) r.GetBody = func() (io.ReadCloser, error) { _, err := reader.Seek(0, 0) return io.NopCloser(reader), err } r.ContentLength = int64(len(body)) } ctx := r.Context() credentials, err := cfg.Credentials.Retrieve(ctx) if err != nil { return nil, err } hash := sha256.Sum256(body) err = signer.SignHTTP(ctx, credentials, r, hex.EncodeToString(hash[:]), "bedrock", cfg.Region, time.Now()) if err != nil { return nil, err } return next(r) } }