internal/namespaceinpath/responsewriter.go (70 lines of code) (raw):
package namespaceinpath
import (
"net"
"net/http"
"net/url"
"strings"
"gitlab.com/gitlab-org/labkit/log"
)
// responseWriter is a wrapper around http.ResponseWriter that captures the response
type responseWriter struct {
http.ResponseWriter
pagesDomain string
authRedirectURI string
}
// newResponseWriter creates a new responseWriter
func newResponseWriter(w http.ResponseWriter, pagesDomain string, authRedirectURI string) *responseWriter {
return &responseWriter{
ResponseWriter: w,
pagesDomain: pagesDomain,
authRedirectURI: authRedirectURI,
}
}
// WriteHeader captures the status code and rewrites the location url to namespace in path
func (w *responseWriter) WriteHeader(statusCode int) {
header := w.ResponseWriter.Header()
if header.Get("Location") != "" {
parsedLocationURL, err := url.Parse(header.Get("Location"))
if err == nil && !w.isAuthURL(parsedLocationURL) {
newURL := &customURL{URL: parsedLocationURL, pagesDomain: w.pagesDomain}
if err = newURL.convertToNamespaceInPath(); err != nil {
log.WithFields(log.Fields{
"orig_host": parsedLocationURL.Host,
"orig_path": parsedLocationURL.Path,
"pages_domain": w.pagesDomain,
}).WithError(err).Error("while writing location header, couldn't convert URL")
} else {
log.WithFields(log.Fields{
"orig_host": parsedLocationURL.Host,
"orig_path": parsedLocationURL.Path,
"new_location": newURL.URL.String(),
}).Debug("while writing location header, converted URL")
w.Header().Set("Location", newURL.URL.String())
}
}
}
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseWriter) isAuthURL(reqURL *url.URL) bool {
if w.authRedirectURI == "" {
return false
}
if isAuthRedirectURL(reqURL, w.authRedirectURI) {
return true
}
_, port, _ := net.SplitHostPort(reqURL.Host)
pagesDomainWithPort := w.pagesDomain
if port != "" {
pagesDomainWithPort = w.pagesDomain + ":" + port
}
authSegment := ""
if reqURL.Host == pagesDomainWithPort {
// if namespace in path
segments := strings.Split(strings.TrimPrefix(reqURL.Path, "/"), "/")
if len(segments) == 2 {
authSegment = segments[1]
}
} else if strings.HasSuffix(reqURL.Host, pagesDomainWithPort) {
// if namespace in host
segments := strings.Split(strings.TrimPrefix(reqURL.Path, "/"), "/")
if len(segments) == 1 {
authSegment = segments[0]
}
}
return authSegment == "auth"
}