in internal/tls/tls.go [34:95]
func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) {
wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey)
if err != nil {
return nil, err
}
getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Golang calls tls.Config.GetCertificate only if it's set and
// 1. ServerName != ""
// 2. Or tls.Config.Certificates is empty array
// tls.Config.Certificates contain wildcard certificate
// We want to implement rate limits via GetCertificate, so we need to call it every time
// So we don't set tls.Config.Certificates, but simulate the behavior of golang:
// 1. try to get certificate by name
// 2. if we can't, fallback to default(wildcard) certificate
customCertificate, err := getCertificateByServerName(info)
if customCertificate != nil || err != nil {
return customCertificate, err
}
return &wildcardCertificate, nil
}
TLSDomainRateLimiter := ratelimiter.New(
"tls_connections_by_domain",
ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond),
ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst),
ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
)
TLSSourceIPRateLimiter := ratelimiter.New(
"tls_connections_by_source_ip",
ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond),
ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst),
ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
)
getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate)
getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)
tlsConfig, err := getTLSConfig(cfg.TLS, getCertificate)
if err != nil {
return nil, err
}
tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return getOptionalConfig(cfg, info, getCertificate, getConfigByServerName)
}
return tlsConfig, nil
}