plugins/wasm-go/extensions/ai-proxy/provider/failover.go (533 lines of code) (raw):
package provider
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
// @Title zh-CN 健康检测的间隔时间,单位毫秒
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配
failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"`
// @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string
// @Title zh-CN 记录本次请求时所有可用的 apiToken
ctxAvailableApiTokensInRequest string
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
ctxApiTokenRequestFailureCount string
// @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数
ctxApiTokenRequestSuccessCount string
// @Title zh-CN 记录所有可用的 apiToken 列表
ctxApiTokens string
// @Title zh-CN 记录所有不可用的 apiToken 列表
ctxUnavailableApiTokens string
// @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求
ctxHealthCheckEndpoint string
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
ctxVmLease string
}
type Lease struct {
VMID string `json:"vmID"`
Timestamp int64 `json:"timestamp"`
}
type HealthCheckEndpoint struct {
Host string `json:"host"`
Path string `json:"path"`
Cluster string `json:"cluster"`
}
const (
casMaxRetries = 10
addApiTokenOperation = "addApiToken"
removeApiTokenOperation = "removeApiToken"
addApiTokenRequestCountOperation = "addApiTokenRequestCount"
resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
CtxRequestHost = "requestHost"
CtxRequestPath = "requestPath"
CtxRequestBody = "requestBody"
)
var (
healthCheckClient wrapper.HttpClient
)
func (f *failover) FromJson(json gjson.Result) {
f.enabled = json.Get("enabled").Bool()
f.failureThreshold = json.Get("failureThreshold").Int()
if f.failureThreshold == 0 {
f.failureThreshold = 3
}
f.successThreshold = json.Get("successThreshold").Int()
if f.successThreshold == 0 {
f.successThreshold = 1
}
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
if f.healthCheckInterval == 0 {
f.healthCheckInterval = 5000
}
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
if f.healthCheckTimeout == 0 {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
for _, status := range json.Get("failoverOnStatus").Array() {
f.failoverOnStatus = append(f.failoverOnStatus, status.String())
}
// If failoverOnStatus is empty, default to retry on 4xx and 5xx
if len(f.failoverOnStatus) == 0 {
f.failoverOnStatus = []string{"4.*", "5.*"}
}
}
func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
}
return nil
}
func (c *ProviderConfig) initVariable() {
// Set provider name as prefix to differentiate shared data
provider := c.GetType()
id := c.GetId()
c.failover.ctxApiTokenInUse = provider + "-" + id + "-apiTokenInUse"
c.failover.ctxApiTokenRequestFailureCount = provider + "-" + id + "-apiTokenRequestFailureCount"
c.failover.ctxApiTokenRequestSuccessCount = provider + "-" + id + "-apiTokenRequestSuccessCount"
c.failover.ctxApiTokens = provider + "-" + id + "-apiTokens"
c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens"
c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath"
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
}
func parseConfig(json gjson.Result, config *any) error {
return nil
}
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
c.resetSharedData()
if c.isFailoverEnabled() {
log.Debugf("ai-proxy plugin failover is enabled")
vmID := generateVMID()
err := c.initApiTokens()
if err != nil {
return fmt.Errorf("failed to init apiTokens: %v", err)
}
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
// Only the Wasm VM that successfully acquires the lease will perform health check
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID) {
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
if err != nil {
log.Errorf("Failed to get unavailable tokens: %v", err)
return
}
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
}
})
}
return nil
}
func generateUrl(header http.Header) string {
return fmt.Sprintf("https://%s%s", header.Get(":authority"), header.Get(":path"))
}
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte) (http.Header, []byte, error) {
modifiedHeaders := util.SliceToHeader(headers)
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders)
}
var err error
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, modifiedHeaders)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
}
return modifiedHeaders, body, nil
}
func createHttpContext() *wrapper.CommonHttpCtx[any] {
setParseConfig := wrapper.ParseConfig[any](parseConfig)
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
return ctx
}
func (c *ProviderConfig) generateRequestHeadersAndBody() (HealthCheckEndpoint, [][2]string, []byte) {
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
if err != nil {
log.Errorf("Failed to get request host and path: %v", err)
}
var healthCheckEndpoint HealthCheckEndpoint
err = json.Unmarshal(data, &healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to unmarshal request host and path: %v", err)
}
headers := [][2]string{
{"content-type", "application/json"},
{":authority", healthCheckEndpoint.Host},
{":path", healthCheckEndpoint.Path},
}
body := []byte(fmt.Sprintf(`{
"model": "%s",
"messages": [
{
"role": "user",
"content": "who are you?"
}
]
}`, c.failover.healthCheckModel))
return healthCheckEndpoint, headers, body
}
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string) bool {
now := time.Now().Unix()
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return c.setLease(vmID, now, cas)
} else {
log.Errorf("Failed to get lease: %v", err)
return false
}
}
if data == nil {
return c.setLease(vmID, now, cas)
}
var lease Lease
err = json.Unmarshal(data, &lease)
if err != nil {
log.Errorf("Failed to unmarshal lease data: %v", err)
return false
}
// If vmID is itself, try to renew the lease directly
// If the lease is expired (60s), try to acquire the lease
if lease.VMID == vmID || now-lease.Timestamp > 60 {
lease.VMID = vmID
lease.Timestamp = now
return c.setLease(vmID, now, cas)
}
return false
}
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32) bool {
lease := Lease{
VMID: vmID,
Timestamp: timestamp,
}
leaseByte, err := json.Marshal(lease)
if err != nil {
log.Errorf("Failed to marshal lease data: %v", err)
return false
}
if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
log.Errorf("Failed to set or renew lease: %v", err)
return false
}
return true
}
func generateVMID() string {
return uuid.New().String()
}
// When number of request successes exceeds the threshold during health check,
// add the apiToken back to the available list and remove it from the unavailable list
func (c *ProviderConfig) handleAvailableApiToken(apiToken string) {
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
if err != nil {
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
return
}
successCount := successApiTokenRequestCount[apiToken] + 1
if successCount >= c.failover.successThreshold {
log.Infof("healthcheck after failover: apiToken %s is available now, add it back to the apiTokens list", apiToken)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
addApiToken(c.failover.ctxApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
} else {
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
}
}
// When number of request failures exceeds the threshold,
// remove the apiToken from the available list and add it to the unavailable list
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string) {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
return
}
availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
if err != nil {
log.Errorf("Failed to get available apiToken: %v", err)
return
}
// unavailable apiToken has been removed from the available list
if !containsElement(availableTokens, apiToken) {
return
}
failureCount := failureApiTokenRequestCount[apiToken] + 1
if failureCount >= c.failover.failureThreshold {
log.Infof("failover: apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
removeApiToken(c.failover.ctxApiTokens, apiToken)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx)
} else {
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
}
}
func addApiToken(key, apiToken string) {
modifyApiToken(key, apiToken, addApiTokenOperation)
}
func removeApiToken(key, apiToken string) {
modifyApiToken(key, apiToken, removeApiTokenOperation)
}
func modifyApiToken(key, apiToken, op string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokens, cas, err := getApiTokens(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
exists := containsElement(apiTokens, apiToken)
if op == addApiTokenOperation && exists {
log.Debugf("%s already exists in %s", apiToken, key)
return
} else if op == removeApiTokenOperation && !exists {
log.Debugf("%s does not exist in %s", apiToken, key)
return
}
if op == addApiTokenOperation {
apiTokens = append(apiTokens, apiToken)
} else {
apiTokens = removeElement(apiTokens, apiToken)
}
if err := setApiTokens(key, apiTokens, cas); err == nil {
log.Debugf("Successfully updated %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func getApiTokens(key string) ([]string, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return []string{}, cas, nil
}
return nil, 0, err
}
if data == nil {
return []string{}, cas, nil
}
var apiTokens []string
if err = json.Unmarshal(data, &apiTokens); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
}
return apiTokens, cas, nil
}
func setApiTokens(key string, apiTokens []string, cas uint32) error {
data, err := json.Marshal(apiTokens)
if err != nil {
return fmt.Errorf("failed to marshal tokens: %v", err)
}
return proxywasm.SetSharedData(key, data, cas)
}
func removeElement(slice []string, s string) []string {
for i := 0; i < len(slice); i++ {
if slice[i] == s {
slice = append(slice[:i], slice[i+1:]...)
i--
}
}
return slice
}
func containsElement(slice []string, s string) bool {
for _, item := range slice {
if item == s {
return true
}
}
return false
}
func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return make(map[string]int64), cas, nil
}
return nil, 0, err
}
if data == nil {
return make(map[string]int64), cas, nil
}
var apiTokens map[string]int64
err = json.Unmarshal(data, &apiTokens)
if err != nil {
return nil, 0, err
}
return apiTokens, cas, nil
}
func addApiTokenRequestCount(key, apiToken string) {
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation)
}
func resetApiTokenRequestCount(key, apiToken string) {
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation)
}
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string) {
if c.isFailoverEnabled() {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
}
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
log.Infof("Reset apiToken %s request failure count", apiTokenInUse)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse)
}
}
}
func modifyApiTokenRequestCount(key, apiToken string, op string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
if op == resetApiTokenRequestCountOperation {
delete(apiTokenRequestCount, apiToken)
} else {
apiTokenRequestCount[apiToken]++
}
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
if err != nil {
log.Errorf("Failed to marshal apiTokenRequestCount: %v", err)
}
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
}
func (c *ProviderConfig) GetGlobalRandomToken() string {
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
if err != nil {
return ""
}
count := len(apiTokens)
switch count {
case 0:
log.Warn("all tokens are unavailable, will use random one of the unavailable tokens")
return unavailableApiTokens[rand.Intn(len(unavailableApiTokens))]
case 1:
return apiTokens[0]
default:
return apiTokens[rand.Intn(count)]
}
}
func (c *ProviderConfig) GetAvailableApiToken(ctx wrapper.HttpContext) []string {
apiTokens, _ := ctx.GetContext(c.failover.ctxAvailableApiTokensInRequest).([]string)
return apiTokens
}
// SetAvailableApiTokens set available apiTokens of current request in the context, will be used in the retryOnFailure
func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext) {
var apiTokens []string
if c.isFailoverEnabled() {
apiTokens, _, _ = getApiTokens(c.failover.ctxApiTokens)
} else {
apiTokens = c.apiTokens
}
ctx.SetContext(c.failover.ctxAvailableApiTokensInRequest, apiTokens)
}
func (c *ProviderConfig) isFailoverEnabled() bool {
return c.failover.enabled
}
func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
if c.isFailoverEnabled() && util.MatchStatus(status, c.failover.failoverOnStatus) {
log.Warnf("apiToken:%s need failover, error status:%s", apiTokenInUse, status)
c.handleUnavailableApiToken(ctx, apiTokenInUse)
}
if c.IsRetryOnFailureEnabled() && util.MatchStatus(status, c.retryOnFailure.retryOnStatus) {
log.Warnf("need retry, notice that retry response will be bufferd, error status:%s", status)
err := c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens)
if err != nil {
log.Errorf("retryFailedRequest failed, err:%v", err)
return types.ActionContinue
}
return types.HeaderStopAllIterationAndWatermark
}
return types.ActionContinue
}
func isNotStreamingResponse(ctx wrapper.HttpContext) bool {
return ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool)
}
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
return token
}
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
var apiToken string
// if enable apiToken failover, only use available apiToken from global apiTokens list
if c.isFailoverEnabled() {
apiToken = c.GetGlobalRandomToken()
} else {
apiToken = c.GetRandomToken()
}
log.Debugf("Use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext) {
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil {
log.Errorf("Failed to get cluster_name: %v", err)
}
host := ctx.GetStringContext(CtxRequestHost, "")
path := ctx.GetStringContext(CtxRequestPath, "")
if host == "" || path == "" {
log.Errorf("get host or path failed, host:%s, path:%s", host, path)
return
}
healthCheckEndpoint := HealthCheckEndpoint{
Host: host,
Path: path,
Cluster: string(cluster),
}
healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to marshal request host and path: %v", err)
}
err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
if err != nil {
log.Errorf("Failed to set request host and path: %v", err)
}
}