pkg/portal/middleware/aad.go (302 lines of code) (raw):

package middleware // Copyright (c) Microsoft Corporation. // Licensed under the Apache License 2.0. import ( "context" "crypto/rsa" "crypto/x509" "errors" "io" "net/http" "strings" "time" "github.com/Azure/go-autorest/autorest/adal" "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/sirupsen/logrus" "golang.org/x/oauth2" "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/util/log/audit" "github.com/Azure/ARO-RP/pkg/util/oidc" "github.com/Azure/ARO-RP/pkg/util/roundtripper" "github.com/Azure/ARO-RP/pkg/util/stringutils" "github.com/Azure/ARO-RP/pkg/util/uuid" ) const ( SessionName = "session" // Expiration time in unix format SessionKeyExpires = "expires" sessionKeyState = "state" sessionKeyRedirectUri = "redirect_uri" SessionKeyUsername = "user_name" SessionKeyGroups = "groups" ) // AAD is responsible for ensuring that we have a valid login session with AAD. type AAD interface { AAD(http.Handler) http.Handler CheckAuthentication(http.Handler) http.Handler Login(http.ResponseWriter, *http.Request) Logout(string) http.Handler } type oauther interface { AuthCodeURL(string, ...oauth2.AuthCodeOption) string Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) } type claims struct { Groups []string `json:"groups,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` } type aad struct { log *logrus.Entry env env.Core now func() time.Time rt http.RoundTripper tenantID string clientID string clientKey *rsa.PrivateKey clientCerts []*x509.Certificate store *sessions.CookieStore oauther oauther verifier oidc.Verifier allGroups []string sessionTimeout time.Duration } func NewAAD(log *logrus.Entry, audit *logrus.Entry, outelAuditClient audit.Client, env env.Core, baseAccessLog *logrus.Entry, hostname string, sessionKey []byte, clientID string, clientKey *rsa.PrivateKey, clientCerts []*x509.Certificate, allGroups []string, unauthenticatedRouter *mux.Router, verifier oidc.Verifier) (*aad, error) { if len(sessionKey) != 32 { return nil, errors.New("invalid sessionKey") } endpoint := oauth2.Endpoint{ AuthURL: env.Environment().ActiveDirectoryEndpoint + env.TenantID() + "/oauth2/v2.0/authorize", TokenURL: env.Environment().ActiveDirectoryEndpoint + env.TenantID() + "/oauth2/v2.0/token", } a := &aad{ log: log, env: env, now: time.Now, rt: http.DefaultTransport, tenantID: env.TenantID(), clientID: clientID, clientKey: clientKey, clientCerts: clientCerts, store: sessions.NewCookieStore(sessionKey), oauther: &oauth2.Config{ ClientID: clientID, Endpoint: endpoint, RedirectURL: "https://" + hostname + "/callback", Scopes: []string{ "openid", "profile", }, }, verifier: verifier, allGroups: allGroups, sessionTimeout: time.Hour, } a.store.MaxAge(0) a.store.Options.Secure = true a.store.Options.HttpOnly = true a.store.Options.SameSite = http.SameSiteLaxMode unauthenticatedRouter.NewRoute().Methods(http.MethodGet).Path("/callback").Handler(Log(env, audit, baseAccessLog, outelAuditClient)(http.HandlerFunc(a.callback))) unauthenticatedRouter.NewRoute().Methods(http.MethodGet).Path("/api/login").Handler(Log(env, audit, baseAccessLog, outelAuditClient)(http.HandlerFunc(a.Login))) unauthenticatedRouter.NewRoute().Methods(http.MethodPost).Path("/api/logout").Handler(Log(env, audit, baseAccessLog, outelAuditClient)(a.Logout("/"))) return a, nil } // AAD is the early stage handler which adds a username to the context if it // can. It lets the request through regardless (this is so that failures can be // logged). func (a *aad) AAD(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := a.store.Get(r, SessionName) if err != nil { cookieError, ok := err.(securecookie.Error) if ok && cookieError != nil && cookieError.IsDecode() { cookie := &http.Cookie{ Name: SessionName, Path: "/", Expires: time.Unix(0, 0), } http.SetCookie(w, cookie) http.Redirect(w, r, "/api/login", http.StatusTemporaryRedirect) } else { a.internalServerError(w, err) } return } expires, ok := session.Values[SessionKeyExpires].(int64) if !ok || time.Unix(expires, 0).Before(a.now()) { h.ServeHTTP(w, r) return } ctx := r.Context() ctx = context.WithValue(ctx, ContextKeyUsername, session.Values[SessionKeyUsername]) ctx = context.WithValue(ctx, ContextKeyGroups, session.Values[SessionKeyGroups]) r = r.WithContext(ctx) h.ServeHTTP(w, r) }) } // CheckAuthentication is the handler which prevents access to requests without // valid authentication. func (a *aad) CheckAuthentication(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if ctx.Value(ContextKeyUsername) == nil { if r.URL != nil { redirect := "/api/login" if r.URL.Path != "" { redirect += "?" + sessionKeyRedirectUri + "=" + r.URL.Path } http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) return } http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } h.ServeHTTP(w, r) }) } // Login will redirect the user to a login page. func (a *aad) Login(w http.ResponseWriter, r *http.Request) { a.redirect(w, r) } func (a *aad) Logout(url string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := a.store.Get(r, SessionName) if err != nil { a.internalServerError(w, err) return } session.Values = nil err = session.Save(r, w) if err != nil { a.internalServerError(w, err) return } http.Redirect(w, r, url, http.StatusSeeOther) }) } func (a *aad) redirect(w http.ResponseWriter, r *http.Request) { session, err := a.store.Get(r, SessionName) if err != nil { a.internalServerError(w, err) return } state := uuid.DefaultGenerator.Generate() session.Values = map[interface{}]interface{}{ sessionKeyState: state, } if r.URL.Query().Has(sessionKeyRedirectUri) { session.Values[sessionKeyRedirectUri] = r.URL.Query().Get(sessionKeyRedirectUri) } err = session.Save(r, w) if err != nil { a.internalServerError(w, err) return } http.Redirect(w, r, a.oauther.AuthCodeURL(state), http.StatusTemporaryRedirect) } func (a *aad) callback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() session, err := a.store.Get(r, SessionName) if err != nil { a.internalServerError(w, err) return } state, ok := session.Values[sessionKeyState].(string) if !ok { a.redirect(w, r) return } delete(session.Values, sessionKeyState) err = session.Save(r, w) if err != nil { a.internalServerError(w, err) return } if r.FormValue("state") != state { a.internalServerError(w, errors.New("state mismatch")) return } if r.FormValue("error") != "" { err := r.FormValue("error") if r.FormValue("error_description") != "" { err += ": " + r.FormValue("error_description") } a.internalServerError(w, errors.New(err)) return } cliCtx := context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ Transport: roundtripper.RoundTripperFunc(a.clientAssertion), }) token, err := a.oauther.Exchange(cliCtx, r.FormValue("code")) if err != nil { a.internalServerError(w, err) return } rawIDToken, ok := token.Extra("id_token").(string) if !ok { a.internalServerError(w, errors.New("id_token not found")) return } idToken, err := a.verifier.Verify(r.Context(), rawIDToken) if err != nil { a.internalServerError(w, err) return } var claims claims err = idToken.Claims(&claims) if err != nil { a.internalServerError(w, err) return } groupsIntersect := stringutils.GroupsIntersect(a.allGroups, claims.Groups) if len(groupsIntersect) == 0 { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) } session.Values[SessionKeyUsername] = claims.PreferredUsername session.Values[SessionKeyGroups] = groupsIntersect session.Values[SessionKeyExpires] = a.now().Add(a.sessionTimeout).Unix() redirectUri := "/" if v, ok := session.Values[sessionKeyRedirectUri]; ok { redirectUri = v.(string) delete(session.Values, sessionKeyRedirectUri) } err = session.Save(r, w) if err != nil { a.internalServerError(w, err) return } http.Redirect(w, r, redirectUri, http.StatusTemporaryRedirect) } // clientAssertion adds a JWT client assertion according to // https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials // Treating this as a RoundTripper is more hackery -- this is because the // underlying oauth2 library is a little unextensible. func (a *aad) clientAssertion(req *http.Request) (*http.Response, error) { oauthConfig, err := adal.NewOAuthConfig(a.env.Environment().ActiveDirectoryEndpoint, a.tenantID) if err != nil { return nil, err } sp, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, a.clientID, a.clientCerts[0], a.clientKey, "unused") if err != nil { return nil, err } s := &adal.ServicePrincipalCertificateSecret{ Certificate: a.clientCerts[0], PrivateKey: a.clientKey, } err = req.ParseForm() if err != nil { return nil, err } err = s.SetAuthenticationValues(sp, &req.Form) if err != nil { return nil, err } form := req.Form.Encode() req.Body = io.NopCloser(strings.NewReader(form)) req.ContentLength = int64(len(form)) return a.rt.RoundTrip(req) } func (a *aad) internalServerError(w http.ResponseWriter, err error) { a.log.Warn(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }