ci/internal/cli/mitresync.go (234 lines of code) (raw):
package cli
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/lmittmann/tint"
"gitlab.com/gitlab-org/cves/internal/cve"
"gitlab.com/gitlab-org/cves/internal/git"
"gitlab.com/gitlab-org/cves/internal/mitre"
)
const (
exitNormal = iota
exitError
)
func MitreSync(stdout io.Writer, args []string) int {
start := time.Now()
flags, err := parseFlags(stdout, args)
if err != nil {
return exitError
}
ctx := context.Background()
initLogger(stdout, flags)
if flags.sinceCommit == "" {
return fatal("since commit is empty or not set", nil)
}
repo, err := git.NewRepository(flags.repo)
if err != nil {
return fatal("error initializing repository", err, "path", flags.repo)
}
slog.Info("starting mitresync", "time", time.Now().UTC().Format(time.RFC1123Z), "sinceCommit", flags.sinceCommit, "repo", repo.Dir())
slog.Debug("debugging enabled")
if flags.dryrun {
slog.Warn("DRY RUN ENABLED: data changing actions will not be performed")
}
if err := processModified(ctx, repo, flags.sinceCommit, flags); err != nil {
slog.Warn("finished with errors", "dur", time.Since(start).Round(time.Millisecond))
return exitError
}
slog.Info("finished with no errors", "dur", time.Since(start).Round(time.Millisecond))
return exitNormal
}
func processModified(ctx context.Context, repo *git.Repository, commit string, flags *mitreSyncFlags) error {
records, err := modifiedRecords(ctx, repo, commit)
if err != nil {
slog.Error("error getting modified records", "error", err)
return err
}
if !flags.update {
if len(records) > 0 {
slog.Warn("skipping updating of modified records on MITRE", "records", records)
}
return nil
}
if len(records) == 0 {
slog.Info("no modified records to update on MITRE")
return nil
}
mitreClient, err := initMITREClient(flags)
if err != nil {
slog.Error("error initializing MITRE API client", "error", err)
return fmt.Errorf("initializing MITRE API client: %w", err)
}
var errs error
for i, record := range records {
logger := slog.With("record", record, "progress", fmt.Sprintf("%d/%d", i+1, len(records)))
published, err := mitreClient.GetRecord(ctx, record.CveMetadata.CveID)
if err != nil {
errs = errors.Join(errs, err)
logger.Error("error fetching record on MITRE", "error", err)
continue
}
if equalContainers(record.Containers.Cna, published.Containers.Cna) {
logger.Info("record has no changes to CNA container; skipping")
continue
}
if err := mitreClient.UpdateRecord(ctx, record.CveMetadata.CveID, &record.Containers.Cna); err != nil {
errs = errors.Join(errs, err)
logger.Error("error updating modified record on MITRE", "error", err)
}
logger.Info("updated modified record on MITRE")
}
return errs
}
func modifiedRecords(ctx context.Context, repo *git.Repository, commit string) ([]*cve.Record, error) {
files, err := repo.ModifiedFiles(ctx, commit)
if err != nil {
return nil, fmt.Errorf("getting modified files since commit: %w", err)
}
return recordsFromFiles(files)
}
func recordsFromFiles(files []string) ([]*cve.Record, error) {
filtered, err := filterRecordFiles(files)
if err != nil {
return nil, fmt.Errorf("filtering record files: %w", err)
}
records := make([]*cve.Record, 0, len(filtered))
for _, name := range filtered {
f, err := os.Open(name)
if err != nil {
return nil, fmt.Errorf("opening record file: %w", err)
}
record, err := cve.RecordFromReader(f)
f.Close()
if err != nil {
return nil, fmt.Errorf("creating record from %s: %w", name, err)
}
records = append(records, record)
}
return records, nil
}
func filterRecordFiles(files []string) ([]string, error) {
slog.Debug("filtering record files", "files", files)
records := make([]string, 0, len(files))
for _, file := range files {
logger := slog.With("file", file)
if !strings.HasPrefix(filepath.Base(file), "CVE-") {
logger.Debug("excluding file: basename does not start with CVE-")
continue
}
if filepath.Ext(file) != ".json" {
logger.Debug("excluding file: extension is not .json")
continue
}
f, err := os.Open(file)
if err != nil {
return nil, fmt.Errorf("opening record file: %w", err)
}
version, err := cve.RecordVersion(f)
f.Close()
if err != nil {
if errors.Is(err, cve.ErrNotCVERecord) {
logger.Debug("excluding file: does not contain a CVE record")
continue
}
return nil, fmt.Errorf("determining record version in %s: %w", file, err)
}
if !strings.HasPrefix(version, "5.") {
logger.Debug("excluding file: does not contain a v5 CVE record", "version", version)
continue
}
logger.Debug("including file")
records = append(records, file)
}
return records, nil
}
func equalContainers(a, b cve.CnaEdContainer) bool {
// Provider metadata is cleared before comparison as it can contain timestamps
// that are not relevant for the comparison.
a.ProviderMetadata = cve.ProviderMetadata{}
b.ProviderMetadata = cve.ProviderMetadata{}
return reflect.DeepEqual(a, b)
}
type mitreSyncFlags struct {
sinceCommit string
update bool
mitreBaseURL string
dryrun bool
debug bool
repo string
args []string
}
func parseFlags(stdout io.Writer, args []string) (*mitreSyncFlags, error) {
f := mitreSyncFlags{
repo: "..",
}
fs := flag.NewFlagSet("mitresync", flag.ContinueOnError)
fs.StringVar(&f.sinceCommit, "since-commit", "", "check for record file changes since commit SHA")
fs.BoolVar(&f.update, "update", true, "update modified records on MITRE")
fs.StringVar(&f.mitreBaseURL, "mitre-base-url", mitre.DefaultBaseURL, "base URL to use for MITRE API requests")
fs.BoolVar(&f.dryrun, "dry-run", false, "don't perform data changing actions")
fs.BoolVar(&f.debug, "debug", false, "log debugging information")
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
fs.Usage()
return nil, err
}
return nil, fmt.Errorf("parsing flags: %w", err)
}
if f.sinceCommit == "" {
f.sinceCommit = os.Getenv("CI_COMMIT_BEFORE_SHA")
}
f.args = fs.Args()
if len(f.args) > 0 {
f.repo = f.args[0]
}
return &f, nil
}
func initLogger(stdout io.Writer, f *mitreSyncFlags) {
opts := &tint.Options{
Level: slog.LevelInfo,
TimeFormat: time.TimeOnly,
}
if f.debug {
opts.Level = slog.LevelDebug
}
// Disable colors and stabilize time values with zero values if testing.
if _, ok := os.LookupEnv("TEST_MITRESYNC"); ok {
opts.NoColor = true
opts.ReplaceAttr = func(groups []string, attr slog.Attr) slog.Attr {
switch attr.Value.Any().(type) {
case time.Time:
return slog.Time(attr.Key, time.Time{})
case time.Duration:
return slog.Duration(attr.Key, time.Duration(0))
default:
return attr
}
}
}
logger := slog.New(tint.NewHandler(stdout, opts))
slog.SetDefault(logger)
}
func initMITREClient(f *mitreSyncFlags) (*mitre.Client, error) {
return mitre.NewClient(
os.Getenv("MITRE_CNA_UUID"),
os.Getenv("MITRE_CNA_SHORTNAME"),
os.Getenv("MITRE_API_USER"),
os.Getenv("MITRE_API_KEY"),
mitre.WithBaseURL(f.mitreBaseURL),
mitre.WithDryRun(f.dryrun),
)
}
func fatal(message string, err error, args ...any) int {
if err != nil {
args = append(args, "error", err)
}
slog.Error("fatal: "+message, args...)
return exitError
}