pkg/tls/reloader.go (310 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Package tls provides common TLS utilities for HTTP and gRPC servers.
package tls
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"os"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/pkg/errors"
"github.com/apache/skywalking-banyandb/pkg/logger"
)
// Reloader manages dynamic reloading of TLS certificates and keys for servers.
//
//nolint:govet
type Reloader struct {
cert *tls.Certificate
watcher *fsnotify.Watcher
log *logger.Logger
debounceTimer *time.Timer
updateCh chan struct{}
certFile string
keyFile string
lastCertHash []byte
lastKeyHash []byte
mu sync.RWMutex
}
// NewReloader creates a new TLSReloader instance.
func NewReloader(certFile, keyFile string, log *logger.Logger) (*Reloader, error) {
if certFile == "" || keyFile == "" {
return nil, errors.New("certFile and keyFile must be provided")
}
if log == nil {
return nil, errors.New("logger must not be nil")
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, errors.Wrap(err, "failed to create fsnotify watcher")
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
watcher.Close()
return nil, errors.Wrap(err, "failed to load initial TLS certificate")
}
log.Info().Str("certFile", certFile).Str("keyFile", keyFile).Msg("Successfully loaded initial TLS certificates")
tr := &Reloader{
certFile: certFile,
keyFile: keyFile,
cert: &cert,
log: log,
watcher: watcher,
updateCh: make(chan struct{}, 1),
}
// Compute initial hashes
tr.lastCertHash, _ = tr.computeFileHash(certFile)
if keyFile != "" {
tr.lastKeyHash, _ = tr.computeFileHash(keyFile)
}
return tr, nil
}
// NewClientCertReloader creates a reloader that only monitors a CA certificate without requiring a key.
// This is useful for client-side certificate verification where only the CA cert is needed.
func NewClientCertReloader(certFile string, log *logger.Logger) (*Reloader, error) {
if certFile == "" {
return nil, errors.New("certFile must be provided")
}
if log == nil {
return nil, errors.New("logger must not be nil")
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, errors.Wrap(err, "failed to create fsnotify watcher")
}
// Read the cert file content to ensure it exists and is valid
certPEM, err := os.ReadFile(certFile)
if err != nil {
watcher.Close()
return nil, errors.Wrap(err, "failed to read certificate file")
}
// Ensure the certificate is parsable
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certPEM) {
watcher.Close()
return nil, errors.New("failed to parse PEM certificate")
}
log.Info().Str("certFile", certFile).Msg("Successfully loaded initial client certificate")
tr := &Reloader{
certFile: certFile,
keyFile: "", // No key file for client certs
log: log,
watcher: watcher,
updateCh: make(chan struct{}, 1),
}
// Compute initial cert hash
tr.lastCertHash, _ = tr.computeFileHash(certFile)
return tr, nil
}
// Start begins monitoring the TLS certificate and key files for changes.
func (r *Reloader) Start() error {
r.log.Info().Str("certFile", r.certFile).Str("keyFile", r.keyFile).Msg("Starting TLS file monitoring")
err := r.watcher.Add(r.certFile)
if err != nil {
return errors.Wrapf(err, "failed to watch cert file: %s", r.certFile)
}
// Only add key file watcher if a key file was provided
if r.keyFile != "" {
err = r.watcher.Add(r.keyFile)
if err != nil {
return errors.Wrapf(err, "failed to watch key file: %s", r.keyFile)
}
}
go r.watchFiles()
return nil
}
func (r *Reloader) isFileStable(filePath string) bool {
const (
checks = 3
delay = 200 * time.Millisecond
maxRetries = 5
)
var lastSize int64
retryCount := 0
for i := 0; i < checks; i++ {
if fi, err := os.Stat(filePath); err == nil {
retryCount = 0 // reset retry count if file exists
if i > 0 && fi.Size() != lastSize {
return false
}
lastSize = fi.Size()
} else if os.IsNotExist(err) {
retryCount++
if retryCount > maxRetries {
r.log.Error().Str("file", filePath).Msgf("File does not exist after %d retries", maxRetries)
return false
}
i--
time.Sleep(time.Second)
r.log.Debug().Str("file", filePath).Msg("File does not exist, retrying")
continue
}
time.Sleep(delay)
}
return true
}
// computeFileHash calculates a SHA-256 hash of a file's contents.
func (r *Reloader) computeFileHash(filePath string) ([]byte, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, errors.Wrapf(err, "failed to read file for hashing: %s", filePath)
}
h := sha256.New()
h.Write(content)
return h.Sum(nil), nil
}
// scheduleReloadAttempt debounces reload attempts to avoid excessive reloads.
func (r *Reloader) scheduleReloadAttempt() {
// Create or reset the debounce timer
if r.debounceTimer == nil {
r.debounceTimer = time.AfterFunc(500*time.Millisecond, func() {
// Check if content has changed before reloading
changed, newCertHash, newKeyHash, err := r.checkContentChanged()
if err != nil {
r.log.Error().Err(err).Msg("Error checking if certificate content changed")
return
}
if !changed {
r.log.Debug().Msg("Certificate content unchanged, skipping reload")
return
}
// Content has changed, reload certificate
if err := r.reloadCertificate(newCertHash, newKeyHash); err != nil {
r.log.Error().Err(err).Msg("Failed to reload TLS certificate")
} else {
r.log.Info().Msg("Successfully updated TLS certificate after content change")
}
})
} else {
r.debounceTimer.Reset(500 * time.Millisecond)
}
}
// checkContentChanged checks if file contents have changed and returns new hashes.
func (r *Reloader) checkContentChanged() (bool, []byte, []byte, error) {
// Check if cert file has changed
currentCertHash, err := r.computeFileHash(r.certFile)
if err != nil {
return false, nil, nil, errors.Wrap(err, "failed to compute current cert hash")
}
certChanged := !bytes.Equal(r.lastCertHash, currentCertHash)
// If no key file, return just cert info
if r.keyFile == "" {
return certChanged, currentCertHash, nil, nil
}
// Check if key file has changed
currentKeyHash, err := r.computeFileHash(r.keyFile)
if err != nil {
return false, nil, nil, errors.Wrap(err, "failed to compute current key hash")
}
keyChanged := !bytes.Equal(r.lastKeyHash, currentKeyHash)
return certChanged || keyChanged, currentCertHash, currentKeyHash, nil
}
func (r *Reloader) watchFiles() {
r.log.Info().Msg("TLS file watcher loop started")
for {
select {
case event, ok := <-r.watcher.Events:
if !ok {
r.log.Info().Msg("Watcher events channel closed")
return
}
r.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).Msg("Detected file event")
// Handle all relevant file operation events
if event.Op&(fsnotify.Remove|fsnotify.Create|fsnotify.Write|fsnotify.Rename) != 0 {
// Special handling for removal/creation
if event.Op&(fsnotify.Remove|fsnotify.Create) != 0 {
r.log.Info().Str("file", event.Name).Msg("File removed or created, performing stability checks")
// Remove from watcher first to avoid duplicate watches
_ = r.watcher.Remove(event.Name)
// Wait for file operations to complete
time.Sleep(1 * time.Second)
// Try to re-add files to watcher with retries
maxRetries := 5
for i := 0; i < maxRetries; i++ {
if event.Name == r.certFile {
if r.isFileStable(r.certFile) {
if err := r.watcher.Add(r.certFile); err != nil {
r.log.Error().Err(err).Str("file", r.certFile).Msg("Failed to re-add cert file to watcher")
} else {
r.log.Debug().Str("file", r.certFile).Msg("Re-added cert file to watcher")
break
}
}
} else if event.Name == r.keyFile {
if r.isFileStable(r.keyFile) {
if err := r.watcher.Add(r.keyFile); err != nil {
r.log.Error().Err(err).Str("file", r.keyFile).Msg("Failed to re-add key file to watcher")
} else {
r.log.Debug().Str("file", r.keyFile).Msg("Re-added key file to watcher")
break
}
}
}
if i < maxRetries-1 {
time.Sleep(500 * time.Millisecond)
} else {
logger.Panicf("Failed to re-add file to watcher after %d attempts", maxRetries)
}
}
} else {
r.log.Info().Str("file", event.Name).Msg("Detected certificate modification")
time.Sleep(200 * time.Millisecond) // Ensure file is fully written
}
// Schedule a reload attempt with debouncing for all types of events
r.scheduleReloadAttempt()
}
case err, ok := <-r.watcher.Errors:
if !ok {
r.log.Info().Msg("Watcher errors channel closed")
return
}
r.log.Error().Err(err).Msg("Error in file watcher")
}
}
}
// notifyUpdate sends update notification to the update channel.
func (r *Reloader) notifyUpdate() {
select {
case r.updateCh <- struct{}{}:
r.log.Debug().Msg("Sent certificate update notification")
default:
r.log.Warn().Msg("Update channel is full, notification skipped")
}
}
// reloadCertificate reloads the certificate from disk.
func (r *Reloader) reloadCertificate(newCertHash, newKeyHash []byte) error {
r.log.Debug().Msg("Reloading TLS certificate")
// For client certificates (no key file), just verify the certificate is valid
if r.keyFile == "" {
certPEM, err := os.ReadFile(r.certFile)
if err != nil {
return errors.Wrap(err, "failed to read certificate file")
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certPEM) {
return errors.New("failed to parse PEM certificate")
}
// Update the stored hash
r.lastCertHash = newCertHash
r.log.Debug().Msg("Client certificate updated in memory")
r.notifyUpdate()
return nil
}
// For server certificates with key files, load the key pair
newCert, err := tls.LoadX509KeyPair(r.certFile, r.keyFile)
if err != nil {
return errors.Wrap(err, "failed to reload TLS certificate")
}
// Update certificate and hashes
r.mu.Lock()
r.cert = &newCert
r.lastCertHash = newCertHash
r.lastKeyHash = newKeyHash
r.mu.Unlock()
r.log.Debug().Msg("TLS certificate updated in memory")
r.notifyUpdate()
return nil
}
// getCertificate returns the current certificate.
func (r *Reloader) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.cert, nil
}
// GetUpdateChannel returns a channel that will be triggered when a certificate is updated.
func (r *Reloader) GetUpdateChannel() <-chan struct{} {
return r.updateCh
}
// Stop gracefully stops the TLS reloader.
func (r *Reloader) Stop() {
r.log.Info().Msg("Stopping TLS Reloader")
if err := r.watcher.Close(); err != nil {
r.log.Error().Err(err).Msg("Failed to close fsnotify watcher")
}
}
// GetTLSConfig returns a TLS config using this reloader's certificate.
func (r *Reloader) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: r.getCertificate,
MinVersion: tls.VersionTLS12,
NextProtos: []string{"h2"},
}
}
// GetClientTLSConfig returns a TLS config for client-side certificate validation.
func (r *Reloader) GetClientTLSConfig(serverName string) (*tls.Config, error) {
// Read the certificate file
certPEM, err := os.ReadFile(r.certFile)
if err != nil {
return nil, errors.Wrap(err, "failed to read certificate file")
}
// Create a certificate pool
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certPEM) {
return nil, errors.New("failed to parse PEM certificate")
}
// Create TLS config for client
return &tls.Config{
RootCAs: certPool,
ServerName: serverName,
MinVersion: tls.VersionTLS12,
}, nil
}