go/storage/certdatabase.go (542 lines of code) (raw):
package storage
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/bluele/gcache"
"github.com/golang/glog"
"github.com/google/certificate-transparency-go/x509"
"github.com/google/renameio"
"github.com/mozilla/crlite/go"
)
const (
permModeDir = 0755
kMoveSerialsBatchSize = 1000
)
func serialListExpiryLine(aExpDate types.ExpDate) string {
return fmt.Sprintf("@%016x", aExpDate.Unix())
}
func WriteSerialList(w io.Writer, aExpDate types.ExpDate, aIssuer types.Issuer, aSerials []types.Serial) error {
writer := bufio.NewWriter(w)
defer writer.Flush()
// Write the expiry date for this collection of serial numbers as a unix
// timestamp encoded as a zero-padded 16 digit hex string. The expiry
// date is prefixed by "@" to distinguish it from a serial number.
_, err := writer.WriteString(serialListExpiryLine(aExpDate))
if err != nil {
return err
}
err = writer.WriteByte('\n')
if err != nil {
return err
}
for _, s := range aSerials {
_, err := writer.WriteString(s.HexString())
if err != nil {
return err
}
err = writer.WriteByte('\n')
if err != nil {
return err
}
}
return nil
}
type CertDatabase struct {
cache RemoteCache
cacheAccessors gcache.Cache
storageDir string
readOnlyStorage bool
}
func NewCertDatabase(aCache RemoteCache, aStorageDir string, aReadOnlyStorage bool) (CertDatabase, error) {
db := CertDatabase{
cache: aCache,
cacheAccessors: gcache.New(8 * 1024).ARC().Build(),
storageDir: aStorageDir,
readOnlyStorage: aReadOnlyStorage,
}
_, err := os.Stat(db.serialsDir())
if os.IsNotExist(err) && !aReadOnlyStorage {
err := os.MkdirAll(db.serialsDir(), permModeDir)
if err != nil {
return db, err
}
}
return db, nil
}
func (db *CertDatabase) EnsureCacheIsConsistent() error {
storageEpoch, err := db.getStorageEpoch()
if err != nil {
return err
}
cacheEpoch, err := db.cache.GetEpoch()
if err != nil {
return err
}
if cacheEpoch == storageEpoch+1 || (cacheEpoch == 0 && storageEpoch == 0) {
return nil
}
// The epochs are inconsistent, so we'll reset the cached log states
// based on what's in storage. This ensures that the ct-fetch process
// downloads a portion of each log that is contiguous with what's
// already in storage.
logStates, err := db.GetCTLogsFromStorage()
if err != nil {
return err
}
return db.cache.Restore(storageEpoch+1, logStates)
}
func (db *CertDatabase) GetIssuerAndDatesFromChannel(reader <-chan string) ([]types.IssuerDate, error) {
// The channel entries are strings of the form "serials::<date>::<issuer id>".
// We gather these by issuer to obtain a list of the form
// [(issuer 1, [date 1, date 2, ...]), (issuer 2, [...]), ...].
issuerMap := make(map[string]types.IssuerDate)
for entry := range reader {
parts := strings.Split(entry, "::")
if len(parts) != 3 {
return []types.IssuerDate{}, fmt.Errorf("Unexpected key format: %s", entry)
}
issuer := types.NewIssuerFromString(parts[2])
expDate, err := types.NewExpDate(parts[1])
if err != nil {
glog.Warningf("Couldn't parse expiration date %s: %s", entry, err)
continue
}
_, ok := issuerMap[issuer.ID()]
if !ok {
issuerMap[issuer.ID()] = types.IssuerDate{
Issuer: issuer,
ExpDates: make([]types.ExpDate, 0),
}
}
tmp := issuerMap[issuer.ID()]
tmp.ExpDates = append(tmp.ExpDates, expDate)
issuerMap[issuer.ID()] = tmp
}
issuerList := make([]types.IssuerDate, 0, len(issuerMap))
for _, v := range issuerMap {
issuerList = append(issuerList, v)
}
return issuerList, nil
}
func (db *CertDatabase) GetIssuerAndDatesFromCache() ([]types.IssuerDate, error) {
// The cache stores sets of serial numbers in bins that are keyed by strings
// of the form "serials::<date>::<issuer id>".
allChan := make(chan string)
go func() {
err := db.cache.KeysToChan("serials::*", allChan)
if err != nil {
glog.Fatalf("Couldn't list from cache")
}
}()
return db.GetIssuerAndDatesFromChannel(allChan)
}
func (db *CertDatabase) GetIssuerAndDatesFromStorage() ([]types.IssuerDate, error) {
// The storage directory has the following structure:
// storageDir
// ├─ serials
// ├─ issuer::<issuer id 1>
// ├─ serials::<date 1>::<issuer id 1>
// ├─ serials::<date 2>::<issuer id 1>
// ...
// ├─ issuer::<issuer id 2>
// ├─ serials::<date 1>::<issuer id 2>
// ├─ serials::<date 2>::<issuer id 2>
// ...
// ...
//
allChan := make(chan string)
go func() {
defer close(allChan)
issuerDirs, err := os.ReadDir(db.serialsDir())
if err != nil {
glog.Fatal(err)
}
for _, issuerDir := range issuerDirs {
issuerName := issuerDir.Name()
issuerDirFull := filepath.Join(db.serialsDir(), issuerName)
if !(issuerDir.IsDir() && strings.HasPrefix(issuerName, "issuer::")) {
continue
}
serialFiles, err := os.ReadDir(issuerDirFull)
if err != nil {
glog.Fatal(err)
}
for _, file := range serialFiles {
name := file.Name()
if strings.HasPrefix(name, "serials::") {
allChan <- name
}
}
}
}()
return db.GetIssuerAndDatesFromChannel(allChan)
}
func (db *CertDatabase) removeExpiredSerialsFromStorage(t time.Time) error {
issuerDirs, err := os.ReadDir(db.serialsDir())
if err != nil {
return err
}
for _, issuerDir := range issuerDirs {
issuerName := issuerDir.Name()
issuerDirFull := filepath.Join(db.serialsDir(), issuerName)
if !(issuerDir.IsDir() && strings.HasPrefix(issuerName, "issuer::")) {
continue
}
serialFiles, err := os.ReadDir(issuerDirFull)
if err != nil {
return err
}
for _, serialFile := range serialFiles {
name := serialFile.Name()
serialFileFull := filepath.Join(issuerDirFull, name)
parts := strings.Split(name, "::")
if len(parts) != 3 {
glog.Warningf("Unexpected serial file name: %s", name)
continue
}
expDate, err := types.NewExpDate(parts[1])
if err != nil {
glog.Warningf("Couldn't parse expiration date %s: %s", name, err)
continue
}
if expDate.IsExpiredAt(t) {
os.Remove(serialFileFull)
}
}
// If the issuerDir is now empty, remove it
serialFiles, err = os.ReadDir(issuerDirFull)
if err != nil {
return err
}
if len(serialFiles) == 0 {
os.Remove(issuerDirFull)
continue
}
}
return nil
}
func (db *CertDatabase) Migrate(aLogData *types.CTLogMetadata) error {
return db.cache.Migrate(aLogData)
}
func (db *CertDatabase) SaveLogState(aLogObj *types.CTLogState) error {
return db.cache.StoreLogState(aLogObj)
}
func (db *CertDatabase) GetLogState(aUrl *url.URL) (*types.CTLogState, error) {
shortUrl := fmt.Sprintf("%s%s", aUrl.Host, strings.TrimRight(aUrl.Path, "/"))
log, cacheErr := db.cache.LoadLogState(shortUrl)
if log != nil {
return log, cacheErr
}
glog.Warningf("Allocating brand new log for %+v, cache err=%v", shortUrl, cacheErr)
return &types.CTLogState{
ShortURL: shortUrl,
}, nil
}
func (db *CertDatabase) Store(aCert *x509.Certificate, aIssuer *x509.Certificate,
aLogURL string, aEntryId int64) error {
expDate := types.NewExpDateFromTime(aCert.NotAfter)
issuer := types.NewIssuer(aIssuer)
serialWriter := db.GetSerialCacheAccessor(expDate, issuer)
serial := types.NewSerial(aCert)
_, err := serialWriter.Insert(serial)
if err != nil {
return err
}
return nil
}
func (db *CertDatabase) serialsDir() string {
return filepath.Join(db.storageDir, "serials")
}
func (db *CertDatabase) issuerDir(aIssuer types.Issuer) string {
return filepath.Join(db.serialsDir(), "issuer::"+aIssuer.ID())
}
func (db *CertDatabase) serialFile(aExpDate types.ExpDate, aIssuer types.Issuer) string {
issuerDir := db.issuerDir(aIssuer)
return filepath.Join(issuerDir, "serials::"+aExpDate.ID()+"::"+aIssuer.ID())
}
func (db *CertDatabase) epochFile() string {
return filepath.Join(db.storageDir, "epoch")
}
func (db *CertDatabase) coverageFile() string {
return filepath.Join(db.storageDir, "ct-logs.json")
}
func (db *CertDatabase) GetCTLogsFromStorage() ([]types.CTLogState, error) {
ctLogFD, err := os.Open(db.coverageFile())
if err != nil {
return nil, err
}
defer ctLogFD.Close()
// Decode the JSON data
ctLogList := make([]types.CTLogState, 0)
decoder := json.NewDecoder(ctLogFD)
err = decoder.Decode(&ctLogList)
if err != nil {
return nil, err
}
return ctLogList, nil
}
func (db *CertDatabase) GetSerialCacheAccessor(aExpDate types.ExpDate, aIssuer types.Issuer) *SerialCacheWriter {
var kc *SerialCacheWriter
id := aIssuer.ID() + aExpDate.ID()
cacheObj, err := db.cacheAccessors.GetIFPresent(id)
if err != nil {
if err == gcache.KeyNotFoundError {
kc = NewSerialCacheWriter(aExpDate, aIssuer, db.cache)
err = db.cacheAccessors.Set(id, kc)
if err != nil {
glog.Fatalf("Couldn't set into the cache expDate=%s issuer=%s from cache: %s",
aExpDate, aIssuer.ID(), err)
}
} else {
glog.Fatalf("Couldn't load expDate=%s issuer=%s from cache: %s",
aExpDate, aIssuer.ID(), err)
}
} else {
kc = cacheObj.(*SerialCacheWriter)
}
if kc == nil {
panic("kc is null")
}
return kc
}
func (db *CertDatabase) ReadSerialsFromCache(aExpDate types.ExpDate, aIssuer types.Issuer) []types.Serial {
accessor := db.GetSerialCacheAccessor(aExpDate, aIssuer)
return accessor.List()
}
func (db *CertDatabase) ReadSerialsFromStorage(aExpDate types.ExpDate, aIssuer types.Issuer) ([]types.Serial, error) {
path := db.serialFile(aExpDate, aIssuer)
fd, err := os.Open(path)
if errors.Is(err, os.ErrNotExist) {
// No serials with this issuer and expiry
return nil, nil
}
if err != nil {
return nil, err
}
defer fd.Close()
scanner := bufio.NewScanner(fd)
// The first line encodes the expiry date of the serials in the file
if scanner.Scan() {
line := scanner.Text()
expectedExpiryLine := serialListExpiryLine(aExpDate)
if line != expectedExpiryLine {
return nil, fmt.Errorf("Unexpected expiry line. Found '%s', expected '%s'", line, expectedExpiryLine)
}
}
var serialList []types.Serial
for scanner.Scan() {
line := scanner.Text()
serialList = append(serialList, types.NewSerialFromHex(line))
}
if err := scanner.Err(); err != nil {
return nil, err
}
return serialList, nil
}
func (db *CertDatabase) moveOneBinOfCachedSerialsToStorage(aTmpDir string, aExpDate types.ExpDate, aIssuer types.Issuer) error {
cachedSerials := db.ReadSerialsFromCache(aExpDate, aIssuer)
if len(cachedSerials) == 0 {
return nil
}
storedSerials, err := db.ReadSerialsFromStorage(aExpDate, aIssuer)
if err != nil {
return err
}
// Concatenate the serial lists and remove any duplicates
serials := append(storedSerials, cachedSerials...)
serials = types.SerialList(serials).Dedup()
// Write the merged serial list to a temporary file, and atomically
// overwrite the storage file if all goes well.
path := db.serialFile(aExpDate, aIssuer)
t, err := renameio.TempFile(aTmpDir, path)
if err != nil {
return err
}
defer t.Cleanup()
err = WriteSerialList(t, aExpDate, aIssuer, serials)
if err != nil {
return err
}
err = t.CloseAtomicallyReplace()
if err != nil {
return err
}
// It's now safe to remove cachedSerials from the cache.
cacheWriter := db.GetSerialCacheAccessor(aExpDate, aIssuer)
err = cacheWriter.RemoveMany(cachedSerials)
if err != nil {
glog.Warningf("Failed to remove serial from cache: %s", err)
}
return nil
}
func (db *CertDatabase) moveCachedSerialsToStorage() error {
issuerList, err := db.GetIssuerAndDatesFromCache()
if err != nil {
return err
}
for _, issuerDate := range issuerList {
issuer := issuerDate.Issuer
tmpDir := renameio.TempDir(db.issuerDir(issuer))
err = os.MkdirAll(tmpDir, permModeDir)
if err != nil {
return err
}
batchSize := kMoveSerialsBatchSize
for start := 0; start < len(issuerDate.ExpDates); start += batchSize {
if start+batchSize > len(issuerDate.ExpDates) {
batchSize = len(issuerDate.ExpDates) - start
}
glog.Infof("[%s] Moving %d expiry bins to storage.", issuer.ID(), batchSize)
errChan := make(chan error, batchSize)
var wg sync.WaitGroup
wg.Add(batchSize)
for i := start; i < start+batchSize; i++ {
go func(expDate types.ExpDate) {
errChan <- db.moveOneBinOfCachedSerialsToStorage(tmpDir, expDate, issuer)
wg.Done()
}(issuerDate.ExpDates[i])
}
wg.Wait()
close(errChan)
for err := range errChan {
if err != nil {
return err
}
}
}
}
return nil
}
func (db *CertDatabase) moveOneBinOfAliasedSerials(aTmpDir string, aExpDate types.ExpDate, aPreIssuer types.Issuer, aIssuer types.Issuer) error {
aliasedSerials, err := db.ReadSerialsFromStorage(aExpDate, aPreIssuer)
if err != nil {
return err
}
if len(aliasedSerials) > 0 {
glog.Infof("[%s] Moving %d aliased serials from %s", aIssuer.ID(), len(aliasedSerials), aPreIssuer.ID())
} else {
return nil
}
storedSerials, err := db.ReadSerialsFromStorage(aExpDate, aIssuer)
if err != nil {
return err
}
// Concatenate the serial lists and remove any duplicates
serials := append(storedSerials, aliasedSerials...)
serials = types.SerialList(serials).Dedup()
// Write the merged serial list to a temporary file, and atomically
// overwrite the issuer's file if all goes well.
path := db.serialFile(aExpDate, aIssuer)
t, err := renameio.TempFile(aTmpDir, path)
if err != nil {
return err
}
defer t.Cleanup()
err = WriteSerialList(t, aExpDate, aIssuer, serials)
if err != nil {
return err
}
err = t.CloseAtomicallyReplace()
if err != nil {
return err
}
return nil
}
func (db *CertDatabase) moveAliasedSerials() error {
issuerAndDatesList, err := db.GetIssuerAndDatesFromStorage()
if err != nil {
return err
}
for _, issuerAndDates := range issuerAndDatesList {
preissuer := issuerAndDates.Issuer
preissuerDates := issuerAndDates.ExpDates
aliases, err := db.cache.GetPreIssuerAliases(preissuer)
if err != nil {
return err
}
for _, issuer := range aliases {
tmpDir := renameio.TempDir(db.issuerDir(issuer))
err = os.MkdirAll(tmpDir, permModeDir)
if err != nil {
return err
}
for _, expDate := range preissuerDates {
err = db.moveOneBinOfAliasedSerials(tmpDir, expDate, preissuer, issuer)
if err != nil {
return err
}
}
}
}
return nil
}
func (db *CertDatabase) getStorageEpoch() (uint64, error) {
fd, err := os.Open(db.epochFile())
if errors.Is(err, os.ErrNotExist) {
return 0, nil
}
if err != nil {
return 0, err
}
defer fd.Close()
scanner := bufio.NewScanner(fd)
if scanner.Scan() {
return strconv.ParseUint(scanner.Text(), 10, 64)
}
if err = scanner.Err(); err != nil {
return 0, err
}
return 0, nil
}
func (db *CertDatabase) Commit(aProofOfLock string) error {
// Commit() moves serials from cache to storage, removes expired serial
// numbers from storage, and updates the coverage metadata file. This is
// done in four steps:
// 1) coverage metadata is retrieved from cache and written to a
// temporary file,
// 2) cached serials are moved to persistent storage,
// 3) the coverage metadata file is atomically overwritten with the
// temporary file from step 1,
// 4) expired serial numbers are removed from storage.
// This sequence of operations ensures that the coverage metadata file
// describes a subset of the stored serials at the end of step 3. (It
// will typically be a strict subset, as the commit process is intended
// to run in parallel with ct-fetch).
//
// The caller must hold the commit lock (i.e. the caller must store a
// random value under the key `lock::commit` in the cache and then
// provide that value here as `aProofOfLock`).
//
// The epoch value in storage must be one less than the epoch value in
// cache (unless this is the first time that Commit() has been called,
// in which case both epochs will be equal to 0).
if db.readOnlyStorage {
return fmt.Errorf("Cannot commit serials to read-only storage")
}
hasLock, err := db.cache.HasCommitLock(aProofOfLock)
if err != nil {
return err
}
if !hasLock {
return errors.New("Caller must hold commit lock")
}
storageEpoch, err := db.getStorageEpoch()
if err != nil {
return err
}
cacheEpoch, err := db.cache.GetEpoch()
if err != nil {
return err
}
if (cacheEpoch != storageEpoch+1) && !(cacheEpoch == 0 && storageEpoch == 0) {
return errors.New("Inconsistent cache and storage epochs. Restart ct-fetch.")
}
logList, err := db.cache.LoadAllLogStates()
if err != nil {
return err
}
ctLogFD, err := renameio.TempFile("", db.coverageFile())
if err != nil {
return err
}
defer ctLogFD.Cleanup()
enc := json.NewEncoder(ctLogFD)
if err = enc.Encode(logList); err != nil {
return err
}
err = db.moveCachedSerialsToStorage()
if err != nil {
return err
}
err = ctLogFD.CloseAtomicallyReplace()
if err != nil {
return err
}
err = db.removeExpiredSerialsFromStorage(time.Now())
if err != nil {
return err
}
err = db.moveAliasedSerials()
if err != nil {
return err
}
// The data on disk is in a good state and we just have to increment
// the cache and storage epochs. We can ignore some errors here as long
// as the end result is that the cache is one epoch ahead of storage.
epochFD, err := renameio.TempFile("", db.epochFile())
if err != nil {
glog.Warningf("Failed to increment epochs: %s", err)
return nil
}
defer epochFD.Cleanup()
writer := bufio.NewWriter(epochFD)
_, err = writer.WriteString(fmt.Sprintf("%v\n", cacheEpoch))
if err != nil {
glog.Warningf("Failed to increment epochs: %s", err)
return nil
}
writer.Flush()
err = db.cache.NextEpoch()
if err != nil {
glog.Warningf("Failed to increment epochs: %s", err)
return nil
}
err = epochFD.CloseAtomicallyReplace()
if err != nil {
// This is the one case where we get inconsistent epochs.
return err
}
return nil
}
func (db *CertDatabase) AddPreIssuerAlias(aPreIssuer types.Issuer, aIssuer types.Issuer) error {
return db.cache.AddPreIssuerAlias(aPreIssuer, aIssuer)
}