helpers/jwt.go (121 lines of code) (raw):

package helpers import ( "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "github.com/golang-jwt/jwt/v4" "io/ioutil" "log" "net/http" "os" "regexp" "strings" "github.com/MicahParks/keyfunc" "time" ) func extractAuth(h *http.Request) (string, error) { authHeader := h.Header.Get("Authorization") if authHeader == "" { return "", errors.New("no authorization header provided") } if !strings.HasPrefix(authHeader, "Bearer") { return "", errors.New("expecting token auth") } xtrctor := regexp.MustCompile("^([^\\s]+)\\s+(.*)$") matches := xtrctor.FindAllStringSubmatch(authHeader, -1) if matches == nil { return "", errors.New("no token in auth header") } return matches[0][2], nil } func extractPublicKey(certPem string) (*rsa.PublicKey, error) { block, _ := pem.Decode([]byte(certPem)) //var cert *x509.Certificate cert, err := x509.ParseCertificate(block.Bytes) if err != nil { log.Print("ERROR helpers.extractPublicKey could not parse incoming certificate: ", err) return nil, err } rsaPublicKey := cert.PublicKey.(*rsa.PublicKey) return rsaPublicKey, nil } func LoadPublicKey(filePath string) (string, error) { f, openErr := os.Open(filePath) if openErr != nil { log.Printf("ERROR helpers.LoadPublicKey could not load cert from '%s': %s", filePath, openErr) return "", openErr } defer f.Close() data, readErr := ioutil.ReadAll(f) if readErr != nil { log.Printf("ERROR helpers.LoadPublicKey could not read in all data from '%s': %s", filePath, readErr) return "", readErr } return string(data), nil } func ValidateLogin(h *http.Request, config *Config) (string, error) { rawData, rawErr := extractAuth(h) if rawErr != nil { return "", rawErr } var token *jwt.Token if strings.HasPrefix(config.JWT.CertFile, "http") { options := keyfunc.Options{ RefreshErrorHandler: func(err error) { log.Printf("ERROR There was an error with the jwt.Keyfunc\nError: %s", err.Error()) }, RefreshInterval: time.Hour, RefreshRateLimit: time.Minute * 5, RefreshTimeout: time.Second * 10, RefreshUnknownKID: true, } jwks, err := keyfunc.Get(config.JWT.CertFile, options) if err != nil { log.Printf("ERROR Failed to create JWKS from resource at the given URL.\nError: %s", err.Error()) return "", errors.New("Error loading JWKS from given URL") } if token, err = jwt.Parse(rawData, jwks.Keyfunc); err != nil { log.Printf("ERROR Failed to parse the JWT.\nError: %s", err.Error()) return "", errors.New("Error parsing token using JWKS") } } else { publicCertData, loadErr := LoadPublicKey(config.JWT.CertFile) if loadErr != nil { return "", errors.New("Server setup problem, see logs") } var tokErr error token, tokErr = jwt.Parse(rawData, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } return extractPublicKey(publicCertData) }) if tokErr != nil { if validationErr, isValidationErr := tokErr.(*jwt.ValidationError); isValidationErr { log.Print("ERROR helpers.jwt.ValidateLogin could not validate: ", validationErr.Error()) if (validationErr.Errors | jwt.ValidationErrorExpired) != 0 { return "", errors.New("Token is expired") } } else { log.Printf("ERROR helpers.jwt.ValidateLogin could not validate token '%s': %s", rawData, tokErr) } return "", errors.New("Internal validation error") } } if !token.Valid { log.Printf("ERROR token %s is not valid", rawData) return "", errors.New("token is not valid") } if claims, claimsOk := token.Claims.(jwt.MapClaims); claimsOk { for _, claimName := range config.JWT.UserNameClaims { if username, hasUsername := claims[claimName]; hasUsername { return username.(string), nil } } log.Printf("ERROR helpers.jwt.ValidateLogin token validated but could not get a username from any of %v", config.JWT.UserNameClaims) return "", errors.New("no username claim") } log.Printf("ERROR helpers.jwt.ValidateLogin claims data was not present or incorrect") return "", errors.New("incorrect claims data") }