internal/tls/tls.go (108 lines of code) (raw):
package tls
import (
"crypto/tls"
"crypto/x509"
"os"
"gitlab.com/gitlab-org/gitlab-pages/internal/config"
"gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter"
"gitlab.com/gitlab-org/gitlab-pages/metrics"
)
var preferredCipherSuites = []uint16{
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
}
// GetCertificateFunc returns the certificate to be used for given domain
type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
// GetConfigFunc returns a tls.Config with populated client
// auth values.
type GetConfigFunc func(*tls.ClientHelloInfo) (*tls.Config, error)
// GetTLSConfig initializes tls.Config based on config flags
// getCertificateByServerName obtains certificate based on domain
func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) {
wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey)
if err != nil {
return nil, err
}
getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Golang calls tls.Config.GetCertificate only if it's set and
// 1. ServerName != ""
// 2. Or tls.Config.Certificates is empty array
// tls.Config.Certificates contain wildcard certificate
// We want to implement rate limits via GetCertificate, so we need to call it every time
// So we don't set tls.Config.Certificates, but simulate the behavior of golang:
// 1. try to get certificate by name
// 2. if we can't, fallback to default(wildcard) certificate
customCertificate, err := getCertificateByServerName(info)
if customCertificate != nil || err != nil {
return customCertificate, err
}
return &wildcardCertificate, nil
}
TLSDomainRateLimiter := ratelimiter.New(
"tls_connections_by_domain",
ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond),
ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst),
ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
)
TLSSourceIPRateLimiter := ratelimiter.New(
"tls_connections_by_source_ip",
ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond),
ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst),
ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
)
getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate)
getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)
tlsConfig, err := getTLSConfig(cfg.TLS, getCertificate)
if err != nil {
return nil, err
}
tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return getOptionalConfig(cfg, info, getCertificate, getConfigByServerName)
}
return tlsConfig, nil
}
func getOptionalConfig(cfg *config.Config, info *tls.ClientHelloInfo, getCertificate GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) {
customConfig, err := getConfigByServerName(info)
if customConfig != nil || err != nil {
customConfig.GetCertificate = getCertificate
return customConfig, err
}
if cfg.TLS.ClientAuth == tls.NoClientCert {
return nil, nil
}
for _, i := range cfg.TLS.ClientAuthDomains {
if i != info.ServerName {
continue
}
tlsConfig, err := getTLSConfig(cfg.TLS, getCertificate)
if err != nil {
return nil, err
}
tlsConfig.ClientAuth = cfg.TLS.ClientAuth
return tlsConfig, nil
}
return nil, nil
}
func getTLSConfig(cfgTLS config.TLS, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) {
// set MinVersion to fix gosec: G402
tlsConfig := &tls.Config{GetCertificate: getCertificateByServerName, MinVersion: tls.VersionTLS12}
// GoLang 1.23 no longer supports insecure ciphers.
// Therefore, starting from GitLab 17.10, the insecure-ciphers flag is ignored.
tlsConfig.CipherSuites = preferredCipherSuites
tlsConfig.MinVersion = cfgTLS.MinVersion
tlsConfig.MaxVersion = cfgTLS.MaxVersion
if len(cfgTLS.ClientAuthDomains) == 0 {
tlsConfig.ClientAuth = cfgTLS.ClientAuth
}
if cfgTLS.ClientAuth > tls.RequestClientCert {
caCert, err := os.ReadFile(cfgTLS.ClientCert)
if err != nil {
return nil, err
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(caCert)
tlsConfig.ClientCAs = certPool
}
return tlsConfig, nil
}