serve.go (233 lines of code) (raw):
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ezcx
import (
"context"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
var (
ServerDefaultSignals []os.Signal = []os.Signal{
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGHUP,
}
)
// HandlerFunc is an adapter that converts a given ezcx.HandlerFunc into an http.Handler.
type HandlerFunc func(*WebhookResponse, *WebhookRequest) error
// Implementing ServeHTTP allows the ezcx.HandlerFunc to satisfy the http.Handler interface.
//
// Error handling is an area of future improvement. For instance, if a required parameter
// is missing, it should be up to the developer to handle that i.e.: return an HTTP error (400, 500)
// or return a ResponseMessage indicating something went wrong...
func (h HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
req, err := WebhookRequestFromRequest(r)
if err != nil {
log.Println("Error during WebhookRequestFromRequest")
log.Println(err)
return
}
req.ctx = r.Context // flowing down the requests's Context added..
res := req.InitializeResponse()
err = h(res, req)
if err != nil {
log.Println("Error during HandlerFunc execution")
log.Println(err)
return
}
err = res.WriteResponse(w)
if err != nil {
log.Println("Error during WebhookResponse.WriteResponse")
return
}
}
func DefaultHealthCheck(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
type Server struct {
signals []os.Signal
signal chan os.Signal
errs chan error
server *http.Server
mux *http.ServeMux
lg *log.Logger
hc http.HandlerFunc
}
func NewServer(ctx context.Context, addr string, lg *log.Logger, signals ...os.Signal) *Server {
ctx = context.WithValue(ctx, Logger, lg)
return new(Server).Init(ctx, addr, lg, signals...)
}
func (s *Server) Init(ctx context.Context, addr string, lg *log.Logger, signals ...os.Signal) *Server {
if len(signals) == 0 {
s.signals = ServerDefaultSignals
} else {
// rethink this later on. We need to make sure there at least
// the right group of signals!
s.signals = signals
}
s.signal = make(chan os.Signal, 1)
signal.Notify(s.signal, s.signals...)
if lg == nil {
lg = log.Default()
}
s.lg = lg
s.errs = make(chan error)
s.mux = http.NewServeMux()
s.hc = DefaultHealthCheck
s.mux.HandleFunc("/health", s.hc)
s.server = &http.Server{
Addr: addr,
Handler: s.mux,
BaseContext: func(l net.Listener) context.Context { return ctx },
}
return s
}
// SetHandler allows the user to set a custom mux or handler.
func (s *Server) SetHandler(h http.Handler) {
s.server.Handler = h
if s.isMux(h) {
s.mux = h.(*http.ServeMux)
} else {
s.mux = nil
}
}
// ServeMux returns a pointer to the currently set mux.
func (s *Server) ServeMux() *http.ServeMux {
return s.mux
}
func (s *Server) isMux(h http.Handler) bool {
_, ok := h.(*http.ServeMux)
return ok
}
var (
reservedPaths = map[string]struct{}{
"admin": {},
"health": {},
}
)
// HandleCx registers the handler for the given pattern. While the HandleCx method itself
// isn't safe for concurrent usage, the underlying method it wraps (*ServeMux).Handle IS guarded
// by a mutex.
func (s *Server) HandleCx(pattern string, handler HandlerFunc) {
pathParts := strings.Split(pattern, "/")
if len(pathParts) >= 2 {
pathPrefix := pathParts[1]
_, ok := reservedPaths[pathPrefix]
if ok {
s.lg.Fatal("admin, health are reserved path prefixes")
}
}
s.mux.Handle(pattern, handler)
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. ListenAndServe is responsible for handling signals
// and managing graceful shutdown(s) whenever the right signals are intercepted.
func (s *Server) ListenAndServe(ctx context.Context) {
defer func() {
close(s.errs)
close(s.signal)
}()
// Run ListenAndServe on a separate goroutine.
s.lg.Printf("EZCX server listening and serving on %s\n", s.server.Addr)
go func() {
err := s.server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
s.lg.Println(err)
s.errs <- err
close(s.errs)
}
}()
for {
select {
// If the context is done, we need to return.
case <-ctx.Done():
s.lg.Println("EZCX server context is done")
err := ctx.Err()
if err != nil {
s.lg.Print("EZCX server context error...")
s.lg.Println(err)
}
return
// If there's a non-nil error, we need to return
case err := <-s.errs:
if err != nil {
s.lg.Print("EZCX server non-nil error...")
s.lg.Println(err)
return
}
case sig := <-s.signal:
s.lg.Printf("EZCX server signal %s received...", sig)
switch sig {
case syscall.SIGHUP:
s.lg.Println("EZCX reconfigure", sig)
err := s.Reconfigure()
if err != nil {
s.errs <- err
}
default:
s.lg.Printf("EZCX graceful shutdown initiated...")
err := s.Shutdown(ctx)
if err != nil {
s.lg.Println(err)
} else {
s.lg.Println("EZCX shutdown SUCCESS")
}
return
}
}
}
}
func (s *Server) ListenAndServeTLS(ctx context.Context, certFile, keyFile string) {
defer func() {
close(s.errs)
close(s.signal)
}()
// Run ListenAndServe on a separate goroutine.
s.lg.Printf("EZCX server listening and serving on %s\n", s.server.Addr)
go func() {
err := s.server.ListenAndServeTLS(certFile, keyFile)
if err != nil && err != http.ErrServerClosed {
s.lg.Println(err)
s.errs <- err
close(s.errs)
}
}()
for {
select {
// If the context is done, we need to return.
case <-ctx.Done():
s.lg.Println("EZCX server context is done")
err := ctx.Err()
if err != nil {
s.lg.Print("EZCX server context error...")
s.lg.Println(err)
}
return
// If there's a non-nil error, we need to return
case err := <-s.errs:
if err != nil {
s.lg.Print("EZCX server non-nil error...")
s.lg.Println(err)
return
}
case sig := <-s.signal:
s.lg.Printf("EZCX server signal %s received...", sig)
switch sig {
case syscall.SIGHUP:
s.lg.Println("EZCX reconfigure", sig)
err := s.Reconfigure()
if err != nil {
s.errs <- err
}
default:
s.lg.Printf("EZCX graceful shutdown initiated...")
err := s.Shutdown(ctx)
if err != nil {
s.lg.Println(err)
} else {
s.lg.Println("EZCX shutdown SUCCESS")
}
return
}
}
}
}
// Omitted for now.
func (s *Server) Reconfigure() error {
return nil
}
// Shutdown provides graceful shutdown for the entire ezcx Server
func (s *Server) Shutdown(ctx context.Context) error {
timeout, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
err := s.server.Shutdown(timeout)
if err != nil {
return err
}
return nil
}