func GetTLSConfig()

in internal/tls/tls.go [34:95]


func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) {
	wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey)
	if err != nil {
		return nil, err
	}

	getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
		// Golang calls tls.Config.GetCertificate only if it's set and
		// 1. ServerName != ""
		// 2. Or tls.Config.Certificates is empty array
		// tls.Config.Certificates contain wildcard certificate
		// We want to implement rate limits via GetCertificate, so we need to call it every time
		// So we don't set tls.Config.Certificates, but simulate the behavior of golang:
		// 1. try to get certificate by name
		// 2. if we can't, fallback to default(wildcard) certificate
		customCertificate, err := getCertificateByServerName(info)

		if customCertificate != nil || err != nil {
			return customCertificate, err
		}

		return &wildcardCertificate, nil
	}

	TLSDomainRateLimiter := ratelimiter.New(
		"tls_connections_by_domain",
		ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
		ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
		ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
		ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
		ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
		ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond),
		ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst),
		ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
	)

	TLSSourceIPRateLimiter := ratelimiter.New(
		"tls_connections_by_source_ip",
		ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
		ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
		ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
		ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
		ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
		ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond),
		ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst),
		ratelimiter.WithBypassCIDRs(cfg.RateLimit.RateLimitBypassCIDRs),
	)

	getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate)
	getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)

	tlsConfig, err := getTLSConfig(cfg.TLS, getCertificate)
	if err != nil {
		return nil, err
	}

	tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
		return getOptionalConfig(cfg, info, getCertificate, getConfigByServerName)
	}

	return tlsConfig, nil
}