pkg/filter/csrf/csrf.go (96 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package csrf import ( "encoding/base64" "encoding/json" "fmt" stdHttp "net/http" ) import ( "github.com/apache/dubbo-go-pixiu/pkg/common/constant" "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter" "github.com/apache/dubbo-go-pixiu/pkg/context/http" ) const ( // Kind is the kind of Fallback. Kind = constant.HTTPCsrfFilter ) const ( csrfSecret = "csrfSecret" csrfSalt = "csrfSalt" ) func init() { filter.RegisterHttpFilter(&Plugin{}) } type ( // Plugin is http filter plugin. Plugin struct { } // FilterFactory is http filter instance FilterFactory struct { cfg *Config } Filter struct { cfg *Config } // Config describe the config of FilterFactory Config struct { Key string `yaml:"key" json:"key" mapstructure:"key"` // get request key Secret string `yaml:"secret" json:"secret" mapstructure:"secret"` // private key ErrorMsg string `yaml:"error_msg" json:"error_msg" mapstructure:"error_msg"` // hint error info IgnoreMethods []string `yaml:"ignore_methods" json:"ignore_methods" mapstructure:"ignore_methods"` // ignore request method } ) func (p *Plugin) Kind() string { return Kind } func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) { return &FilterFactory{cfg: &Config{}}, nil } func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain filter.FilterChain) error { f := &Filter{cfg: factory.cfg} chain.AppendDecodeFilters(f) return nil } func (f *Filter) Decode(ctx *http.HttpContext) filter.FilterStatus { ctx.Request.Header.Set(csrfSecret, f.cfg.Secret) if inMethod(f.cfg.IgnoreMethods, ctx.Request.Method) { return filter.Continue } salt := ctx.Request.Header.Get(csrfSalt) if salt == "" { bt, _ := json.Marshal(http.ErrResponse{Message: f.cfg.ErrorMsg}) ctx.SendLocalReply(stdHttp.StatusForbidden, bt) return filter.Stop } token := tokenize(f.cfg.Secret, salt) if token != tokenGetter(ctx, f.cfg.Key) { bt, _ := json.Marshal(http.ErrResponse{Message: f.cfg.ErrorMsg}) ctx.SendLocalReply(stdHttp.StatusForbidden, bt) return filter.Stop } return filter.Continue } func tokenGetter(ctx *http.HttpContext, key string) string { req := ctx.Request if t := req.Form.Get(key); t != "" { return t } else if t := req.URL.Query().Get(key); t != "" { return t } else if t := req.Header.Get(key); t != "" { return t } return "" } func inMethod(methods []string, method string) bool { for _, v := range methods { if v == method { return true } } return false } func tokenize(secret, salt string) string { return base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s-%s", salt, secret))) } func (factory *FilterFactory) Apply() error { return nil } func (factory *FilterFactory) Config() any { return factory.cfg }