internal/plugin/plugin.go (169 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plugin import ( "errors" "fmt" "net" "net/http" "sync" hreqc "github.com/api7/ext-plugin-proto/go/A6/HTTPReqCall" hrespc "github.com/api7/ext-plugin-proto/go/A6/HTTPRespCall" flatbuffers "github.com/google/flatbuffers/go" inHTTP "github.com/apache/apisix-go-plugin-runner/internal/http" "github.com/apache/apisix-go-plugin-runner/internal/util" pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http" "github.com/apache/apisix-go-plugin-runner/pkg/log" ) type ParseConfFunc func(in []byte) (conf interface{}, err error) type RequestFilterFunc func(conf interface{}, w http.ResponseWriter, r pkgHTTP.Request) type ResponseFilterFunc func(conf interface{}, w pkgHTTP.Response) type pluginOpts struct { ParseConf ParseConfFunc RequestFilter RequestFilterFunc ResponseFilter ResponseFilterFunc } type pluginRegistries struct { sync.Mutex opts map[string]*pluginOpts } type ErrPluginRegistered struct { name string } func (err ErrPluginRegistered) Error() string { return fmt.Sprintf("plugin %s registered", err.name) } var ( pluginRegistry = pluginRegistries{opts: map[string]*pluginOpts{}} ErrMissingName = errors.New("missing name") ErrMissingParseConfMethod = errors.New("missing ParseConf method") ErrMissingRequestFilterMethod = errors.New("missing RequestFilter method") ErrMissingResponseFilterMethod = errors.New("missing ResponseFilter method") RequestPhase = requestPhase{} ResponsePhase = responsePhase{} ) func RegisterPlugin(name string, pc ParseConfFunc, sv RequestFilterFunc, rsv ResponseFilterFunc) error { log.Infof("register plugin %s", name) if name == "" { return ErrMissingName } if pc == nil { return ErrMissingParseConfMethod } if sv == nil { return ErrMissingRequestFilterMethod } if rsv == nil { return ErrMissingResponseFilterMethod } opt := &pluginOpts{ ParseConf: pc, RequestFilter: sv, ResponseFilter: rsv, } pluginRegistry.Lock() defer pluginRegistry.Unlock() if _, found := pluginRegistry.opts[name]; found { return ErrPluginRegistered{name} } pluginRegistry.opts[name] = opt return nil } func findPlugin(name string) *pluginOpts { if opt, found := pluginRegistry.opts[name]; found { return opt } return nil } type requestPhase struct { } func (ph *requestPhase) filter(conf RuleConf, w *inHTTP.ReqResponse, r *inHTTP.Request) error { for _, c := range conf { plugin := findPlugin(c.Name) if plugin == nil { log.Warnf("can't find plugin %s, skip", c.Name) continue } log.Infof("run plugin %s", c.Name) plugin.RequestFilter(c.Value, w, r) if w.HasChange() { // response is generated, no need to continue break } } return nil } func (ph *requestPhase) builder(id uint32, resp *inHTTP.ReqResponse, req *inHTTP.Request) *flatbuffers.Builder { builder := util.GetBuilder() if resp != nil && resp.FetchChanges(id, builder) { return builder } if req != nil && req.FetchChanges(id, builder) { return builder } hreqc.RespStart(builder) hreqc.RespAddId(builder, id) res := hreqc.RespEnd(builder) builder.Finish(res) return builder } func HTTPReqCall(buf []byte, conn net.Conn) (*flatbuffers.Builder, error) { req := inHTTP.CreateRequest(buf) req.BindConn(conn) defer inHTTP.ReuseRequest(req) resp := inHTTP.CreateReqResponse() defer inHTTP.ReuseReqResponse(resp) token := req.ConfToken() conf, err := GetRuleConf(token) if err != nil { return nil, err } err = RequestPhase.filter(conf, resp, req) if err != nil { return nil, err } id := req.ID() builder := RequestPhase.builder(id, resp, req) return builder, nil } type responsePhase struct { } func (ph *responsePhase) filter(conf RuleConf, w *inHTTP.Response) error { for _, c := range conf { plugin := findPlugin(c.Name) if plugin == nil { log.Warnf("can't find plugin %s, skip", c.Name) continue } log.Infof("run plugin %s", c.Name) plugin.ResponseFilter(c.Value, w) if w.HasChange() { // response is generated, no need to continue break } } return nil } func (ph *responsePhase) builder(id uint32, resp *inHTTP.Response) *flatbuffers.Builder { builder := util.GetBuilder() if resp != nil && resp.FetchChanges(builder) { return builder } hrespc.RespStart(builder) hrespc.RespAddId(builder, id) res := hrespc.RespEnd(builder) builder.Finish(res) return builder } func HTTPRespCall(buf []byte, conn net.Conn) (*flatbuffers.Builder, error) { resp := inHTTP.CreateResponse(buf) resp.BindConn(conn) defer inHTTP.ReuseResponse(resp) token := resp.ConfToken() conf, err := GetRuleConf(token) if err != nil { return nil, err } err = ResponsePhase.filter(conf, resp) if err != nil { return nil, err } id := resp.ID() return ResponsePhase.builder(id, resp), nil }