utils/httputil/httputil.go (337 lines of code) (raw):

// Copyright (c) 2016-2019 Uber Technologies, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httputil import ( "context" "crypto/tls" "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" "time" "github.com/cenkalti/backoff" "github.com/go-chi/chi" "github.com/uber/kraken/core" "github.com/uber/kraken/utils/handler" ) var retryableCodes = map[int]struct{}{ http.StatusTooManyRequests: {}, http.StatusBadGateway: {}, http.StatusServiceUnavailable: {}, http.StatusGatewayTimeout: {}, } // RoundTripper is an alias of the http.RoundTripper for mocking purposes. type RoundTripper = http.RoundTripper // StatusError occurs if an HTTP response has an unexpected status code. type StatusError struct { Method string URL string Status int Header http.Header ResponseDump string } // NewStatusError returns a new StatusError. func NewStatusError(resp *http.Response) StatusError { defer resp.Body.Close() respBytes, err := ioutil.ReadAll(resp.Body) respDump := string(respBytes) if err != nil { respDump = fmt.Sprintf("failed to dump response: %s", err) } return StatusError{ Method: resp.Request.Method, URL: resp.Request.URL.String(), Status: resp.StatusCode, Header: resp.Header, ResponseDump: respDump, } } func (e StatusError) Error() string { if e.ResponseDump == "" { return fmt.Sprintf("%s %s %d", e.Method, e.URL, e.Status) } return fmt.Sprintf("%s %s %d: %s", e.Method, e.URL, e.Status, e.ResponseDump) } // IsStatus returns true if err is a StatusError of the given status. func IsStatus(err error, status int) bool { statusErr, ok := err.(StatusError) return ok && statusErr.Status == status } // IsCreated returns true if err is a "created", 201 func IsCreated(err error) bool { return IsStatus(err, http.StatusCreated) } // IsNotFound returns true if err is a "not found" StatusError. func IsNotFound(err error) bool { return IsStatus(err, http.StatusNotFound) } // IsConflict returns true if err is a "status conflict" StatusError. func IsConflict(err error) bool { return IsStatus(err, http.StatusConflict) } // IsAccepted returns true if err is a "status accepted" StatusError. func IsAccepted(err error) bool { return IsStatus(err, http.StatusAccepted) } // IsForbidden returns true if statis code is 403 "forbidden" func IsForbidden(err error) bool { return IsStatus(err, http.StatusForbidden) } func isRetryable(code int) bool { _, ok := retryableCodes[code] return ok } // IsRetryable returns true if the statis code indicates that the request is // retryable. func IsRetryable(err error) bool { statusErr, ok := err.(StatusError) return ok && isRetryable(statusErr.Status) } // NetworkError occurs on any Send error which occurred while trying to send // the HTTP request, e.g. the given host is unresponsive. type NetworkError struct { err error } func (e NetworkError) Error() string { return fmt.Sprintf("network error: %s", e.err) } // IsNetworkError returns true if err is a NetworkError. func IsNetworkError(err error) bool { _, ok := err.(NetworkError) return ok } type sendOptions struct { body io.Reader timeout time.Duration acceptedCodes map[int]bool headers map[string]string redirect func(req *http.Request, via []*http.Request) error retry retryOptions transport http.RoundTripper ctx context.Context // This is not a valid http option. It provides a way to override // parts of the url. For example, url.Scheme can be changed from // http to https. url *url.URL // This is not a valid http option. HTTP fallback is added to allow // easier migration from http to https. // In go1.11 and go1.12, the responses returned when http request is // sent to https server are different in the fallback mode: // go1.11 returns a network error whereas go1.12 returns BadRequest. // This causes TestTLSClientBadAuth to fail because the test checks // retry error. // This flag is added to allow disabling http fallback in unit tests. // NOTE: it does not impact how it runs in production. httpFallbackDisabled bool } // SendOption allows overriding defaults for the Send function. type SendOption func(*sendOptions) // SendNoop returns a no-op option. func SendNoop() SendOption { return func(o *sendOptions) {} } // SendBody specifies a body for http request func SendBody(body io.Reader) SendOption { return func(o *sendOptions) { o.body = body } } // SendTimeout specifies timeout for http request func SendTimeout(timeout time.Duration) SendOption { return func(o *sendOptions) { o.timeout = timeout } } // SendHeaders specifies headers for http request func SendHeaders(headers map[string]string) SendOption { return func(o *sendOptions) { o.headers = headers } } // SendAcceptedCodes specifies accepted codes for http request func SendAcceptedCodes(codes ...int) SendOption { m := make(map[int]bool) for _, c := range codes { m[c] = true } return func(o *sendOptions) { o.acceptedCodes = m } } // SendRedirect specifies a redirect policy for http request func SendRedirect(redirect func(req *http.Request, via []*http.Request) error) SendOption { return func(o *sendOptions) { o.redirect = redirect } } type retryOptions struct { backoff backoff.BackOff extraCodes map[int]bool } // RetryOption allows overriding defaults for the SendRetry option. type RetryOption func(*retryOptions) // RetryBackoff adds exponential backoff between retries. func RetryBackoff(b backoff.BackOff) RetryOption { return func(o *retryOptions) { o.backoff = b } } // RetryCodes adds more status codes to be retried (in addition to the default // 5XX codes). // // WARNING: You better know what you're doing to retry anything non-5XX. func RetryCodes(codes ...int) RetryOption { return func(o *retryOptions) { for _, c := range codes { o.extraCodes[c] = true } } } // SendRetry will we retry the request on network / 5XX errors. func SendRetry(options ...RetryOption) SendOption { retry := retryOptions{ backoff: backoff.WithMaxRetries( backoff.NewConstantBackOff(250*time.Millisecond), 2), extraCodes: make(map[int]bool), } for _, o := range options { o(&retry) } return func(o *sendOptions) { o.retry = retry } } // DisableHTTPFallback disables http fallback when https request fails. func DisableHTTPFallback() SendOption { return func(o *sendOptions) { o.httpFallbackDisabled = true } } // EnableHTTPFallback enables http fallback when https request fails. func EnableHTTPFallback() SendOption { return func(o *sendOptions) { o.httpFallbackDisabled = false } } // SendTLS sets the transport with TLS config for the HTTP client. func SendTLS(config *tls.Config) SendOption { return func(o *sendOptions) { if config == nil { return } o.transport = &http.Transport{TLSClientConfig: config} o.url.Scheme = "https" } } // SendTLSTransport sets the transport with TLS config for the HTTP client. func SendTLSTransport(transport http.RoundTripper) SendOption { return func(o *sendOptions) { o.transport = transport o.url.Scheme = "https" } } // SendTransport sets the transport for the HTTP client. func SendTransport(transport http.RoundTripper) SendOption { return func(o *sendOptions) { o.transport = transport } } // SendContext sets the context for the HTTP client. func SendContext(ctx context.Context) SendOption { return func(o *sendOptions) { o.ctx = ctx } } // Send sends an HTTP request. May return NetworkError or StatusError (see above). func Send(method, rawurl string, options ...SendOption) (*http.Response, error) { u, err := url.Parse(rawurl) if err != nil { return nil, fmt.Errorf("parse url: %s", err) } opts := &sendOptions{ body: nil, timeout: 60 * time.Second, acceptedCodes: map[int]bool{http.StatusOK: true}, headers: map[string]string{}, retry: retryOptions{backoff: &backoff.StopBackOff{}}, transport: nil, // Use HTTP default. ctx: context.Background(), url: u, httpFallbackDisabled: true, } for _, o := range options { o(opts) } req, err := newRequest(method, opts) if err != nil { return nil, err } client := &http.Client{ Timeout: opts.timeout, CheckRedirect: opts.redirect, Transport: opts.transport, } var resp *http.Response for { resp, err = client.Do(req) // Retry without tls. During migration there would be a time when the // component receiving the tls request does not serve https response. // TODO (@evelynl): disable retry after tls migration. if err != nil && req.URL.Scheme == "https" && !opts.httpFallbackDisabled { originalErr := err resp, err = fallbackToHTTP(client, method, opts) if err != nil { // Sometimes the request fails for a reason unrelated to https. // To keep this reason visible, we always include the original // error. err = fmt.Errorf( "failed to fallback https to http, original https error: %s,\n"+ "fallback http error: %s", originalErr, err) } } if err != nil || (isRetryable(resp.StatusCode) && !opts.acceptedCodes[resp.StatusCode]) || (opts.retry.extraCodes[resp.StatusCode]) { d := opts.retry.backoff.NextBackOff() if d == backoff.Stop { break // Backoff timed out. } time.Sleep(d) continue } break } if err != nil { return nil, NetworkError{err} } if !opts.acceptedCodes[resp.StatusCode] { return nil, NewStatusError(resp) } return resp, nil } // Get sends a GET http request. func Get(url string, options ...SendOption) (*http.Response, error) { return Send("GET", url, options...) } // Head sends a HEAD http request. func Head(url string, options ...SendOption) (*http.Response, error) { return Send("HEAD", url, options...) } // Post sends a POST http request. func Post(url string, options ...SendOption) (*http.Response, error) { return Send("POST", url, options...) } // Put sends a PUT http request. func Put(url string, options ...SendOption) (*http.Response, error) { return Send("PUT", url, options...) } // Patch sends a PATCH http request. func Patch(url string, options ...SendOption) (*http.Response, error) { return Send("PATCH", url, options...) } // Delete sends a DELETE http request. func Delete(url string, options ...SendOption) (*http.Response, error) { return Send("DELETE", url, options...) } // PollAccepted wraps GET requests for endpoints which require 202-polling. func PollAccepted( url string, b backoff.BackOff, options ...SendOption) (*http.Response, error) { b.Reset() for { resp, err := Get(url, options...) if err != nil { if IsAccepted(err) { d := b.NextBackOff() if d == backoff.Stop { break // Backoff timed out. } time.Sleep(d) continue } return nil, err } return resp, nil } return nil, errors.New("backoff timed out on 202 responses") } // GetQueryArg gets an argument from http.Request by name. // When the argument is not specified, it returns a default value. func GetQueryArg(r *http.Request, name string, defaultVal string) string { v := r.URL.Query().Get(name) if v == "" { v = defaultVal } return v } // ParseParam parses a parameter from url. func ParseParam(r *http.Request, name string) (string, error) { param := chi.URLParam(r, name) if param == "" { return "", handler.Errorf("param %s is required", name).Status(http.StatusBadRequest) } val, err := url.PathUnescape(param) if err != nil { return "", handler.Errorf("path unescape %s: %s", name, err).Status(http.StatusBadRequest) } return val, nil } // ParseDigest parses a digest from url. func ParseDigest(r *http.Request, name string) (core.Digest, error) { raw, err := ParseParam(r, name) if err != nil { return core.Digest{}, err } d, err := core.ParseSHA256Digest(raw) if err != nil { return core.Digest{}, handler.Errorf("parse digest: %s", err).Status(http.StatusBadRequest) } return d, nil } func newRequest(method string, opts *sendOptions) (*http.Request, error) { req, err := http.NewRequest(method, opts.url.String(), opts.body) if err != nil { return nil, fmt.Errorf("new request: %s", err) } req = req.WithContext(opts.ctx) if opts.body == nil { req.ContentLength = 0 } for key, val := range opts.headers { req.Header.Set(key, val) } return req, nil } func fallbackToHTTP( client *http.Client, method string, opts *sendOptions) (*http.Response, error) { req, err := newRequest(method, opts) if err != nil { return nil, err } req.URL.Scheme = "http" return client.Do(req) } func min(a, b time.Duration) time.Duration { if a < b { return a } return b }