ci/internal/mitre/client.go (187 lines of code) (raw):
package mitre
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"github.com/hashicorp/go-retryablehttp"
"gitlab.com/gitlab-org/cves/internal/cve"
)
const (
// DefaultBaseURL is the default base URL used for MITRE API requests.
DefaultBaseURL = "https://cveawg.mitre.org/api"
// DefaultRetries is the default number of retries of failing API requests.
DefaultRetries = 3
// DefaultTimeout is the default timeout for API requests.
DefaultTimeout = 30 * time.Second
)
// State represents a state for a CVE record.
type State int
const (
StateUnknown State = iota // CVE is in unknown state.
StateReserved // CVE is reserved.
StatePublished // CVE is published.
StateRejected // CVE is rejected.
)
// String returns the string representation of the state.
func (s State) String() string {
return [...]string{
"UNKNOWN",
"RESERVED",
"PUBLISHED",
"REJECTED",
}[s]
}
// LogValue satisfies the [slog.LogValuer] interface.
func (s State) LogValue() slog.Value {
return slog.StringValue(s.String())
}
// Client for interacting with the MITRE API.
type Client struct {
cnaUUID string
cnaShortname string
apiUser string
apiKey string
baseURL string
retries int
dryrun bool
client *retryablehttp.Client
}
// NewClient returns a new client for interacting with the MITRE API configured
// with the given options.
func NewClient(cnaUUID, cnaShortname, apiUser, apiKey string, opts ...ClientOption) (*Client, error) {
client := &Client{
cnaUUID: cnaUUID,
cnaShortname: cnaShortname,
apiUser: apiUser,
apiKey: apiKey,
baseURL: DefaultBaseURL,
retries: DefaultRetries,
}
for i, opt := range opts {
if err := opt(client); err != nil {
return nil, fmt.Errorf("applying option #%d: %w", i+1, err)
}
}
if client.cnaUUID == "" {
return nil, errors.New("cnaUUID is empty")
}
if client.cnaShortname == "" {
return nil, errors.New("cnaShortname is empty")
}
if client.apiUser == "" {
return nil, errors.New("apiUser is empty")
}
if client.apiKey == "" {
return nil, errors.New("apiKey is empty")
}
rhc := retryablehttp.NewClient()
rhc.RetryMax = client.retries
rhc.HTTPClient.Timeout = DefaultTimeout
rhc.Logger = slog.Default()
client.client = rhc
return client, nil
}
// GetRecord fetches the record for the given CVE ID.
// See: https://cveawg.mitre.org/api-docs/#/CVE%20Record/cveGetSingle
func (c *Client) GetRecord(ctx context.Context, cveID string) (*cve.Record, error) {
resp, err := c.do(ctx, http.MethodGet, fmt.Sprintf("cve/%s", url.PathEscape(cveID)), nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
record, err := cve.RecordFromReader(resp.Body)
if err != nil {
return nil, fmt.Errorf("getting record from response body: %w", err)
}
return record, nil
}
// GetState fetches the state for the given CVE ID.
// See: https://cveawg.mitre.org/api-docs/#/CVE%20ID/cveIdGetSingle
func (c *Client) GetState(ctx context.Context, cveID string) (State, error) {
resp, err := c.do(ctx, http.MethodGet, fmt.Sprintf("cve-id/%s", url.PathEscape(cveID)), nil)
if err != nil {
return StateUnknown, err
}
defer resp.Body.Close()
s := struct {
State string `json:"state"`
}{}
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
return StateUnknown, fmt.Errorf("decoding response body: %w", err)
}
switch s.State {
case StateReserved.String():
return StateReserved, nil
case StatePublished.String():
return StatePublished, nil
case StateRejected.String():
return StateRejected, nil
default:
return StateUnknown, fmt.Errorf("unexpected cve state in response: %q", s.State)
}
}
// UpdateRecord updates the CNA container data for the record with the given CVE ID.
// See: https://cveawg.mitre.org/api-docs/#/CVE%20Record/cveCnaUpdateSingle
func (c *Client) UpdateRecord(ctx context.Context, cveID string, container *cve.CnaEdContainer) error {
if container == nil {
return errors.New("container is nil")
}
body := struct {
CNAContainer *cve.CnaEdContainer `json:"cnaContainer"`
}{
CNAContainer: container,
}
b, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("encoding request body: %w", err)
}
resp, err := c.do(ctx, http.MethodPut, fmt.Sprintf("cve/%s/cna", url.PathEscape(cveID)), b)
if err != nil {
return err
}
if resp != nil {
defer resp.Body.Close()
}
return nil
}
func (r *Client) do(ctx context.Context, method, path string, body []byte) (*http.Response, error) {
if r.dryrun && method != http.MethodGet {
slog.Warn("DRY RUN: skipping MITRE API request", "method", method, "path", path)
return nil, nil
}
u := fmt.Sprintf("%s/%s", r.baseURL, path)
req, err := retryablehttp.NewRequestWithContext(ctx, method, u, body)
if err != nil {
return nil, fmt.Errorf("creating %s %s request: %w", method, path, err)
}
req.Header.Add("CVE-API-ORG", r.cnaShortname)
req.Header.Add("CVE-API-USER", r.apiUser)
req.Header.Add("CVE-API-KEY", r.apiKey)
req.Header.Add("Content-Type", "application/json")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("performing %s %s request: %w", method, path, err)
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
resp.Body.Close()
return nil, fmt.Errorf("unexpected response status for %s %s request: %s", method, path, resp.Status)
}
return resp, nil
}
// ClientOption configures a [Client].
type ClientOption func(*Client) error
// WithBaseURL configures a [Client] to use the given base URL for API requests.
func WithBaseURL(u string) ClientOption {
u = strings.TrimSuffix(strings.TrimSpace(u), "/")
if u == "" {
u = DefaultBaseURL
}
return func(c *Client) error {
c.baseURL = u
return nil
}
}
// WithRetries configures a [Client] to retry failing API requests given number
// of times.
func WithRetries(n int) ClientOption {
return func(c *Client) error {
if n < 1 {
return errors.New("retry number cannot be less than one")
}
c.retries = n
return nil
}
}
// WithDryRun configures a [Client] to be in dry run mode where data changing
// API requests are not performed.
func WithDryRun(enabled bool) ClientOption {
return func(c *Client) error {
c.dryrun = enabled
return nil
}
}