cmd/cpe2cve/cpe2cve.go (261 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "encoding/csv" "flag" "fmt" "io" "os" "path" "runtime" "runtime/pprof" "strings" "sync" "sync/atomic" "time" "github.com/facebookincubator/flog" "github.com/facebookincubator/nvdtools/cvefeed" "github.com/facebookincubator/nvdtools/stats" "github.com/facebookincubator/nvdtools/wfn" ) func processAll(in <-chan []string, out chan<- []string, caches map[string]*cvefeed.Cache, cfg config, nlines *uint64) { cpesAt := cfg.CPEsAt - 1 for rec := range in { if cpesAt >= len(rec) { flog.Errorf("not enough fields in input (%d)", len(rec)) continue } if stats.AreLogged() { stats.IncrementCounter("line.total") } cpeList := strings.Split(rec[cpesAt], cfg.InRecordSeparator) cpes := make([]*wfn.Attributes, 0, len(cpeList)) for _, uri := range cpeList { if stats.AreLogged() { stats.IncrementCounter("cpe.total") } attr, err := wfn.Parse(uri) if err != nil { flog.Errorf("couldn't parse uri %q: %v", uri, err) continue } cpes = append(cpes, attr) } rec[cpesAt] = strings.Join(cpeList, cfg.OutRecordSeparator) // if performance seems to be the issue, we could try to make these cache.Get's concurrent: // // wg := sync.WaitGroup{} // for provider, cache := range caches { // provider, cache := provider, cache // wg.Add(1) // go func() { // defer wg.Done() // for _, matches := range cache.Get(cpes) { // ... for provider, cache := range caches { for _, matches := range cache.Get(cpes) { ml := len(matches.CPEs) if stats.AreLogged() { stats.IncrementCounterBy("cpe.match", int64(ml)) if ml != 0 { stats.IncrementCounter("line.match") } } matchingCPEs := make([]string, ml) for i, attr := range matches.CPEs { if attr == nil { flog.Errorf("%s matches nil CPE", matches.CVE.ID()) continue } matchingCPEs[i] = (*wfn.Attributes)(attr).BindToURI() } rec2 := make([]string, len(rec)) copy(rec2, rec) cvss := matches.CVE.CVSSv3BaseScore() if cvss == 0 { cvss = matches.CVE.CVSSv2BaseScore() } rec2 = cfg.EraseFields.appendAt( rec2, cfg.CVEsAt-1, matches.CVE.ID(), cfg.MatchesAt-1, strings.Join(matchingCPEs, cfg.OutRecordSeparator), cfg.CWEsAt-1, strings.Join(matches.CVE.CWEs(), cfg.OutRecordSeparator), cfg.CVSS2At-1, fmt.Sprintf("%.1f", matches.CVE.CVSSv2BaseScore()), cfg.CVSS3At-1, fmt.Sprintf("%.1f", matches.CVE.CVSSv3BaseScore()), cfg.CVSSAt-1, fmt.Sprintf("%.1f", cvss), cfg.ProviderAt-1, provider, ) out <- rec2 } } n := atomic.AddUint64(nlines, 1) if n > 0 { if n%10000 == 0 { flog.V(1).Infoln(n, "lines processed") } else if n%1000 == 0 { flog.V(2).Infoln(n, "lines processed") } else if n%100 == 0 { flog.V(3).Infoln(n, "lines processed") } } } } func processInput(in io.Reader, out io.Writer, caches map[string]*cvefeed.Cache, cfg config) chan struct{} { done := make(chan struct{}) procIn := make(chan []string) procOut := make(chan []string) r := csv.NewReader(in) r.Comma = rune(cfg.InFieldSeparator[0]) w := csv.NewWriter(out) w.Comma = rune(cfg.OutFieldSeparator[0]) // spawn processing goroutines var linesProcessed uint64 var procWG sync.WaitGroup procWG.Add(cfg.NumProcessors) for i := 0; i < cfg.NumProcessors; i++ { go func() { processAll(procIn, procOut, caches, cfg, &linesProcessed) procWG.Done() }() } // write processed results in background go func() { for rec := range procOut { if err := w.Write(rec); err != nil { flog.Errorf("write error: %v", err) } w.Flush() } if err := w.Error(); err != nil { flog.Errorf("write error: %v", err) } close(done) }() start := time.Now() // main goroutine reads input and sends it to processors for line := 1; ; line++ { rec, err := r.Read() if err != nil { if err == io.EOF { break } flog.Errorf("read error at line %d: %v", line, err) } procIn <- rec } close(procIn) procWG.Wait() close(procOut) flog.V(1).Infof("processed %d lines in %v", linesProcessed, time.Since(start)) return done } func init() { flog.AddFlags(flag.CommandLine, nil) stats.AddFlags() flag.Usage = func() { fmt.Fprintf(os.Stderr, "usage: %s [flags] nvd_feed.xml.gz...\n", path.Base(os.Args[0])) fmt.Fprintf(os.Stderr, "flags:\n") flag.PrintDefaults() if flog.V(1) { writeConfigFileDefinition(os.Stderr) } os.Exit(1) } flag.Set("logtostderr", "true") } func main() { // we do it like this because if we exit in Main, deferred functions don't get called os.Exit(Main()) } func Main() int { var cfg config cfg.addFlags() provider := flag.String("provider", "", "feed provider. used as a provider name for the feeds passed in through the command line") cfgFile := flag.String("config", "", "path to a config file (JSON or TOML); see usage to see how it's configured (pass -v=1 flag for verbose help). Mutually exclusive with command line flags => when used, other flags are ignored") flag.Parse() var err error if *cfgFile != "" { // override config from config file cfg, err = readConfigFile(*cfgFile) } if err == nil { // add all feeds from cmdline cfg.addFeedsFromArgs(*provider, flag.Args()...) err = cfg.validate() } if err != nil { flog.Error(err) flag.Usage() } start := time.Now() if stats.AreLogged() { defer func(start time.Time) { stats.TrackTime("run.time", start, time.Second) stats.WriteAndLogError() }(start) } flog.V(1).Info("loading NVD feeds...") var overrides cvefeed.Dictionary dicts := map[string]cvefeed.Dictionary{} // provider -> dictionary for provider, files := range cfg.Feeds { dict, err := cvefeed.LoadJSONDictionary(files...) if err != nil { flog.Errorf("failed to load dictionary for provider %s: %v", provider, err) } dicts[provider] = dict } allEmpty := true for _, dict := range dicts { if len(dict) != 0 { allEmpty = false break } } if allEmpty { flog.Error(fmt.Errorf("all dictionaries are empty")) return -1 } overrides, err = cvefeed.LoadJSONDictionary(cfg.FeedOverrides...) if err != nil { flog.Error(err) return -1 } flog.V(1).Infof("...done in %v", time.Since(start)) if len(overrides) != 0 { start = time.Now() flog.V(1).Info("applying overrides...") for _, dict := range dicts { dict.Override(overrides) } flog.V(1).Infof("...done in %v", time.Since(start)) } caches := map[string]*cvefeed.Cache{} for provider, dict := range dicts { caches[provider] = cvefeed.NewCache(dict).SetRequireVersion(cfg.RequireVersion).SetMaxSize(cfg.CacheSize) } if cfg.IndexDict { start = time.Now() flog.V(1).Info("indexing dictionaries...") for provider, cache := range caches { cache.Idx = cvefeed.NewIndex(dicts[provider]) if flog.V(2) { var named, total int for k, v := range cache.Idx { if k != wfn.Any { named += len(v) } total += len(v) } flog.Infof("%d out of %d records are named", named, total) } } flog.V(1).Infof("...done in %v", time.Since(start)) } if cfg.CPUProfile != "" { f, err := os.Create(cfg.CPUProfile) if err != nil { flog.Error(err) return 1 } pprof.StartCPUProfile(f) defer pprof.StopCPUProfile() } done := processInput(os.Stdin, os.Stdout, caches, cfg) if cfg.MemoryProfile != "" { f, err := os.Create(cfg.MemoryProfile) if err != nil { flog.Error(err) return 1 } runtime.GC() if err = pprof.WriteHeapProfile(f); err != nil { flog.Errorf("couldn't write heap profile: %v", err) } f.Close() } <-done return 0 }