pkg/authority/cert/storage.go (157 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cert
import (
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"math"
"os"
"reflect"
"sync"
"time"
"github.com/apache/dubbo-admin/pkg/authority/config"
"github.com/apache/dubbo-admin/pkg/logger"
)
type storageImpl struct {
Storage
mutex *sync.Mutex
stopChan chan os.Signal
caValidity int64
certValidity int64
rootCert *Cert
authorityCert *Cert
trustedCerts []*Cert
serverNames []string
serverCerts *Cert
}
type Storage interface {
GetServerCert(serverName string) *tls.Certificate
RefreshServerCert()
SetAuthorityCert(*Cert)
GetAuthorityCert() *Cert
SetRootCert(*Cert)
GetRootCert() *Cert
AddTrustedCert(*Cert)
GetTrustedCerts() []*Cert
GetStopChan() chan os.Signal
}
type Cert struct {
Cert *x509.Certificate
CertPem string
PrivateKey *ecdsa.PrivateKey
tlsCert *tls.Certificate
}
func NewStorage(options *config.Options) *storageImpl {
return &storageImpl{
mutex: &sync.Mutex{},
stopChan: make(chan os.Signal, 1),
authorityCert: &Cert{},
trustedCerts: []*Cert{},
certValidity: options.CertValidity,
caValidity: options.CaValidity,
}
}
func (c *Cert) IsValid() bool {
if c.Cert == nil || c.CertPem == "" || c.PrivateKey == nil {
return false
}
if time.Now().Before(c.Cert.NotBefore) || time.Now().After(c.Cert.NotAfter) {
return false
}
if c.tlsCert == nil || !reflect.DeepEqual(c.tlsCert.PrivateKey, c.PrivateKey) {
tlsCert, err := tls.X509KeyPair([]byte(c.CertPem), []byte(EncodePrivateKey(c.PrivateKey)))
if err != nil {
return false
}
c.tlsCert = &tlsCert
}
return true
}
func (c *Cert) NeedRefresh() bool {
if c.Cert == nil || c.CertPem == "" || c.PrivateKey == nil {
return true
}
if time.Now().Before(c.Cert.NotBefore) || time.Now().After(c.Cert.NotAfter) {
return true
}
validity := c.Cert.NotAfter.UnixMilli() - c.Cert.NotBefore.UnixMilli()
if time.Now().Add(time.Duration(math.Floor(float64(validity)*0.2)) * time.Millisecond).After(c.Cert.NotAfter) {
return true
}
if !reflect.DeepEqual(c.Cert.PublicKey, c.PrivateKey.Public()) {
return true
}
return false
}
func (c *Cert) GetTlsCert() *tls.Certificate {
if c.tlsCert != nil && reflect.DeepEqual(c.tlsCert.PrivateKey, c.PrivateKey) {
return c.tlsCert
}
tlsCert, err := tls.X509KeyPair([]byte(c.CertPem), []byte(EncodePrivateKey(c.PrivateKey)))
if err != nil {
logger.Sugar().Warnf("Failed to load x509 cert. %v", err)
}
c.tlsCert = &tlsCert
return c.tlsCert
}
func (s *storageImpl) GetServerCert(serverName string) *tls.Certificate {
nameSigned := serverName == ""
for _, name := range s.serverNames {
if name == serverName {
nameSigned = true
break
}
}
if nameSigned && s.serverCerts != nil && s.serverCerts.IsValid() {
return s.serverCerts.GetTlsCert()
}
s.mutex.Lock()
defer s.mutex.Unlock()
if !nameSigned {
s.serverNames = append(s.serverNames, serverName)
}
s.serverCerts = SignServerCert(s.authorityCert, s.serverNames, s.certValidity)
return s.serverCerts.GetTlsCert()
}
func (s *storageImpl) RefreshServerCert() {
interval := math.Min(math.Floor(float64(s.certValidity)/100), 10_000)
for true {
select {
case <-s.stopChan:
return
default:
}
time.Sleep(time.Duration(interval) * time.Millisecond)
func() {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.authorityCert == nil || !s.authorityCert.IsValid() {
// ignore if authority cert is invalid
return
}
if s.serverCerts == nil || !s.serverCerts.IsValid() {
logger.Sugar().Infof("Server cert is invalid, refresh it.")
s.serverCerts = SignServerCert(s.authorityCert, s.serverNames, s.certValidity)
}
}()
}
}
func (s *storageImpl) SetAuthorityCert(cert *Cert) {
s.authorityCert = cert
}
func (s *storageImpl) GetAuthorityCert() *Cert {
return s.authorityCert
}
func (s *storageImpl) SetRootCert(cert *Cert) {
s.rootCert = cert
}
func (s *storageImpl) GetRootCert() *Cert {
return s.rootCert
}
func (s *storageImpl) AddTrustedCert(cert *Cert) {
s.trustedCerts = append(s.trustedCerts, cert)
}
func (s *storageImpl) GetTrustedCerts() []*Cert {
return s.trustedCerts
}
func (s *storageImpl) GetStopChan() chan os.Signal {
return s.stopChan
}