service/middleware/middleware.go (93 lines of code) (raw):
package middleware
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
)
func init() {
// Unfortunatell AWS sends invalid JWT tokens and this flag is required to
// ensure they are parsed correctly. See e.g.
// https://github.com/golang-jwt/jwt/pull/117.
jwt.DecodePaddingAllowed = true
}
type InvalidEmail struct {
Email string
}
func (ie InvalidEmail) Error() string {
return "unsupported email: " + ie.Email
}
func WithRequestLog(h http.Handler) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
log.Printf("%s %s %s", req.Method, req.Host, req.URL.Path)
h.ServeHTTP(resp, req)
}
}
// withDomainPrefix adds the value of req.Host as a prefix to the path. Also
// rewrites '/' to '/index.html'. The idea is that files for a specific static
// site (host) like in the same S3 bucket but prefixed by host. E.g.
// s3://the-static-bucket/example.gutools.co.uk/index.html.
func WithDomainPrefix(h http.Handler) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
host, _, _ := strings.Cut(req.Host, ":")
path := req.URL.Path
req.URL.Path = host + path
if strings.HasSuffix(path, "/") {
req.URL.Path += "index.html"
}
log.Printf("updated req path is: %s", req.URL.Path)
h.ServeHTTP(resp, req)
}
}
// withAuth handles token validation and ensures the contained email is a
// @guardian.co.uk one. Note, the actual Open Auth flow is handled at the ALB,
// but this is required both to confirm the email domain and as an extra
// security step in case the EC2 instance is somehow accessed not via the ALB.
func WithAuth(h http.Handler) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
if strings.HasSuffix(req.URL.Path, "/_prout") {
// https://github.com/guardian/prout needs no auth, so we skip it for **/_prout
h.ServeHTTP(resp, req)
return
}
// See https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html#user-claims-encoding
tokenString := req.Header.Get("x-amzn-oidc-data")
err := auth(tokenString, keyFunc, []string{"ES256"})
if err != nil {
statusForbidden(resp, err)
return
}
h.ServeHTTP(resp, req)
}
}
func statusForbidden(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusForbidden)
log.Printf("User failed authentication with: %v", err)
fmt.Fprintln(w, "Status Forbidden (403) - you are not authorised to access this site.")
}
func auth(tokenString string, keyFunc func(token *jwt.Token) (interface{}, error), validMethods []string) error {
token, err := jwt.Parse(tokenString, keyFunc, jwt.WithValidMethods(validMethods))
if err != nil {
return fmt.Errorf("unable to parse token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return errors.New("jwt token is invalid")
}
email := fmt.Sprintf("%v", claims["email"])
if !strings.HasSuffix(email, "@guardian.co.uk") {
return InvalidEmail{email}
}
return nil
}
// Takes a token and gets the signature key.
func keyFunc(token *jwt.Token) (interface{}, error) {
region := "eu-west-1"
kid := fmt.Sprintf("%v", token.Header["kid"])
resp, err := http.Get(fmt.Sprintf("https://public-keys.auth.elb.%s.amazonaws.com/%s", region, kid))
if err != nil {
log.Printf("request for public key failed: %v", err)
return nil, err
}
defer resp.Body.Close()
pem, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("unable to read public key: %v", err)
return nil, err
}
publicKey, err := jwt.ParseECPublicKeyFromPEM(pem)
if err != nil {
log.Printf("unable to parse pem into public key: %v", err)
return nil, err
}
return publicKey, nil
}