network/client.go (488 lines of code) (raw):

package network import ( "context" "crypto/tls" "crypto/x509" "encoding/json" "encoding/xml" "errors" "fmt" "io" "mime" "net" "net/http" "net/url" "os" "path/filepath" "strings" "sync" "time" "github.com/jpillora/backoff" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-runner/common" "gitlab.com/gitlab-org/gitlab-runner/helpers/tls/ca_chain" url_helpers "gitlab.com/gitlab-org/gitlab-runner/helpers/url" ) const ( jsonMimeType = "application/json" applicationXMLMimeType = "application/xml" textXMLMimeType = "text/xml" ) type requestCredentials interface { GetURL() string GetToken() string GetTLSCAFile() string GetTLSCertFile() string GetTLSKeyFile() string } var dialer = net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } const ( backOffMinDelay = 100 * time.Millisecond backOffMaxDelay = 60 * time.Second backOffDelayFactor = 2.0 backOffDelayJitter = true ) type Option = func(c *client) error type client struct { http.Client url *url.URL caFile string certFile string keyFile string caData []byte updateTime time.Time lastIdleRefresh time.Time lastUpdate string requestBackOffs map[string]*backoff.Backoff connectionMaxAge time.Duration lock sync.Mutex requester requester } type ResponseTLSData struct { CAChain string CertFile string KeyFile string } func (n *client) getLastUpdate() string { return n.lastUpdate } func (n *client) setLastUpdate(headers http.Header) { if lu := headers.Get("X-GitLab-Last-Update"); lu != "" { n.lastUpdate = lu } } func (n *client) ensureTLSConfig() { // certificate got modified if stat, err := os.Stat(n.caFile); err == nil && n.updateTime.Before(stat.ModTime()) { n.Transport = nil } // client certificate got modified if stat, err := os.Stat(n.certFile); err == nil && n.updateTime.Before(stat.ModTime()) { n.Transport = nil } // client private key got modified if stat, err := os.Stat(n.keyFile); err == nil && n.updateTime.Before(stat.ModTime()) { n.Transport = nil } // create or update transport if n.Transport == nil { n.updateTime = time.Now() n.lastIdleRefresh = time.Now() n.createTransport() } } // To ensure long-lived TLS connections pick up rotated certificates // and to ensure load balancers distribute connections evenly, limit // the age of a connection to 15 minutes. Go has an upstream proposal // to do this in https://github.com/golang/go/issues/54429, but this // feature is not yet available. func (n *client) ensureTransportMaxAge() { if n.connectionMaxAge == 0 { return } if n.Transport == nil { return } elapsed := time.Since(n.lastIdleRefresh) if elapsed <= n.connectionMaxAge { return } logrus.WithFields(logrus.Fields{ "elapsed_s": elapsed.Seconds(), "max_age_s": n.connectionMaxAge.Seconds(), }).Debug("Closing idle connections") n.CloseIdleConnections() n.lastIdleRefresh = time.Now() } func (n *client) addTLSCA(tlsConfig *tls.Config) { // load TLS CA certificate file := n.caFile if file == "" { return } logrus.Debugln("Trying to load", file, "...") data, err := os.ReadFile(file) if err != nil { if !os.IsNotExist(err) { logrus.Errorln("Failed to load", n.caFile, err) } return } pool, err := x509.SystemCertPool() if err != nil { logrus.Warningln("Failed to load system CertPool:", err) } if pool == nil { pool = x509.NewCertPool() } if !pool.AppendCertsFromPEM(data) { logrus.Errorln("Failed to parse PEM in", n.caFile) return } tlsConfig.RootCAs = pool n.caData = data } func (n *client) addTLSAuth(tlsConfig *tls.Config) { if n.certFile == "" || n.keyFile == "" { return } logrus.Debugln("Trying to load", n.certFile, "and", n.keyFile, "pair...") // load TLS client keypair certificate, err := tls.LoadX509KeyPair(n.certFile, n.keyFile) if err != nil { if !os.IsNotExist(err) { logrus.Errorln("Failed to load", n.certFile, n.keyFile, err) } return } tlsConfig.Certificates = []tls.Certificate{certificate} //nolint:staticcheck tlsConfig.BuildNameToCertificate() } func (n *client) createTransport() { // create reference TLS config tlsConfig := tls.Config{ MinVersion: tls.VersionTLS12, } n.addTLSCA(&tlsConfig) n.addTLSAuth(&tlsConfig) // create transport n.Transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, Dial: func(network, addr string) (net.Conn, error) { logrus.Debugln("Dialing:", network, addr, "...") return dialer.Dial(network, addr) }, TLSClientConfig: &tlsConfig, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Minute, } n.Timeout = common.DefaultNetworkClientTimeout } func (n *client) ensureBackoff(method, uri string) *backoff.Backoff { n.lock.Lock() defer n.lock.Unlock() key := fmt.Sprintf("%s_%s", method, uri) if n.requestBackOffs[key] == nil { n.requestBackOffs[key] = &backoff.Backoff{ Min: backOffMinDelay, Max: backOffMaxDelay, Factor: backOffDelayFactor, Jitter: backOffDelayJitter, } } return n.requestBackOffs[key] } func (n *client) backoffRequired(res *http.Response) bool { if res.StatusCode == http.StatusTooManyRequests { // StatusTooManyRequests is handled by the caller, to allow for early logging return false } return res.StatusCode >= 400 && res.StatusCode < 600 } func (n *client) checkBackoffRequest(req *http.Request, res *http.Response) { backoffDelay := n.ensureBackoff(req.Method, req.RequestURI) if n.backoffRequired(res) { time.Sleep(backoffDelay.Duration()) } else { backoffDelay.Reset() } } func (n *client) do( ctx context.Context, uri, method string, bodyProvider common.ContentProvider, requestType string, headers http.Header, ) (*http.Response, error) { url, err := n.url.Parse(uri) if err != nil { return nil, err } var body io.ReadCloser if bodyProvider != nil { body, err = bodyProvider.GetReader() if err != nil { return nil, fmt.Errorf("couldn't get request body: %w", err) } defer body.Close() } req, err := http.NewRequestWithContext(ctx, method, url.String(), body) if err != nil { err = fmt.Errorf("failed to create NewRequest: %w", err) return nil, err } if bodyProvider != nil { req.GetBody = func() (io.ReadCloser, error) { return bodyProvider.GetReader() } if length, known := bodyProvider.GetContentLength(); known { req.ContentLength = length } } if headers != nil { req.Header = headers } req.Header.Set("User-Agent", common.AppVersion.UserAgent()) if bodyProvider != nil { req.Header.Set(common.ContentType, requestType) } n.ensureTLSConfig() n.ensureTransportMaxAge() res, err := n.requester.Do(req) if err != nil { return nil, err } n.checkBackoffRequest(req, res) return res, nil } // ErrorResponse is an error type that is returned when there is an issue // calling the remote server. It contains the http.Response responsible for // the error and the error payload provided by the server. type ErrorResponse struct { Response *http.Response `json:"-"` Message ErrorResponseMessage `json:"message"` } // XMLErrorResponse is an error type that is returned when there is an issue // from an object storage provider that returns XML. It contains the // http.Response responsible for the error and the error payload provided by // the server. // // Google: https://cloud.google.com/storage/docs/xml-api/reference-status // Amazon: https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html // Azure: https://docs.microsoft.com/en-us/rest/api/storageservices/status-and-error-codes2 type XMLErrorResponse struct { Response *http.Response `xml:"-"` XMLName xml.Name `xml:"Error"` Code string `xml:"Code"` Message string `xml:"Message"` } type ErrorResponseMessage string func (r *ErrorResponse) Error() string { statusCodeMsg := fmt.Sprintf("%d %s", r.Response.StatusCode, http.StatusText(r.Response.StatusCode)) reqURL := url_helpers.CleanURL(r.Response.Request.URL.String()) errMessage := fmt.Sprintf("%v %s: %s", r.Response.Request.Method, reqURL, statusCodeMsg) if string(r.Message) == statusCodeMsg { // If the message returned by the server is the status text, then don't repeat it in the message return errMessage } return fmt.Sprintf("%s (%s)", errMessage, r.Message) } func (r *XMLErrorResponse) Error() string { statusCodeMsg := fmt.Sprintf("%d %s", r.Response.StatusCode, http.StatusText(r.Response.StatusCode)) if r.Code == "" { return statusCodeMsg } return fmt.Sprintf("%s (%s: %s)", statusCodeMsg, r.Code, r.Message) } func (e *ErrorResponseMessage) UnmarshalJSON(data []byte) error { type simple ErrorResponseMessage err := json.Unmarshal(data, (*simple)(e)) if err == nil { return nil } var complex map[string][]interface{} err = json.Unmarshal(data, &complex) if err != nil { // explicitly ignore error, we can't decode this type return nil } messages := make([]string, 0, len(complex)) for key, val := range complex { values := make([]string, 0, len(val)) for _, msg := range val { values = append(values, fmt.Sprintf("%v", msg)) } messages = append(messages, fmt.Sprintf("%s: %s", key, strings.Join(values, "; "))) } *e = ErrorResponseMessage(strings.Join(messages, ", ")) return nil } func (n *client) doJSON( ctx context.Context, uri, method string, statusCode int, headers http.Header, request interface{}, response interface{}, ) (int, string, *http.Response) { var bytesProvider common.ContentProvider if request != nil { requestBody, err := json.Marshal(request) if err != nil { return -1, fmt.Sprintf("failed to marshal project object: %v", err), nil } bytesProvider = common.BytesProvider{Data: requestBody} } if headers == nil { headers = http.Header{} } if response != nil { headers.Set(common.Accept, jsonMimeType) } res, err := n.do(ctx, uri, method, bytesProvider, jsonMimeType, headers) if err != nil { return -1, err.Error(), nil } defer func() { _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() }() status := getMessageFromJSONResponse(res) if res.StatusCode == statusCode && response != nil { isApplicationJSON, err := isResponseApplicationJSON(res) if !isApplicationJSON { return -1, err.Error(), nil } d := json.NewDecoder(res.Body) err = d.Decode(response) if err != nil { return -1, fmt.Sprintf("Error decoding json payload %v", err), nil } } n.setLastUpdate(res.Header) return res.StatusCode, status, res } func getMessageFromJSONResponse(res *http.Response) string { if res.StatusCode >= 200 && res.StatusCode <= 299 { return res.Status } if isApplicationJSON, _ := isResponseApplicationJSON(res); isApplicationJSON { errMsg, _ := decodeJSONResponse(res) if errMsg != "" { return errMsg } } return res.Status } func getMimeAndContentType(res *http.Response) (mimeType, contentType string, e error) { contentType = res.Header.Get(common.ContentType) mimeType, _, err := mime.ParseMediaType(contentType) if err != nil { return "", contentType, fmt.Errorf("parsing Content-Type: %w", err) } return mimeType, contentType, nil } func decodeJSONResponse(res *http.Response) (string, error) { errResp := ErrorResponse{Response: res} err := json.NewDecoder(res.Body).Decode(&errResp) if err == nil { return errResp.Error(), nil } return "", err } func decodeXMLResponse(res *http.Response) (string, error) { xmlResp := XMLErrorResponse{Response: res} err := xml.NewDecoder(res.Body).Decode(&xmlResp) if err == nil { return xmlResp.Error(), nil } return "", err } func getMessageFromJSONOrXMLResponse(res *http.Response) string { if res.StatusCode >= 200 && res.StatusCode <= 299 { return res.Status } mimeType, _, err := getMimeAndContentType(res) if err != nil { return res.Status } var decodeErr error var errMsg string switch mimeType { case jsonMimeType: errMsg, decodeErr = decodeJSONResponse(res) case applicationXMLMimeType, textXMLMimeType: errMsg, decodeErr = decodeXMLResponse(res) } if errMsg != "" { return errMsg } else if decodeErr != nil { return fmt.Sprintf("%s (%s decode error: %v)", res.Status, mimeType, decodeErr) } return res.Status } func (n *client) getResponseTLSData(tls *tls.ConnectionState, resolveFullChain bool) (ResponseTLSData, error) { TLSData := ResponseTLSData{ CertFile: n.certFile, KeyFile: n.keyFile, } caChain, err := n.buildCAChain(tls, resolveFullChain) if err != nil { return TLSData, fmt.Errorf("couldn't build CA Chain: %w", err) } TLSData.CAChain = caChain return TLSData, nil } func (n *client) buildCAChain(tls *tls.ConnectionState, resolveFullChain bool) (string, error) { if len(n.caData) != 0 { return string(n.caData), nil } if tls == nil { return "", nil } builder := ca_chain.NewBuilder(logrus.StandardLogger(), resolveFullChain) err := builder.BuildChainFromTLSConnectionState(tls) if err != nil { return "", fmt.Errorf("error while fetching certificates from TLS ConnectionState: %w", err) } return builder.String(), nil } func isResponseApplicationJSON(res *http.Response) (result bool, err error) { mimeType, contentType, err := getMimeAndContentType(res) if err != nil { return false, err } if mimeType != jsonMimeType { return false, fmt.Errorf("server should return application/json. Got: %v", contentType) } return true, nil } func fixCIURL(url string) string { url = strings.TrimRight(url, "/") url = strings.TrimSuffix(url, "/ci") return url } func (n *client) findCertificate(certificate *string, base string, name string) { if *certificate != "" { return } path := filepath.Join(base, name) if _, err := os.Stat(path); err == nil { *certificate = path } } func WithMaxAge(connectionMaxAge time.Duration) Option { return func(c *client) error { c.connectionMaxAge = connectionMaxAge return nil } } func newClient(requestCredentials requestCredentials, options ...Option) (*client, error) { url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/") if err != nil { return nil, err } if url.Scheme != "http" && url.Scheme != "https" { return nil, errors.New("only http or https scheme supported") } c := &client{ url: url, caFile: requestCredentials.GetTLSCAFile(), certFile: requestCredentials.GetTLSCertFile(), keyFile: requestCredentials.GetTLSKeyFile(), requestBackOffs: make(map[string]*backoff.Backoff), } c.requester = newRateLimitRequester(&c.Client) host := strings.Split(url.Host, ":")[0] if CertificateDirectory != "" { c.findCertificate(&c.caFile, CertificateDirectory, host+".crt") c.findCertificate(&c.certFile, CertificateDirectory, host+".auth.crt") c.findCertificate(&c.keyFile, CertificateDirectory, host+".auth.key") } for _, o := range options { err := o(c) if err != nil { return nil, err } } return c, nil }