providers/rbs/api/client.go (103 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package api import ( "context" "encoding/json" "fmt" "net/url" "sync" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" "github.com/facebookincubator/flog" "github.com/facebookincubator/nvdtools/providers/lib/client" "github.com/facebookincubator/nvdtools/providers/lib/runner" "github.com/facebookincubator/nvdtools/providers/rbs/schema" ) const ( pageSize = 100 ) type Client struct { client.Client baseURL string } func NewClient(c client.Client, baseURL, tokenURL, clientID, clientSecret string) *Client { // TODO this might not work anymore, since c should be a *http.Client // Need to add all other functions to the client.Client interface and eventually maybe drop the interface altogether ctx := context.WithValue(context.Background(), oauth2.HTTPClient, c) conf := &clientcredentials.Config{ ClientID: clientID, ClientSecret: clientSecret, TokenURL: tokenURL, } return &Client{ Client: conf.Client(ctx), baseURL: baseURL, } } func (c *Client) FetchAllVulnerabilitiesAfterVulndbID(vulndbID int) (<-chan runner.Convertible, error) { u := fmt.Sprintf("%d/find_next_to_vulndb_id_full", vulndbID) return c.fetchAllVulnerabilities(func() string { return u }) } func (c *Client) FetchAllVulnerabilities(since int64) (<-chan runner.Convertible, error) { from := time.Unix(since, 0) return c.fetchAllVulnerabilities(func() string { // we need to recalculate hours ago on each request, if the fetching takes more than an hour return fmt.Sprintf("find_by_time_full?hours_ago=%d", int(time.Since(from).Hours())) }) } func (c *Client) fetchAllVulnerabilities(getEndpoint func() string) (<-chan runner.Convertible, error) { fetch := func(page, size int) (*schema.VulnerabilityResult, error) { u, err := url.Parse(fmt.Sprintf("%s/api/v1/vulnerabilities/%s", c.baseURL, getEndpoint())) if err != nil { return nil, fmt.Errorf("can't parse url: %v", err) } values := u.Query() values.Set("page", fmt.Sprintf("%d", page)) values.Set("size", fmt.Sprintf("%d", size)) u.RawQuery = values.Encode() return c.getResult(u.String()) } result, err := fetch(1, 1) if err != nil { return nil, err } totalVulns := result.TotalEntries if totalVulns == 0 { return nil, fmt.Errorf("no vulnerabilities found") } output := make(chan runner.Convertible) numPages := (totalVulns-1)/pageSize + 1 // fetch pages concurrently flog.Infof("starting sync for %d vulnerabilities over %d pages\n", totalVulns, numPages) wg := sync.WaitGroup{} for page := 1; page <= numPages; page++ { page := page wg.Add(1) go func() { defer wg.Done() result, err := fetch(page, pageSize) if err != nil { flog.Errorf("failed to get page %d: %v", page, err) return } for _, vuln := range result.Vulnerabilities { if vuln != nil { output <- vuln } } }() } go func() { wg.Wait() close(output) }() return output, nil } func (c *Client) getResult(u string) (*schema.VulnerabilityResult, error) { resp, err := c.Get(u) if err != nil { return nil, fmt.Errorf("can't get response: %v", err) } defer resp.Body.Close() var result schema.VulnerabilityResult if err = json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("can't decode result: %v", err) } return &result, nil }