plugins/wasm-go/extensions/cors/main.go (119 lines of code) (raw):
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"cors/config"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
func main() {
wrapper.SetCtx(
"cors",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
)
}
func parseConfig(json gjson.Result, corsConfig *config.CorsConfig, log wrapper.Log) error {
log.Debug("parseConfig()")
allowOrigins := json.Get("allow_origins").Array()
for _, origin := range allowOrigins {
if err := corsConfig.AddAllowOrigin(origin.String()); err != nil {
log.Warnf("failed to AddAllowOrigin:%s, error:%v", origin, err)
}
}
allowOriginPatterns := json.Get("allow_origin_patterns").Array()
for _, pattern := range allowOriginPatterns {
corsConfig.AddAllowOriginPattern(pattern.String())
}
allowMethods := json.Get("allow_methods").Array()
for _, method := range allowMethods {
corsConfig.AddAllowMethod(method.String())
}
allowHeaders := json.Get("allow_headers").Array()
for _, header := range allowHeaders {
corsConfig.AddAllowHeader(header.String())
}
exposeHeaders := json.Get("expose_headers").Array()
for _, header := range exposeHeaders {
corsConfig.AddExposeHeader(header.String())
}
allowCredentials := json.Get("allow_credentials").Bool()
if err := corsConfig.SetAllowCredentials(allowCredentials); err != nil {
log.Warnf("failed to set AllowCredentials error: %v", err)
}
maxAge := json.Get("max_age").Int()
corsConfig.SetMaxAge(int(maxAge))
// Fill default values
corsConfig.FillDefaultValues()
log.Debugf("corsConfig:%+v", corsConfig)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log wrapper.Log) types.Action {
log.Debug("onHttpRequestHeaders()")
requestUrl, _ := proxywasm.GetHttpRequestHeader(":path")
method, _ := proxywasm.GetHttpRequestHeader(":method")
host := ctx.Host()
scheme := ctx.Scheme()
log.Debugf("scheme:%s, host:%s, method:%s, request:%s", scheme, host, method, requestUrl)
// Get headers
headers, _ := proxywasm.GetHttpRequestHeaders()
// Process request
httpCorsContext, err := corsConfig.Process(scheme, host, method, headers)
if err != nil {
log.Warnf("failed to process %s : %v", requestUrl, err)
return types.ActionContinue
}
log.Debugf("Process httpCorsContext:%+v", httpCorsContext)
// Set HttpContext
ctx.SetContext(config.HttpContextKey, httpCorsContext)
// Response forbidden when it is not valid cors request
if !httpCorsContext.IsValid {
headers := make([][2]string, 0)
headers = append(headers, [2]string{config.HeaderPluginTrace, "trace"})
proxywasm.SendHttpResponseWithDetail(http.StatusForbidden, "cors.forbidden", headers, []byte("Invalid CORS request"), -1)
return types.ActionPause
}
// Response directly when it is cors preflight request
if httpCorsContext.IsPreFlight {
headers := make([][2]string, 0)
headers = append(headers, [2]string{config.HeaderPluginTrace, "trace"})
proxywasm.SendHttpResponseWithDetail(http.StatusOK, "cores.preflight", headers, nil, -1)
return types.ActionPause
}
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log wrapper.Log) types.Action {
log.Debug("onHttpResponseHeaders()")
// Remove trace header if existed
proxywasm.RemoveHttpResponseHeader(config.HeaderPluginTrace)
// Remove upstream cors response headers if existed
proxywasm.RemoveHttpResponseHeader(config.HeaderAccessControlAllowOrigin)
proxywasm.RemoveHttpResponseHeader(config.HeaderAccessControlAllowMethods)
proxywasm.RemoveHttpResponseHeader(config.HeaderAccessControlExposeHeaders)
proxywasm.RemoveHttpResponseHeader(config.HeaderAccessControlAllowCredentials)
proxywasm.RemoveHttpResponseHeader(config.HeaderAccessControlMaxAge)
// Add debug header
proxywasm.AddHttpResponseHeader(config.HeaderPluginDebug, corsConfig.GetVersion())
// Restore httpCorsContext from HttpContext
httpCorsContext, ok := ctx.GetContext(config.HttpContextKey).(config.HttpCorsContext)
if !ok {
log.Debug("restore httpCorsContext from HttpContext error")
return types.ActionContinue
}
log.Debugf("Restore httpCorsContext:%+v", httpCorsContext)
// Skip which it is not valid or not cors request
if !httpCorsContext.IsValid || !httpCorsContext.IsCorsRequest {
return types.ActionContinue
}
// Add Cors headers when it is cors and valid request
if len(httpCorsContext.AllowOrigin) > 0 {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlAllowOrigin, httpCorsContext.AllowOrigin)
}
if len(httpCorsContext.AllowMethods) > 0 {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlAllowMethods, httpCorsContext.AllowMethods)
}
if len(httpCorsContext.AllowHeaders) > 0 {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlAllowHeaders, httpCorsContext.AllowHeaders)
}
if len(httpCorsContext.ExposeHeaders) > 0 {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlExposeHeaders, httpCorsContext.ExposeHeaders)
}
if httpCorsContext.AllowCredentials {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlAllowCredentials, "true")
}
if httpCorsContext.MaxAge > 0 {
proxywasm.AddHttpResponseHeader(config.HeaderAccessControlMaxAge, fmt.Sprintf("%d", httpCorsContext.MaxAge))
}
return types.ActionContinue
}