internal/namespaceinpath/responsewriter.go (70 lines of code) (raw):

package namespaceinpath import ( "net" "net/http" "net/url" "strings" "gitlab.com/gitlab-org/labkit/log" ) // responseWriter is a wrapper around http.ResponseWriter that captures the response type responseWriter struct { http.ResponseWriter pagesDomain string authRedirectURI string } // newResponseWriter creates a new responseWriter func newResponseWriter(w http.ResponseWriter, pagesDomain string, authRedirectURI string) *responseWriter { return &responseWriter{ ResponseWriter: w, pagesDomain: pagesDomain, authRedirectURI: authRedirectURI, } } // WriteHeader captures the status code and rewrites the location url to namespace in path func (w *responseWriter) WriteHeader(statusCode int) { header := w.ResponseWriter.Header() if header.Get("Location") != "" { parsedLocationURL, err := url.Parse(header.Get("Location")) if err == nil && !w.isAuthURL(parsedLocationURL) { newURL := &customURL{URL: parsedLocationURL, pagesDomain: w.pagesDomain} if err = newURL.convertToNamespaceInPath(); err != nil { log.WithFields(log.Fields{ "orig_host": parsedLocationURL.Host, "orig_path": parsedLocationURL.Path, "pages_domain": w.pagesDomain, }).WithError(err).Error("while writing location header, couldn't convert URL") } else { log.WithFields(log.Fields{ "orig_host": parsedLocationURL.Host, "orig_path": parsedLocationURL.Path, "new_location": newURL.URL.String(), }).Debug("while writing location header, converted URL") w.Header().Set("Location", newURL.URL.String()) } } } w.ResponseWriter.WriteHeader(statusCode) } func (w *responseWriter) isAuthURL(reqURL *url.URL) bool { if w.authRedirectURI == "" { return false } if isAuthRedirectURL(reqURL, w.authRedirectURI) { return true } _, port, _ := net.SplitHostPort(reqURL.Host) pagesDomainWithPort := w.pagesDomain if port != "" { pagesDomainWithPort = w.pagesDomain + ":" + port } authSegment := "" if reqURL.Host == pagesDomainWithPort { // if namespace in path segments := strings.Split(strings.TrimPrefix(reqURL.Path, "/"), "/") if len(segments) == 2 { authSegment = segments[1] } } else if strings.HasSuffix(reqURL.Host, pagesDomainWithPort) { // if namespace in host segments := strings.Split(strings.TrimPrefix(reqURL.Path, "/"), "/") if len(segments) == 1 { authSegment = segments[0] } } return authSegment == "auth" }