ci/internal/mitre/client.go (187 lines of code) (raw):

package mitre import ( "context" "encoding/json" "errors" "fmt" "log/slog" "net/http" "net/url" "strings" "time" "github.com/hashicorp/go-retryablehttp" "gitlab.com/gitlab-org/cves/internal/cve" ) const ( // DefaultBaseURL is the default base URL used for MITRE API requests. DefaultBaseURL = "https://cveawg.mitre.org/api" // DefaultRetries is the default number of retries of failing API requests. DefaultRetries = 3 // DefaultTimeout is the default timeout for API requests. DefaultTimeout = 30 * time.Second ) // State represents a state for a CVE record. type State int const ( StateUnknown State = iota // CVE is in unknown state. StateReserved // CVE is reserved. StatePublished // CVE is published. StateRejected // CVE is rejected. ) // String returns the string representation of the state. func (s State) String() string { return [...]string{ "UNKNOWN", "RESERVED", "PUBLISHED", "REJECTED", }[s] } // LogValue satisfies the [slog.LogValuer] interface. func (s State) LogValue() slog.Value { return slog.StringValue(s.String()) } // Client for interacting with the MITRE API. type Client struct { cnaUUID string cnaShortname string apiUser string apiKey string baseURL string retries int dryrun bool client *retryablehttp.Client } // NewClient returns a new client for interacting with the MITRE API configured // with the given options. func NewClient(cnaUUID, cnaShortname, apiUser, apiKey string, opts ...ClientOption) (*Client, error) { client := &Client{ cnaUUID: cnaUUID, cnaShortname: cnaShortname, apiUser: apiUser, apiKey: apiKey, baseURL: DefaultBaseURL, retries: DefaultRetries, } for i, opt := range opts { if err := opt(client); err != nil { return nil, fmt.Errorf("applying option #%d: %w", i+1, err) } } if client.cnaUUID == "" { return nil, errors.New("cnaUUID is empty") } if client.cnaShortname == "" { return nil, errors.New("cnaShortname is empty") } if client.apiUser == "" { return nil, errors.New("apiUser is empty") } if client.apiKey == "" { return nil, errors.New("apiKey is empty") } rhc := retryablehttp.NewClient() rhc.RetryMax = client.retries rhc.HTTPClient.Timeout = DefaultTimeout rhc.Logger = slog.Default() client.client = rhc return client, nil } // GetRecord fetches the record for the given CVE ID. // See: https://cveawg.mitre.org/api-docs/#/CVE%20Record/cveGetSingle func (c *Client) GetRecord(ctx context.Context, cveID string) (*cve.Record, error) { resp, err := c.do(ctx, http.MethodGet, fmt.Sprintf("cve/%s", url.PathEscape(cveID)), nil) if err != nil { return nil, err } defer resp.Body.Close() record, err := cve.RecordFromReader(resp.Body) if err != nil { return nil, fmt.Errorf("getting record from response body: %w", err) } return record, nil } // GetState fetches the state for the given CVE ID. // See: https://cveawg.mitre.org/api-docs/#/CVE%20ID/cveIdGetSingle func (c *Client) GetState(ctx context.Context, cveID string) (State, error) { resp, err := c.do(ctx, http.MethodGet, fmt.Sprintf("cve-id/%s", url.PathEscape(cveID)), nil) if err != nil { return StateUnknown, err } defer resp.Body.Close() s := struct { State string `json:"state"` }{} if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { return StateUnknown, fmt.Errorf("decoding response body: %w", err) } switch s.State { case StateReserved.String(): return StateReserved, nil case StatePublished.String(): return StatePublished, nil case StateRejected.String(): return StateRejected, nil default: return StateUnknown, fmt.Errorf("unexpected cve state in response: %q", s.State) } } // UpdateRecord updates the CNA container data for the record with the given CVE ID. // See: https://cveawg.mitre.org/api-docs/#/CVE%20Record/cveCnaUpdateSingle func (c *Client) UpdateRecord(ctx context.Context, cveID string, container *cve.CnaEdContainer) error { if container == nil { return errors.New("container is nil") } body := struct { CNAContainer *cve.CnaEdContainer `json:"cnaContainer"` }{ CNAContainer: container, } b, err := json.Marshal(body) if err != nil { return fmt.Errorf("encoding request body: %w", err) } resp, err := c.do(ctx, http.MethodPut, fmt.Sprintf("cve/%s/cna", url.PathEscape(cveID)), b) if err != nil { return err } if resp != nil { defer resp.Body.Close() } return nil } func (r *Client) do(ctx context.Context, method, path string, body []byte) (*http.Response, error) { if r.dryrun && method != http.MethodGet { slog.Warn("DRY RUN: skipping MITRE API request", "method", method, "path", path) return nil, nil } u := fmt.Sprintf("%s/%s", r.baseURL, path) req, err := retryablehttp.NewRequestWithContext(ctx, method, u, body) if err != nil { return nil, fmt.Errorf("creating %s %s request: %w", method, path, err) } req.Header.Add("CVE-API-ORG", r.cnaShortname) req.Header.Add("CVE-API-USER", r.apiUser) req.Header.Add("CVE-API-KEY", r.apiKey) req.Header.Add("Content-Type", "application/json") resp, err := r.client.Do(req) if err != nil { return nil, fmt.Errorf("performing %s %s request: %w", method, path, err) } if resp.StatusCode < 200 || resp.StatusCode > 299 { resp.Body.Close() return nil, fmt.Errorf("unexpected response status for %s %s request: %s", method, path, resp.Status) } return resp, nil } // ClientOption configures a [Client]. type ClientOption func(*Client) error // WithBaseURL configures a [Client] to use the given base URL for API requests. func WithBaseURL(u string) ClientOption { u = strings.TrimSuffix(strings.TrimSpace(u), "/") if u == "" { u = DefaultBaseURL } return func(c *Client) error { c.baseURL = u return nil } } // WithRetries configures a [Client] to retry failing API requests given number // of times. func WithRetries(n int) ClientOption { return func(c *Client) error { if n < 1 { return errors.New("retry number cannot be less than one") } c.retries = n return nil } } // WithDryRun configures a [Client] to be in dry run mode where data changing // API requests are not performed. func WithDryRun(enabled bool) ClientOption { return func(c *Client) error { c.dryrun = enabled return nil } }