internal/httpserver/httpserver.go (226 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more agreements.
// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.
package httpserver
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"regexp"
"strings"
"time"
"github.com/gorilla/mux"
"go.uber.org/zap"
"github.com/elastic/stream/internal/output"
)
type Server struct {
logger *zap.SugaredLogger
opts *Options
listener net.Listener
server *http.Server
ctx context.Context
}
type Options struct {
*output.Options
TLSCertificate string // TLS certificate file path.
TLSKey string // TLS key file path.
ReadTimeout time.Duration // HTTP Server read timeout.
WriteTimeout time.Duration // HTTP Server write timeout.
ConfigPath string // Config path.
ExitOnUnmatchedRule bool // If true it will exit if a request does not match any rule.
}
func New(opts *Options, logger *zap.SugaredLogger) (*Server, error) {
if opts.Addr == "" {
return nil, errors.New("a listen address is required")
}
if !(opts.TLSCertificate == "" && opts.TLSKey == "") &&
!(opts.TLSCertificate != "" && opts.TLSKey != "") {
return nil, errors.New("both TLS certificate and key files must be defined")
}
config, err := newConfigFromFile(opts.ConfigPath)
if err != nil {
return nil, err
}
notFoundHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Debugf("request did not match with any rule: %s", strRequest(r))
w.WriteHeader(404)
if opts.ExitOnUnmatchedRule {
logger.Fatalf("--exit-on-unmatched-rule is set, exiting")
}
})
handler, err := newHandlerFromConfig(config, notFoundHandler, logger)
if err != nil {
return nil, err
}
server := &http.Server{
ReadTimeout: opts.ReadTimeout,
WriteTimeout: opts.WriteTimeout,
MaxHeaderBytes: 1 << 20,
Handler: handler,
}
return &Server{
logger: logger,
opts: opts,
server: server,
}, nil
}
func (o *Server) Start(ctx context.Context) error {
o.ctx = ctx
l, err := net.Listen("tcp", o.opts.Addr)
if err != nil {
if l, err = net.Listen("tcp6", o.opts.Addr); err != nil {
return fmt.Errorf("failed to listen on address: %w", err)
}
}
o.listener = l
if o.opts.TLSCertificate != "" && o.opts.TLSKey != "" {
go func() { o.logger.Info(o.server.ServeTLS(l, o.opts.TLSCertificate, o.opts.TLSKey).Error()) }()
} else {
go func() { o.logger.Info(o.server.Serve(l).Error()) }()
}
o.logger.Debugf("listening on %s", o.listener.Addr().(*net.TCPAddr).String())
return nil
}
func (o *Server) Close() error {
o.logger.Info("shutting down http-server...")
ctx, cancel := context.WithTimeout(o.ctx, time.Second)
defer cancel()
return o.server.Shutdown(ctx)
}
func newHandlerFromConfig(config *config, notFoundHandler http.HandlerFunc, logger *zap.SugaredLogger) (http.Handler, error) {
router := mux.NewRouter()
var buf bytes.Buffer
var currInSeq int
var posInSeq int
for i, rule := range config.Rules {
rule := rule
var count int
i := i
if i > 0 {
posInSeq += len(config.Rules[i-1].Responses)
}
posInSeq := posInSeq
logger.Debugf("Setting up rule #%d for path %q", i, rule.Path)
route := router.HandleFunc(rule.Path, func(w http.ResponseWriter, r *http.Request) {
isNext := currInSeq == posInSeq+count
if config.AsSequence && !isNext {
logger.Fatalf("expecting to match request #%d in sequence, matched rule #%d instead, exiting", currInSeq, posInSeq+count)
}
response := func() *response {
switch len(rule.Responses) {
case 0:
return nil
case 1:
return &rule.Responses[0]
}
return &rule.Responses[count%len(rule.Responses)]
}()
count++
currInSeq++
logger.Debug(fmt.Sprintf("Rule #%d matched: request #%d => %s", i, count, strRequest(r)))
data := map[string]interface{}{
"req_num": count,
"request": map[string]interface{}{
"vars": mux.Vars(r),
"url": r.URL,
"headers": r.Header,
},
}
if response != nil {
for k, tpls := range response.Headers {
for _, tpl := range tpls {
buf.Reset()
if err := tpl.Execute(&buf, data); err != nil {
logger.Errorf("executing header template %s: %s, %v", k, tpl.Root.String(), err)
continue
}
w.Header().Add(k, buf.String())
}
}
w.WriteHeader(response.StatusCode)
if err := response.Body.Execute(w, data); err != nil {
logger.Errorf("executing body template %s: %v", response.Body.Root.String(), err)
}
}
})
route.Methods(rule.Methods...)
exclude := make(map[string]bool)
for key, vals := range rule.QueryParams {
if len(vals) == 0 { // Cannot use nil since ucfg interprets null as an empty slice instead of nil.
exclude[key] = true
continue
}
for _, v := range vals {
route.Queries(key, v)
}
}
route.MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) bool {
for key := range exclude {
if r.URL.Query().Has(key) {
return false
}
}
return true
})
for key, vals := range rule.RequestHeaders {
for _, v := range vals {
route.HeadersRegexp(key, v)
}
}
route.MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) bool {
user, password, _ := r.BasicAuth()
if rule.User != "" && user != rule.User {
return false
}
if rule.Password != "" && password != rule.Password {
return false
}
return true
})
var bodyRE *regexp.Regexp
if strings.HasPrefix(rule.RequestBody, "/") && strings.HasSuffix(rule.RequestBody, "/") {
re := strings.TrimPrefix(strings.TrimSuffix(rule.RequestBody, "/"), "/")
var err error
bodyRE, err = regexp.Compile(re)
if err != nil {
logger.Errorf("compiling body match regexp: %s", re, err)
}
}
route.MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) bool {
if rule.RequestBody == "" {
return true
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return false
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
if bodyRE != nil {
return bodyRE.Match(body)
}
return rule.RequestBody == string(body)
})
}
router.NotFoundHandler = notFoundHandler
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// merge together form params into the url ones to make checks easier
_ = r.ParseForm()
r.URL.RawQuery = r.Form.Encode()
router.ServeHTTP(w, r)
}), nil
}
func strRequest(r *http.Request) string {
var b strings.Builder
b.WriteString("Request path: ")
b.WriteString(r.Method)
b.WriteString(" ")
b.WriteString(r.URL.String())
b.WriteString(", Request Headers: ")
for k, v := range r.Header {
b.WriteString(fmt.Sprintf("'%s: %s' ", k, v))
}
b.WriteString(", Request Body: ")
body, _ := ioutil.ReadAll(r.Body)
b.Write(body)
return b.String()
}