plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go (812 lines of code) (raw):
package provider
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash"
"hash/crc32"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
const (
httpPostMethod = "POST"
awsService = "bedrock"
// bedrock-runtime.{awsRegion}.amazonaws.com
bedrockDefaultDomain = "bedrock-runtime.%s.amazonaws.com"
// converse路径 /model/{modelId}/converse
bedrockChatCompletionPath = "/model/%s/converse"
// converseStream路径 /model/{modelId}/converse-stream
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
// invoke_model 路径 /model/{modelId}/invoke
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
)
type bedrockProviderInitializer struct {
}
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
return errors.New("missing bedrock access authentication parameters")
}
if len(config.awsRegion) == 0 {
return errors.New("missing bedrock region parameters")
}
return nil
}
func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): bedrockChatCompletionPath,
string(ApiNameImageGeneration): bedrockInvokeModelPath,
}
}
func (b *bedrockProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(b.DefaultCapabilities())
return &bedrockProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type bedrockProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (b *bedrockProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
events := extractAmazonEventStreamEvents(ctx, chunk)
if len(events) == 0 {
return chunk, fmt.Errorf("No events are extracted ")
}
var responseBuilder strings.Builder
for _, event := range events {
outputEvent, err := b.convertEventFromBedrockToOpenAI(ctx, event)
if err != nil {
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
return chunk, err
}
responseBuilder.WriteString(string(outputEvent))
}
return []byte(responseBuilder.String()), nil
}
func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContext, bedrockEvent ConverseStreamEvent) ([]byte, error) {
choices := make([]chatCompletionChoice, 0)
chatChoice := chatCompletionChoice{
Delta: &chatMessage{},
}
if bedrockEvent.Role != nil {
chatChoice.Delta.Role = *bedrockEvent.Role
}
if bedrockEvent.Delta != nil {
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
}
if bedrockEvent.StopReason != nil {
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
}
choices = append(choices, chatChoice)
requestId := ctx.GetStringContext(requestIdHeader, "")
openAIFormattedChunk := &chatCompletionResponse{
Id: requestId,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: choices,
}
if bedrockEvent.Usage != nil {
openAIFormattedChunk.Choices = choices[:0]
openAIFormattedChunk.Usage = usage{
CompletionTokens: bedrockEvent.Usage.OutputTokens,
PromptTokens: bedrockEvent.Usage.InputTokens,
TotalTokens: bedrockEvent.Usage.TotalTokens,
}
}
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
var openAIChunk strings.Builder
openAIChunk.WriteString(ssePrefix)
openAIChunk.WriteString(string(openAIFormattedChunkBytes))
openAIChunk.WriteString("\n\n")
return []byte(openAIChunk.String()), nil
}
type ConverseStreamEvent struct {
ContentBlockIndex int `json:"contentBlockIndex,omitempty"`
Delta *converseStreamEventContentBlockDelta `json:"delta,omitempty"`
Role *string `json:"role,omitempty"`
StopReason *string `json:"stopReason,omitempty"`
Usage *tokenUsage `json:"usage,omitempty"`
Start *contentBlockStart `json:"start,omitempty"`
}
type converseStreamEventContentBlockDelta struct {
Text *string `json:"text,omitempty"`
ToolUse *toolUseBlockDelta `json:"toolUse,omitempty"`
}
type toolUseBlockStart struct {
Name string `json:"name"`
ToolUseID string `json:"toolUseId"`
}
type contentBlockStart struct {
ToolUse *toolUseBlockStart `json:"toolUse,omitempty"`
}
type toolUseBlockDelta struct {
Input string `json:"input"`
}
type bedrockImageGenerationResponse struct {
Images []string `json:"images"`
Error string `json:"error"`
}
type bedrockImageGenerationTextToImageParams struct {
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
ConditionImage string `json:"conditionImage,omitempty"`
ControlMode string `json:"controlMode,omitempty"`
ControlStrength float32 `json:"controlLength,omitempty"`
}
type bedrockImageGenerationConfig struct {
Width int `json:"width"`
Height int `json:"height"`
Quality string `json:"quality,omitempty"`
CfgScale float32 `json:"cfgScale,omitempty"`
Seed int `json:"seed,omitempty"`
NumberOfImages int `json:"numberOfImages"`
}
type bedrockImageGenerationColorGuidedGenerationParams struct {
Colors []string `json:"colors"`
ReferenceImage string `json:"referenceImage"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationImageVariationParams struct {
Images []string `json:"images"`
SimilarityStrength float32 `json:"similarityStrength"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationInPaintingParams struct {
Image string `json:"image"`
MaskPrompt string `json:"maskPrompt"`
MaskImage string `json:"maskImage"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationOutPaintingParams struct {
Image string `json:"image"`
MaskPrompt string `json:"maskPrompt"`
MaskImage string `json:"maskImage"`
OutPaintingMode string `json:"outPaintingMode"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationBackgroundRemovalParams struct {
Image string `json:"image"`
}
type bedrockImageGenerationRequest struct {
TaskType string `json:"taskType"`
ImageGenerationConfig *bedrockImageGenerationConfig `json:"imageGenerationConfig"`
TextToImageParams *bedrockImageGenerationTextToImageParams `json:"textToImageParams,omitempty"`
ColorGuidedGenerationParams *bedrockImageGenerationColorGuidedGenerationParams `json:"colorGuidedGenerationParams,omitempty"`
ImageVariationParams *bedrockImageGenerationImageVariationParams `json:"imageVariationParams,omitempty"`
InPaintingParams *bedrockImageGenerationInPaintingParams `json:"inPaintingParams,omitempty"`
OutPaintingParams *bedrockImageGenerationOutPaintingParams `json:"outPaintingParams,omitempty"`
BackgroundRemovalParams *bedrockImageGenerationBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"`
}
func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
}
r := bytes.NewReader(body)
var events []ConverseStreamEvent
var lastRead int64 = -1
messageBuffer := make([]byte, 1024)
defer func() {
log.Infof("extractAmazonEventStreamEvents: lastRead=%d, r.Size=%d", lastRead, r.Size())
ctx.SetContext(ctxKeyStreamingBody, nil)
}()
for {
msg, err := decodeMessage(r, messageBuffer)
if err != nil {
if err == io.EOF {
break
}
log.Errorf("failed to decode message: %v", err)
break
}
var event ConverseStreamEvent
if err = json.Unmarshal(msg.Payload, &event); err == nil {
events = append(events, event)
}
lastRead = r.Size() - int64(r.Len())
}
return events
}
type bedrockStreamMessage struct {
Headers headers
Payload []byte
}
type EventFrame struct {
TotalLength uint32
HeadersLength uint32
PreludeCRC uint32
Headers map[string]interface{}
Payload []byte
PayloadCRC uint32
}
type headers []header
type header struct {
Name string
Value Value
}
func (hs *headers) Set(name string, value Value) {
var i int
for ; i < len(*hs); i++ {
if (*hs)[i].Name == name {
(*hs)[i].Value = value
return
}
}
*hs = append(*hs, header{
Name: name, Value: value,
})
}
func decodeMessage(reader io.Reader, payloadBuf []byte) (m bedrockStreamMessage, err error) {
crc := crc32.New(crc32.MakeTable(crc32.IEEE))
hashReader := io.TeeReader(reader, crc)
prelude, err := decodePrelude(hashReader, crc)
if err != nil {
return bedrockStreamMessage{}, err
}
if prelude.HeadersLen > 0 {
lr := io.LimitReader(hashReader, int64(prelude.HeadersLen))
m.Headers, err = decodeHeaders(lr)
if err != nil {
return bedrockStreamMessage{}, err
}
}
if payloadLen := prelude.PayloadLen(); payloadLen > 0 {
buf, err := decodePayload(payloadBuf, io.LimitReader(hashReader, int64(payloadLen)))
if err != nil {
return bedrockStreamMessage{}, err
}
m.Payload = buf
}
msgCRC := crc.Sum32()
if err := validateCRC(reader, msgCRC); err != nil {
return bedrockStreamMessage{}, err
}
return m, nil
}
func decodeHeaders(r io.Reader) (headers, error) {
hs := headers{}
for {
name, err := decodeHeaderName(r)
if err != nil {
if err == io.EOF {
// EOF while getting header name means no more headers
break
}
return nil, err
}
value, err := decodeHeaderValue(r)
if err != nil {
return nil, err
}
hs.Set(name, value)
}
return hs, nil
}
func decodeHeaderValue(r io.Reader) (Value, error) {
var raw rawValue
typ, err := decodeUint8(r)
if err != nil {
return nil, err
}
raw.Type = valueType(typ)
var v Value
switch raw.Type {
case stringValueType:
var tv StringValue
err = tv.decode(r)
v = tv
default:
log.Errorf("unknown value type %d", raw.Type)
}
// Error could be EOF, let caller deal with it
return v, err
}
type Value interface {
Get() interface{}
}
type StringValue string
func (v StringValue) Get() interface{} {
return string(v)
}
func (v *StringValue) decode(r io.Reader) error {
s, err := decodeStringValue(r)
if err != nil {
return err
}
*v = StringValue(s)
return nil
}
func decodeBytesValue(r io.Reader) ([]byte, error) {
var raw rawValue
var err error
raw.Len, err = decodeUint16(r)
if err != nil {
return nil, err
}
buf := make([]byte, raw.Len)
_, err = io.ReadFull(r, buf)
if err != nil {
return nil, err
}
return buf, nil
}
func decodeUint16(r io.Reader) (uint16, error) {
var b [2]byte
bs := b[:]
_, err := io.ReadFull(r, bs)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint16(bs), nil
}
func decodeStringValue(r io.Reader) (string, error) {
v, err := decodeBytesValue(r)
return string(v), err
}
type rawValue struct {
Type valueType
Len uint16 // Only set for variable length slices
Value []byte // byte representation of value, BigEndian encoding.
}
type valueType uint8
const (
trueValueType valueType = iota
falseValueType
int8ValueType // Byte
int16ValueType // Short
int32ValueType // Integer
int64ValueType // Long
bytesValueType
stringValueType
timestampValueType
uuidValueType
)
func decodeHeaderName(r io.Reader) (string, error) {
var n headerName
var err error
n.Len, err = decodeUint8(r)
if err != nil {
return "", err
}
name := n.Name[:n.Len]
if _, err := io.ReadFull(r, name); err != nil {
return "", err
}
return string(name), nil
}
func decodeUint8(r io.Reader) (uint8, error) {
type byteReader interface {
ReadByte() (byte, error)
}
if br, ok := r.(byteReader); ok {
v, err := br.ReadByte()
return v, err
}
var b [1]byte
_, err := io.ReadFull(r, b[:])
return b[0], err
}
const maxHeaderNameLen = 255
type headerName struct {
Len uint8
Name [maxHeaderNameLen]byte
}
func decodePayload(buf []byte, r io.Reader) ([]byte, error) {
w := bytes.NewBuffer(buf[0:0])
_, err := io.Copy(w, r)
return w.Bytes(), err
}
type messagePrelude struct {
Length uint32
HeadersLen uint32
PreludeCRC uint32
}
func (p messagePrelude) ValidateLens() error {
if p.Length == 0 {
return fmt.Errorf("message prelude want: 16, have: %v", int(p.Length))
}
return nil
}
func (p messagePrelude) PayloadLen() uint32 {
return p.Length - p.HeadersLen - 16
}
func decodePrelude(r io.Reader, crc hash.Hash32) (messagePrelude, error) {
var p messagePrelude
var err error
p.Length, err = decodeUint32(r)
if err != nil {
return messagePrelude{}, err
}
p.HeadersLen, err = decodeUint32(r)
if err != nil {
return messagePrelude{}, err
}
if err := p.ValidateLens(); err != nil {
return messagePrelude{}, err
}
preludeCRC := crc.Sum32()
if err := validateCRC(r, preludeCRC); err != nil {
return messagePrelude{}, err
}
p.PreludeCRC = preludeCRC
return p, nil
}
func decodeUint32(r io.Reader) (uint32, error) {
var b [4]byte
bs := b[:]
_, err := io.ReadFull(r, bs)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint32(bs), nil
}
func validateCRC(r io.Reader, expect uint32) error {
msgCRC, err := decodeUint32(r)
if err != nil {
return err
}
if msgCRC != expect {
return fmt.Errorf("message checksum mismatch")
}
return nil
}
func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
ctx.SetContext(requestIdHeader, headers.Get(requestIdHeader))
if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" {
headers.Set("Content-Type", "text/event-stream; charset=utf-8")
}
headers.Del("Content-Length")
}
func (b *bedrockProvider) GetProviderType() string {
return providerTypeBedrock
}
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
b.config.handleRequestHeaders(b, ctx, apiName)
return nil
}
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
}
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !b.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body)
}
func (b *bedrockProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &bedrockTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if len(request.System) > 0 {
request.System = append(request.System, systemContentBlock{Text: content})
} else {
request.System = []systemContentBlock{{Text: content}}
}
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, nil)
return requestBytes, err
}
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return b.onImageGenerationRequestBody(ctx, body, headers)
default:
return b.config.defaultTransformRequestBody(ctx, apiName, body)
}
}
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionResponseBody(ctx, body)
case ApiNameImageGeneration:
return b.onImageGenerationResponseBody(ctx, body)
}
return nil, errUnsupportedApiName
}
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
bedrockResponse := &bedrockImageGenerationResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
}
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
return json.Marshal(response)
}
func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageGenerationRequest{}
err := b.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
headers.Set("Accept", "*/*")
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
return b.buildBedrockImageGenerationRequest(request, headers)
}
func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageGenerationRequest, headers http.Header) ([]byte, error) {
width, height := 1024, 1024
pairs := strings.Split(origRequest.Size, "x")
if len(pairs) == 2 {
width, _ = strconv.Atoi(pairs[0])
height, _ = strconv.Atoi(pairs[1])
}
request := &bedrockImageGenerationRequest{
TaskType: "TEXT_IMAGE",
TextToImageParams: &bedrockImageGenerationTextToImageParams{
Text: origRequest.Prompt,
},
ImageGenerationConfig: &bedrockImageGenerationConfig{
NumberOfImages: origRequest.N,
Width: width,
Height: height,
Quality: origRequest.Quality,
},
}
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
}
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
data := make([]imageGenerationData, len(bedrockResponse.Images))
for i, image := range bedrockResponse.Images {
data[i] = imageGenerationData{
B64: image,
}
}
return &imageGenerationResponse{
Created: time.Now().UnixMilli() / 1000,
Data: data,
}
}
func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
bedrockResponse := &bedrockConverseResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
log.Errorf("unable to unmarshal bedrock response: %v", err)
return nil, fmt.Errorf("unable to unmarshal bedrock response: %v", err)
}
response := b.buildChatCompletionResponse(ctx, bedrockResponse)
return json.Marshal(response)
}
func (b *bedrockProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := b.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
streaming := request.Stream
headers.Set("Accept", "*/*")
if streaming {
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockStreamChatCompletionPath, request.Model))
} else {
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockChatCompletionPath, request.Model))
}
return b.buildBedrockTextGenerationRequest(request, headers)
}
func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCompletionRequest, headers http.Header) ([]byte, error) {
messages := make([]bedrockMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2BedrockMessage(origRequest.Messages[i]))
}
request := &bedrockTextGenRequest{
Messages: messages,
InferenceConfig: bedrockInferenceConfig{
MaxTokens: origRequest.MaxTokens,
Temperature: origRequest.Temperature,
TopP: origRequest.TopP,
},
AdditionalModelRequestFields: map[string]interface{}{},
PerformanceConfig: PerformanceConfiguration{
Latency: "standard",
},
}
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
}
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
var outputContent string
if len(bedrockResponse.Output.Message.Content) > 0 {
outputContent = bedrockResponse.Output.Message.Content[0].Text
}
choices := []chatCompletionChoice{
{
Index: 0,
Message: &chatMessage{
Role: bedrockResponse.Output.Message.Role,
Content: outputContent,
},
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
},
}
requestId := ctx.GetStringContext(requestIdHeader, "")
return &chatCompletionResponse{
Id: requestId,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: choices,
Usage: usage{
PromptTokens: bedrockResponse.Usage.InputTokens,
CompletionTokens: bedrockResponse.Usage.OutputTokens,
TotalTokens: bedrockResponse.Usage.TotalTokens,
},
}
}
func stopReasonBedrock2OpenAI(reason string) string {
switch reason {
case "end_turn":
return finishReasonStop
case "stop_sequence":
return finishReasonStop
case "max_tokens":
return finishReasonLength
default:
return reason
}
}
type bedrockTextGenRequest struct {
Messages []bedrockMessage `json:"messages"`
System []systemContentBlock `json:"system,omitempty"`
InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"`
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"`
}
type PerformanceConfiguration struct {
Latency string `json:"latency,omitempty"`
}
type bedrockMessage struct {
Role string `json:"role"`
Content []bedrockMessageContent `json:"content"`
}
type bedrockMessageContent struct {
Text string `json:"text,omitempty"`
Image *imageBlock `json:"image,omitempty"`
}
type systemContentBlock struct {
Text string `json:"text,omitempty"`
}
type imageBlock struct {
Format string `json:"format,omitempty"`
Source imageSource `json:"source,omitempty"`
}
type imageSource struct {
Bytes string `json:"bytes,omitempty"`
}
type bedrockInferenceConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
type bedrockConverseResponse struct {
Metrics converseMetrics `json:"metrics"`
Output converseOutputMemberMessage `json:"output"`
StopReason string `json:"stopReason"`
Usage tokenUsage `json:"usage"`
}
type converseMetrics struct {
LatencyMs int `json:"latencyMs"`
}
type converseOutputMemberMessage struct {
Message message `json:"message"`
}
type message struct {
Content []contentBlockMemberText `json:"content"`
Role string `json:"role"`
}
type contentBlockMemberText struct {
Text string `json:"text"`
}
type tokenUsage struct {
InputTokens int `json:"inputTokens,omitempty"`
OutputTokens int `json:"outputTokens,omitempty"`
TotalTokens int `json:"totalTokens"`
}
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
if chatMessage.IsStringContent() {
return bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}},
}
} else {
var contents []bedrockMessageContent
openaiContent := chatMessage.ParseContent()
for _, part := range openaiContent {
var content bedrockMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else {
log.Warnf("imageUrl is not supported: %s", part.Type)
continue
}
contents = append(contents, content)
}
return bedrockMessage{
Role: chatMessage.Role,
Content: contents,
}
}
}
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102")
path, _ := proxywasm.GetHttpRequestHeader(":path")
if headers != nil {
path = headers.Get(":path")
}
signature := b.generateSignature(path, amzDate, dateStamp, body)
if headers != nil {
headers.Set("X-Amz-Date", amzDate)
headers.Set("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
} else {
_ = proxywasm.ReplaceHttpRequestHeader("X-Amz-Date", amzDate)
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
}
}
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
hashedPayload := sha256Hex(body)
path = urlEncoding(path)
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
amzDate, credentialScope, hashedCanonReq)
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
signature := hmacHex(signingKey, stringToSign)
return signature
}
func urlEncoding(rawStr string) string {
encodedStr := strings.ReplaceAll(rawStr, ":", "%3A")
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
encodedStr = strings.ReplaceAll(encodedStr, "$", "%24")
encodedStr = strings.ReplaceAll(encodedStr, "@", "%40")
return encodedStr
}
func getSignatureKey(key, dateStamp, region, service string) []byte {
kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
kRegion := hmacSha256(kDate, region)
kService := hmacSha256(kRegion, service)
kSigning := hmacSha256(kService, "aws4_request")
return kSigning
}
func hmacSha256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}
func sha256Hex(data []byte) string {
h := sha256.New()
h.Write(data)
return hex.EncodeToString(h.Sum(nil))
}
func hmacHex(key []byte, data string) string {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}