pkg/common/httpclient/retry_client.go (100 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httpclient import ( "bytes" "context" "fmt" "io" "log/slog" "net/http" "time" "github.com/GoogleCloudPlatform/khi/pkg/common/token" ) type RetryHttpClient struct { Client HTTPClient[*http.Response] MinWaitSeconds int MaxWaitSeconds int MaxRetryCount int RetriableHttpCodes []int RetriableWithRefreshTokenHttpCodes []int currentWaitSeconds int timeUnit time.Duration // For testing purpose to make test faster tokenRefresher token.TokenRefresher } func NewRetryHttpClient(baseClient HTTPClient[*http.Response], minWaitSeconds int, maxWaitSeconds int, maxRetryCount int, retriableHttpCodes []int, retriableWithRefreshTokenHttpCodes []int, tokenRefresher token.TokenRefresher) *RetryHttpClient { return &RetryHttpClient{ Client: baseClient, MinWaitSeconds: minWaitSeconds, MaxWaitSeconds: maxWaitSeconds, MaxRetryCount: maxRetryCount, RetriableHttpCodes: retriableHttpCodes, RetriableWithRefreshTokenHttpCodes: retriableWithRefreshTokenHttpCodes, currentWaitSeconds: minWaitSeconds, timeUnit: time.Second, tokenRefresher: tokenRefresher, } } // DoWithContext implements HttpClient. func (r *RetryHttpClient) DoWithContext(ctx context.Context, originalRequest *http.Request) (*http.Response, error) { // Clone request body into array to create another reader of Body on retry. var clonedRequest []byte if originalRequest.Body != nil { var err error clonedRequest, err = io.ReadAll(originalRequest.Body) if err != nil { return nil, err } } statusCodes := []int{} for i := 0; i < r.MaxRetryCount; i++ { request, err := http.NewRequestWithContext(ctx, originalRequest.Method, originalRequest.URL.String(), bytes.NewBuffer(clonedRequest)) if err != nil { return nil, err } request.Header = originalRequest.Header.Clone() response, err := r.Client.DoWithContext(ctx, request) if err != nil { return nil, err } if response.StatusCode < 400 { r.currentWaitSeconds = r.MinWaitSeconds // Treat this response is ok not to retry return response, nil } if !r.isRetriable(response.StatusCode) { body := []byte{} if response.Body != nil { body, _ = io.ReadAll(response.Body) } return response, fmt.Errorf("unretriable error returned(%d):%s\nBODY:%s", response.StatusCode, response.Status, string(body)) } else { statusCodes = append(statusCodes, response.StatusCode) if r.isRetriableWithRefreshingToken(response.StatusCode) { slog.DebugContext(ctx, fmt.Sprintf("Previous request to %s got %d response. Attempting retrying with refreshing the token.", request.RequestURI, response.StatusCode)) r.tokenRefresher.Refresh(ctx) r.currentWaitSeconds = r.MinWaitSeconds } else { r.currentWaitSeconds *= 2 if r.currentWaitSeconds > r.MaxWaitSeconds { r.currentWaitSeconds = r.MaxWaitSeconds } slog.DebugContext(ctx, fmt.Sprintf("Previous request to %s got %d response. Next retry after %d seconds", request.RequestURI, response.StatusCode, r.currentWaitSeconds)) time.Sleep(r.timeUnit * time.Duration(r.currentWaitSeconds)) } } } return nil, fmt.Errorf("maximum retry count exceeded %d\nStatus codes:%v", r.MaxRetryCount, statusCodes) } func (r *RetryHttpClient) isRetriable(code int) bool { for _, retryCode := range r.RetriableHttpCodes { if code == retryCode { return true } } return r.isRetriableWithRefreshingToken(code) } func (r *RetryHttpClient) isRetriableWithRefreshingToken(code int) bool { for _, retryCode := range r.RetriableWithRefreshTokenHttpCodes { if code == retryCode { return true } } return false } var _ (HTTPClient[*http.Response]) = (*RetryHttpClient)(nil)