sharedlibraries/rest/rest.go (89 lines of code) (raw):

/* Copyright 2024 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package rest provides common functions for consuming REST APIs. package rest import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "time" "google.golang.org/api/googleapi" "golang.org/x/oauth2/google" "golang.org/x/oauth2" "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/log" ) type ( // defaultTokenGetter abstracts getting default oauth2 access token. defaultTokenGetter func(context.Context, ...string) (oauth2.TokenSource, error) httpClient interface { Do(req *http.Request) (*http.Response, error) } // Rest is a struct for making REST API calls. Rest struct { // HTTPClient abstracts the http client for testing purposes. HTTPClient httpClient // tokenGetter abstracts getting default oauth2 access token for testing purposes. TokenGetter defaultTokenGetter } // errorResponse is the response for REST API calls to cover generic api errors. errorResponse struct { Err googleapi.Error `json:"error"` } ) var ( defaultNewClient = func(timeout time.Duration, trans *http.Transport) httpClient { return &http.Client{Timeout: timeout, Transport: trans} } defaultTransport = func() *http.Transport { return &http.Transport{ Proxy: http.ProxyFromEnvironment, MaxConnsPerHost: 100, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } } ) // NewRest initializes the Rest with a new http client. func (r *Rest) NewRest() { r.HTTPClient = defaultNewClient(10*time.Minute, defaultTransport()) r.TokenGetter = google.DefaultTokenSource } // token fetches a token with default or workload identity federation credentials. func token(ctx context.Context, tokenGetter defaultTokenGetter) (*oauth2.Token, error) { tokenScope := "https://www.googleapis.com/auth/cloud-platform" tokenSource, err := tokenGetter(ctx, tokenScope) if err != nil { return nil, err } return tokenSource.Token() } // GetResponse creates a new request with given method, url and data and returns the response. func (r *Rest) GetResponse(ctx context.Context, method string, baseURL string, data []byte) ([]byte, error) { log.CtxLogger(ctx).Debugw("GetResponse", "method", method, "baseURL", baseURL, "data", string(data)) req, err := http.NewRequest(method, baseURL, bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("failed to create request, err: %w", err) } token, err := token(ctx, r.TokenGetter) if err != nil { return nil, fmt.Errorf("failed to get token, err: %w", err) } req.Header.Add("Authorization", "Bearer "+token.AccessToken) req.Header.Add("Content-Type", "application/json") token.SetAuthHeader(req) resp, err := r.HTTPClient.Do(req) defer googleapi.CloseBody(resp) if err != nil { return nil, err } bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body, err: %w", err) } if resp.StatusCode != http.StatusOK { var googleapiErr errorResponse if err = json.Unmarshal([]byte(bodyBytes), &googleapiErr); err != nil { return nil, fmt.Errorf("failed to unmarshal googleapi error, err: %w", err) } log.CtxLogger(ctx).Errorw("getresponse error", "error", googleapiErr) if googleapiErr.Err.Code != http.StatusOK { return nil, fmt.Errorf("%s", googleapiErr.Err.Message) } } return bodyBytes, nil }