ci/internal/cli/mitresync.go (234 lines of code) (raw):

package cli import ( "context" "errors" "flag" "fmt" "io" "log/slog" "os" "path/filepath" "reflect" "strings" "time" "github.com/lmittmann/tint" "gitlab.com/gitlab-org/cves/internal/cve" "gitlab.com/gitlab-org/cves/internal/git" "gitlab.com/gitlab-org/cves/internal/mitre" ) const ( exitNormal = iota exitError ) func MitreSync(stdout io.Writer, args []string) int { start := time.Now() flags, err := parseFlags(stdout, args) if err != nil { return exitError } ctx := context.Background() initLogger(stdout, flags) if flags.sinceCommit == "" { return fatal("since commit is empty or not set", nil) } repo, err := git.NewRepository(flags.repo) if err != nil { return fatal("error initializing repository", err, "path", flags.repo) } slog.Info("starting mitresync", "time", time.Now().UTC().Format(time.RFC1123Z), "sinceCommit", flags.sinceCommit, "repo", repo.Dir()) slog.Debug("debugging enabled") if flags.dryrun { slog.Warn("DRY RUN ENABLED: data changing actions will not be performed") } if err := processModified(ctx, repo, flags.sinceCommit, flags); err != nil { slog.Warn("finished with errors", "dur", time.Since(start).Round(time.Millisecond)) return exitError } slog.Info("finished with no errors", "dur", time.Since(start).Round(time.Millisecond)) return exitNormal } func processModified(ctx context.Context, repo *git.Repository, commit string, flags *mitreSyncFlags) error { records, err := modifiedRecords(ctx, repo, commit) if err != nil { slog.Error("error getting modified records", "error", err) return err } if !flags.update { if len(records) > 0 { slog.Warn("skipping updating of modified records on MITRE", "records", records) } return nil } if len(records) == 0 { slog.Info("no modified records to update on MITRE") return nil } mitreClient, err := initMITREClient(flags) if err != nil { slog.Error("error initializing MITRE API client", "error", err) return fmt.Errorf("initializing MITRE API client: %w", err) } var errs error for i, record := range records { logger := slog.With("record", record, "progress", fmt.Sprintf("%d/%d", i+1, len(records))) published, err := mitreClient.GetRecord(ctx, record.CveMetadata.CveID) if err != nil { errs = errors.Join(errs, err) logger.Error("error fetching record on MITRE", "error", err) continue } if equalContainers(record.Containers.Cna, published.Containers.Cna) { logger.Info("record has no changes to CNA container; skipping") continue } if err := mitreClient.UpdateRecord(ctx, record.CveMetadata.CveID, &record.Containers.Cna); err != nil { errs = errors.Join(errs, err) logger.Error("error updating modified record on MITRE", "error", err) } logger.Info("updated modified record on MITRE") } return errs } func modifiedRecords(ctx context.Context, repo *git.Repository, commit string) ([]*cve.Record, error) { files, err := repo.ModifiedFiles(ctx, commit) if err != nil { return nil, fmt.Errorf("getting modified files since commit: %w", err) } return recordsFromFiles(files) } func recordsFromFiles(files []string) ([]*cve.Record, error) { filtered, err := filterRecordFiles(files) if err != nil { return nil, fmt.Errorf("filtering record files: %w", err) } records := make([]*cve.Record, 0, len(filtered)) for _, name := range filtered { f, err := os.Open(name) if err != nil { return nil, fmt.Errorf("opening record file: %w", err) } record, err := cve.RecordFromReader(f) f.Close() if err != nil { return nil, fmt.Errorf("creating record from %s: %w", name, err) } records = append(records, record) } return records, nil } func filterRecordFiles(files []string) ([]string, error) { slog.Debug("filtering record files", "files", files) records := make([]string, 0, len(files)) for _, file := range files { logger := slog.With("file", file) if !strings.HasPrefix(filepath.Base(file), "CVE-") { logger.Debug("excluding file: basename does not start with CVE-") continue } if filepath.Ext(file) != ".json" { logger.Debug("excluding file: extension is not .json") continue } f, err := os.Open(file) if err != nil { return nil, fmt.Errorf("opening record file: %w", err) } version, err := cve.RecordVersion(f) f.Close() if err != nil { if errors.Is(err, cve.ErrNotCVERecord) { logger.Debug("excluding file: does not contain a CVE record") continue } return nil, fmt.Errorf("determining record version in %s: %w", file, err) } if !strings.HasPrefix(version, "5.") { logger.Debug("excluding file: does not contain a v5 CVE record", "version", version) continue } logger.Debug("including file") records = append(records, file) } return records, nil } func equalContainers(a, b cve.CnaEdContainer) bool { // Provider metadata is cleared before comparison as it can contain timestamps // that are not relevant for the comparison. a.ProviderMetadata = cve.ProviderMetadata{} b.ProviderMetadata = cve.ProviderMetadata{} return reflect.DeepEqual(a, b) } type mitreSyncFlags struct { sinceCommit string update bool mitreBaseURL string dryrun bool debug bool repo string args []string } func parseFlags(stdout io.Writer, args []string) (*mitreSyncFlags, error) { f := mitreSyncFlags{ repo: "..", } fs := flag.NewFlagSet("mitresync", flag.ContinueOnError) fs.StringVar(&f.sinceCommit, "since-commit", "", "check for record file changes since commit SHA") fs.BoolVar(&f.update, "update", true, "update modified records on MITRE") fs.StringVar(&f.mitreBaseURL, "mitre-base-url", mitre.DefaultBaseURL, "base URL to use for MITRE API requests") fs.BoolVar(&f.dryrun, "dry-run", false, "don't perform data changing actions") fs.BoolVar(&f.debug, "debug", false, "log debugging information") if err := fs.Parse(args); err != nil { if errors.Is(err, flag.ErrHelp) { fs.Usage() return nil, err } return nil, fmt.Errorf("parsing flags: %w", err) } if f.sinceCommit == "" { f.sinceCommit = os.Getenv("CI_COMMIT_BEFORE_SHA") } f.args = fs.Args() if len(f.args) > 0 { f.repo = f.args[0] } return &f, nil } func initLogger(stdout io.Writer, f *mitreSyncFlags) { opts := &tint.Options{ Level: slog.LevelInfo, TimeFormat: time.TimeOnly, } if f.debug { opts.Level = slog.LevelDebug } // Disable colors and stabilize time values with zero values if testing. if _, ok := os.LookupEnv("TEST_MITRESYNC"); ok { opts.NoColor = true opts.ReplaceAttr = func(groups []string, attr slog.Attr) slog.Attr { switch attr.Value.Any().(type) { case time.Time: return slog.Time(attr.Key, time.Time{}) case time.Duration: return slog.Duration(attr.Key, time.Duration(0)) default: return attr } } } logger := slog.New(tint.NewHandler(stdout, opts)) slog.SetDefault(logger) } func initMITREClient(f *mitreSyncFlags) (*mitre.Client, error) { return mitre.NewClient( os.Getenv("MITRE_CNA_UUID"), os.Getenv("MITRE_CNA_SHORTNAME"), os.Getenv("MITRE_API_USER"), os.Getenv("MITRE_API_KEY"), mitre.WithBaseURL(f.mitreBaseURL), mitre.WithDryRun(f.dryrun), ) } func fatal(message string, err error, args ...any) int { if err != nil { args = append(args, "error", err) } slog.Error("fatal: "+message, args...) return exitError }