proxy/handler/oauth2/oauth2_handler.go (92 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package oauth2 import ( "context" "errors" "github.com/go-chassis/go-chassis/v2/core/handler" "github.com/go-chassis/go-chassis/v2/core/invocation" "github.com/go-chassis/openlog" "net/http" "time" ) // errors var ( ErrInvalidState = errors.New("invalid state") ErrInvalidCode = errors.New("invalid code") ErrInvalidToken = errors.New("invalid authorization") ErrInvalidAuth = errors.New("invalid authentication") ErrExpiredToken = errors.New("expired token") ) // AuthName is the auth style const AuthName = "oauth2" // Random is a state value const Random = "random" // Handler is is a oauth2 pre process raw data in handler type Handler struct { } // Handle is provider func (oa *Handler) Handle(chain *handler.Chain, inv *invocation.Invocation, cb invocation.ResponseCallBack) { if auth != nil && auth.GrantType == "authorization_code" { if req, ok := inv.Args.(*http.Request); ok { state := req.FormValue("state") if state != Random && state != "" { WriteBackErr(ErrInvalidState, http.StatusUnauthorized, cb) return } code := req.FormValue("code") if code == "" { WriteBackErr(ErrInvalidCode, http.StatusUnauthorized, cb) return } accessToken, err := getToken(code, cb) if err != nil { openlog.Error("authorization error: " + err.Error()) WriteBackErr(ErrInvalidToken, http.StatusUnauthorized, cb) return } if auth.Authenticate != nil { err = auth.Authenticate(accessToken, req) if err != nil { openlog.Error("authentication error: " + err.Error()) WriteBackErr(ErrInvalidAuth, http.StatusUnauthorized, cb) return } } } } chain.Next(inv, func(r *invocation.Response) { cb(r) }) } // getToken deal with the authorization code and return the token func getToken(code string, cb invocation.ResponseCallBack) (accessToken string, err error) { if auth.UseConfig != nil { config := auth.UseConfig token, err := config.Exchange(context.Background(), code) if err != nil { openlog.Error("get token failed, errors: " + err.Error()) WriteBackErr(ErrInvalidCode, http.StatusUnauthorized, cb) return "", err } // set the expiry token in 30 minutes token.Expiry = time.Now().Add(30 * 60 * time.Second) if time.Now().After(token.Expiry) { return "", ErrExpiredToken } accessToken = token.AccessToken return accessToken, nil } return "", nil } // Name returns router string func (oa *Handler) Name() string { return AuthName } // NewOAuth2 returns new auth handler func NewOAuth2() handler.Handler { return &Handler{} } func init() { err := handler.RegisterHandler(AuthName, NewOAuth2) if err != nil { openlog.Error("register handler error: " + err.Error()) return } } // WriteBackErr write err and callback func WriteBackErr(err error, status int, cb invocation.ResponseCallBack) { r := &invocation.Response{ Err: err, Status: status, } cb(r) }