vulndb/vendor.go (446 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vulndb import ( "context" "database/sql" "encoding/csv" "io" "strconv" "time" "github.com/pkg/errors" "github.com/facebookincubator/flog" "github.com/facebookincubator/nvdtools/vulndb/debug" "github.com/facebookincubator/nvdtools/vulndb/sqlutil" ) // VendorRecord represents a db record of the `vendor` table. type VendorRecord struct { Version int64 `sql:"version"` TS time.Time `sql:"ts"` Ready bool `sql:"ready"` Owner string `sql:"owner"` Provider string `sql:"provider"` } // VendorDataRecord represents a db record of the `vendor_data` table. type VendorDataRecord struct { Version int64 `sql:"version"` CVE string `sql:"cve_id"` Published time.Time `sql:"published"` Modified time.Time `sql:"modified"` BaseScore float64 `sql:"base_score"` Summary string `sql:"summary"` JSON []byte `sql:"cve_json"` } // VendorDataImporter is a helper for importing an entire dataset // from multiple files. type VendorDataImporter struct { DB *sql.DB Owner string Provider string OnFile func(filename string) } // ImportFiles creates a new dataset version and imports all files into it // Files must be formatted as NVD CVE JSON 1.0 optionally gzipped. func (v VendorDataImporter) ImportFiles(ctx context.Context, files ...string) (*VendorRecord, error) { vendor, err := v.newVersion(ctx, v.Owner, v.Provider) if err != nil { return nil, err } for _, file := range files { if v.OnFile != nil { v.OnFile(file) } records, err := VendorDataFromFile(vendor, file) if err != nil { return nil, err } err = v.importData(ctx, records) if err != nil { return nil, err } } err = v.enableVersion(ctx, vendor) if err != nil { return nil, err } return vendor, nil } func (v VendorDataImporter) newVersion(ctx context.Context, owner, provider string) (*VendorRecord, error) { vendor := VendorRecord{ TS: time.Now().UTC(), Ready: false, Owner: owner, Provider: provider, } r := sqlutil.NewRecordType(vendor).Subset( "ts", "ready", "owner", "provider", ) q := sqlutil.Insert(). Into("vendor"). Fields(r.Fields()...). Values(r) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } res, err := v.DB.ExecContext(ctx, query, args...) if err != nil { return nil, errors.Wrap(err, "cannot insert vendor record") } version, err := res.LastInsertId() if err != nil { return nil, errors.Wrap(err, "cannot get last id from vendor record") } vendor.Version = version return &vendor, nil } func (v VendorDataImporter) replaceVendorData(ctx context.Context, records sqlutil.Records) error { q := sqlutil.Replace(). Into("vendor_data"). Fields(records.Fields()...). Values(records...) query, args := q.String(), q.QueryArgs() if debug.V(2) { flog.Infof("running: %q", query) } _, err := v.DB.ExecContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "cannot insert vendor data records") } return nil } func (v VendorDataImporter) replaceVendorDataBatch(ctx context.Context, records sqlutil.Records) error { // when sub-batch gets inserted, but some other fails, don't insert the succeeded one from := 0 OuterLoop: // start with full size and gradually double the size down for batchSize := len(records); batchSize > 0; batchSize /= 2 { for idx := from; idx < len(records); idx += batchSize { limit := idx + batchSize if limit > len(records) { limit = len(records) } if err := v.replaceVendorData(ctx, records[idx:limit]); err != nil { continue OuterLoop } // succeeded, move the from to the new location from = limit } // if it didn't continue before here, then all inserted return nil } // if it came to here, means it didn't insert return errors.New("can't insert batch") } func (v VendorDataImporter) importData(ctx context.Context, data []VendorDataRecord) error { records := sqlutil.NewRecords(data) const batchSize = 100 // the next few lines insert records into vendor_data in batches // if inserting a batch fails, then we subdivide the batch into half and try to insert that // repeat the process until the batch size comes to 0 for i := 0; i < len(records); i += batchSize { limit := i + batchSize if limit > len(records) { limit = len(records) } if err := v.replaceVendorDataBatch(ctx, records[i:limit]); err != nil { return errors.Wrap(err, "cannot insert vendor data records") } } return nil } func (v VendorDataImporter) enableVersion(ctx context.Context, vendor *VendorRecord) error { q := sqlutil.Update("vendor").Set( sqlutil.Assign().Equal("ready", true), ).Where( sqlutil.Cond().Equal("version", vendor.Version), ) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } _, err := v.DB.ExecContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "cannot update vendor record") } return nil } // VendorDataFromFile loads vendor data from NVD CVE JSON files. func VendorDataFromFile(vendor *VendorRecord, name string) ([]VendorDataRecord, error) { feed, err := readNVDCVEJSON(name) if err != nil { return nil, errors.Wrap(err, "cannot load vendor file") } records := make([]VendorDataRecord, len(feed.CVEItems)) for i, item := range feed.CVEItems { cve := cveItem{item} records[i] = VendorDataRecord{ Version: vendor.Version, CVE: cve.ID(), Published: cve.Published(), Modified: cve.Modified(), BaseScore: cve.BaseScore(), Summary: cve.Summary(), JSON: cve.JSON(), } } return records, nil } // VendorDataExporter is a helper for exporting vendor data. type VendorDataExporter struct { DB *sql.DB Provider string FilterCVEs []string } func (v VendorDataExporter) condition() *sqlutil.QueryConditionSet { cond := sqlutil.Cond().InSelect("vendor.version", sqlutil.Select("latest.version"). From(). SelectGroup("latest", latestVendorVersion()). Where( sqlutil.Cond().Equal("provider", v.Provider), ), ) if len(v.FilterCVEs) > 0 { cond = cond.And().In("vendor_data.cve_id", v.FilterCVEs) } return cond } // CSV writes vendor data records to w. func (v VendorDataExporter) CSV(ctx context.Context, w io.Writer, header bool) error { q := sqlutil.Select( "vendor.version AS version", "vendor.ts AS ts", "vendor.owner AS owner", "vendor.provider AS provider", "vendor_data.cve_id AS cve_id", "vendor_data.published AS published", "vendor_data.modified AS modified", "vendor_data.base_score AS base_score", "vendor_data.summary AS summary", ).From( "vendor_data", ).Literal( "LEFT JOIN vendor ON vendor.version = vendor_data.version", ).Where( v.condition(), ) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } rows, err := v.DB.QueryContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "cannot query vendor data") } defer rows.Close() record := struct { Version string `sql:"version"` TS time.Time `sql:"ts"` Owner string `sql:"owner"` Provider string `sql:"provider"` CVE string `sql:"cve_id"` Published time.Time `sql:"published"` Modified time.Time `sql:"modified"` BaseScore float64 `sql:"base_score"` Summary string `sql:"summary"` }{} cw := csv.NewWriter(w) defer cw.Flush() if header { fields := sqlutil.NewRecordType(record).Fields() cw.Write(fields) } for rows.Next() { v := record err = rows.Scan(sqlutil.NewRecordType(&v).Values()...) if err != nil { return errors.Wrap(err, "cannot scan vendor data") } cw.Write([]string{ v.Version, v.TS.Format(TimeLayout), v.Owner, v.Provider, v.CVE, v.Published.Format(TimeLayout), v.Modified.Format(TimeLayout), strconv.FormatFloat(v.BaseScore, 'f', 3, 64), v.Summary, }) } return nil } // JSON writes NVD CVE JSON to w. func (v VendorDataExporter) JSON(ctx context.Context, w io.Writer, indent string) error { q := sqlutil.Select( "cve_id", "cve_json", ).From( "vendor_data", ).Literal( "LEFT JOIN vendor ON vendor.version = vendor_data.version", ).Where( v.condition(), ) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } rows, err := v.DB.QueryContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "cannot query vendor data") } defer rows.Close() record := struct { CVE string JSON []byte }{} f := &cveFile{} for rows.Next() { v := record err = rows.Scan(sqlutil.NewRecordType(&v).Values()...) if err != nil { return errors.Wrap(err, "cannot scan vendor data") } err = f.Add(v.CVE, v.JSON) if err != nil { return err } } if indent == "" { return f.EncodeJSON(w) } const prefix = "" return f.EncodeIndentedJSON(w, prefix, indent) } // VendorDataTrimmer is a helper for trimming vendor data. // // It deletes all versions but the latest. // // Deleting would be easier in common scenarions, but we have some hard // constraints: // // * Vendor data is versioned // * No foreign key between vendor_data and vendor tables // * MySQL in safe mode forbids deleting from SELECT queries, wants values // * Must keep the binlog smaller than 500M, not enough for the NVD database // // Therefore, deletion works as follows: // // * Select versions from the vendor table based on the provided settings // * Operate on vendor records with ready=true or older versions // * By default, delete all versions but the latest, for each provider // * Delete from vendor table first, effectively making data records orphans // * Delete any orphan records from vendor_data, effectively crowd sourcing deletions // * Delete data in chunks, keeping binlog small // // Deletion operations are expensive. type VendorDataTrimmer struct { DB *sql.DB FilterProviders []string DeleteLatestVersion bool // TODO: support keeping up to N versions } // Trim deletes vendor data versions from the database. func (v VendorDataTrimmer) Trim(ctx context.Context) error { err := v.deleteVendors(ctx) if err != nil { return err } return v.deleteOrphanData(ctx) } func (v VendorDataTrimmer) deleteVendors(ctx context.Context) error { versions, err := v.selectVendorVersions(ctx) if err != nil { return err } if len(versions) == 0 { return nil } q := sqlutil.Delete().From( "vendor", ).Where( sqlutil.Cond().In("version", versions), ) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } _, err = v.DB.ExecContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "cannot delete vendor data") } return nil } func (v VendorDataTrimmer) deleteOrphanData(ctx context.Context) error { versions, err := v.selectOrphanDataVersions(ctx) if err != nil { return err } if len(versions) == 0 { return nil } q := sqlutil.Delete().From( "vendor_data", ).Where( sqlutil.Cond().In("version", versions), ) query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } q = q.Literal("LIMIT 100") for { res, err := v.DB.ExecContext(ctx, q.String(), q.QueryArgs()...) if err != nil { return errors.Wrap(err, "cannot delete vendor data") } n, err := res.RowsAffected() if err != nil { return errors.Wrap(err, "cannot get rows affected") } if n == 0 { break } } return nil } func (v VendorDataTrimmer) selectVendorVersions(ctx context.Context) ([]int64, error) { cond := sqlutil.Cond().Group( sqlutil.Cond(). Equal("ready", true). Or(). Group( // this is to delete stale data from failed imports sqlutil.Cond(). Equal("ready", false). And(). Literal("ts < DATE_SUB(NOW(), INTERVAL 1 DAY)"), ), ) if len(v.FilterProviders) > 0 { cond = cond.And().In("provider", v.FilterProviders) } if !v.DeleteLatestVersion { cond = cond.And().Not().InSelect( "version", sqlutil.Select("latest.version"). From(). SelectGroup("latest", latestVendorVersion()), ) } q := sqlutil.Select( "version", ).From( "vendor", ).Where( cond, ) return v.selectVersions(ctx, q) } func (v VendorDataTrimmer) selectOrphanDataVersions(ctx context.Context) ([]int64, error) { q := sqlutil.Select( "vendor_data.version", ).From( "vendor_data", ).Literal( "LEFT JOIN vendor ON vendor.version = vendor_data.version", ).Literal( "WHERE vendor.version IS NULL", ).Literal( "GROUP BY vendor_data.version", ) return v.selectVersions(ctx, q) } func (v VendorDataTrimmer) selectVersions(ctx context.Context, q *sqlutil.SelectStmt) ([]int64, error) { query, args := q.String(), q.QueryArgs() if debug.V(1) { flog.Infof("running: %q / %#v", query, args) } rows, err := v.DB.QueryContext(ctx, query, args...) if err != nil { return nil, errors.Wrap(err, "cannot query data") } defer rows.Close() var versions []int64 for rows.Next() { var version int64 err = rows.Scan(&version) if err != nil { return nil, errors.Wrap(err, "cannot scan data") } versions = append(versions, version) } return versions, nil } func latestVendorVersion() *sqlutil.SelectStmt { return sqlutil.Select( "MAX(version) AS version", "provider", ).From( "vendor", ).Where( sqlutil.Cond().Equal("ready", true), ).Literal( "GROUP BY provider", ) }