internal/apiclient/token.go (267 lines of code) (raw):

// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package apiclient import ( "context" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "internal/clilog" "io" "net/http" "net/url" "os" "reflect" "strconv" "strings" "time" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "golang.org/x/oauth2/google" ) type serviceAccount struct { Type string `json:"type,omitempty"` ProjectID string `json:"project_id,omitempty"` PrivateKeyID string `json:"private_key_id,omitempty"` PrivateKey string `json:"private_key,omitempty"` ClientEmail string `json:"client_email,omitempty"` ClientID string `json:"client_id,omitempty"` AuthURI string `json:"auth_uri,omitempty"` TokenURI string `json:"token_uri,omitempty"` AuthProviderCertURL string `json:"auth_provider_x509_cert_url,omitempty"` ClientCertURL string `json:"client_x509_cert_url,omitempty"` } var account = serviceAccount{} const tokenUri = "https://www.googleapis.com/oauth2/v4/token" func getPrivateKey(privateKey string) (interface{}, error) { pemPrivateKey := fmt.Sprintf("%v", privateKey) block, _ := pem.Decode([]byte(pemPrivateKey)) privKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { clilog.Error.Println("error parsing Private Key: ", err) return nil, err } return privKey, nil } func generateJWT(privateKey string) (string, error) { const scope = "https://www.googleapis.com/auth/cloud-platform" privKey, err := getPrivateKey(privateKey) if err != nil { return "", err } now := time.Now() // Google OAuth takes aud as a string, not array // ref: https://github.com/lestrrat-go/jwx/releases/tag/v2.0.7 jwt.Settings(jwt.WithFlattenAudience(true)) token := jwt.New() token.Options().IsEnabled(jwt.FlattenAudience) _ = token.Set("aud", tokenUri) _ = token.Set(jwt.IssuerKey, getServiceAccountProperty("ClientEmail")) _ = token.Set("scope", scope) _ = token.Set(jwt.IssuedAtKey, now.Unix()) _ = token.Set(jwt.ExpirationKey, now.Unix()) payload, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, privKey)) if err != nil { clilog.Error.Println("error parsing Private Key: ", err) return "", err } clilog.Debug.Println("jwt token : ", string(payload)) return string(payload), nil } // generateAccessToken generates a Google OAuth access token from a service account func generateAccessToken(privateKey string) (string, error) { const grantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" var respBody []byte // oAuthAccessToken is a structure to hold OAuth response type oAuthAccessToken struct { AccessToken string `json:"access_token,omitempty"` ExpiresIn int `json:"expires_in,omitempty"` TokenType string `json:"token_type,omitempty"` } token, err := generateJWT(privateKey) if err != nil { return "", nil } form := url.Values{} form.Add("grant_type", grantType) form.Add("assertion", token) client := &http.Client{} req, err := http.NewRequest("POST", tokenUri, strings.NewReader(form.Encode())) if err != nil { clilog.Error.Println("error in client: ", err) return "", err } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) resp, err := client.Do(req) if err != nil { clilog.Error.Println("failed to generate oauth token: ", err) return "", err } if resp != nil { defer resp.Body.Close() } if resp == nil { clilog.Error.Println("error in response: Response was null") return "", errors.New("error in response: Response was null") } respBody, err = io.ReadAll(resp.Body) clilog.Debug.Printf("Response: %s\n", string(respBody)) if err != nil { clilog.Error.Printf("error in response: %v\n", err) return "", err } else if resp.StatusCode > 399 { clilog.Error.Printf("status code %d, error in response: %s\n", resp.StatusCode, string(respBody)) return "", fmt.Errorf("status code %d, error in response: %s", resp.StatusCode, string(respBody)) } accessToken := oAuthAccessToken{} if err = json.Unmarshal(respBody, &accessToken); err != nil { return "", err } clilog.Debug.Println("access token : ", accessToken) SetIntegrationToken(accessToken.AccessToken) _ = writeToken(accessToken.AccessToken) return accessToken.AccessToken, nil } func readServiceAccount(serviceAccountPath string) error { content, err := os.ReadFile(serviceAccountPath) if err != nil { return err } err = json.Unmarshal(content, &account) if err != nil { return err } return nil } func getServiceAccountProperty(key string) (value string) { r := reflect.ValueOf(&account) field := reflect.Indirect(r).FieldByName(key) return field.String() } func checkAccessToken() bool { if TokenCheckEnabled() { clilog.Debug.Println("skipping token validity") return true } const tokenInfo = "https://oauth2.googleapis.com/tokeninfo" u, _ := url.Parse(tokenInfo) q := u.Query() q.Set("access_token", GetIntegrationToken()) u.RawQuery = q.Encode() client := &http.Client{} clilog.Debug.Println("Connecting to : ", u.String()) req, err := http.NewRequest("GET", u.String(), nil) if err != nil { clilog.Error.Println("error in client:", err) return false } resp, err := client.Do(req) if err != nil { clilog.Error.Println("error connecting to token endpoint: ", err) return false } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { clilog.Error.Println("token info error: ", err) return false } else if resp.StatusCode != 200 { clilog.Error.Println("token expired: ", string(body)) return false } clilog.Debug.Println("Response: ", string(body)) clilog.Debug.Println("Reusing the cached token: ", GetIntegrationToken()) return true } // SetAccessToken read from cache or if not found or expired will generate a new one func SetAccessToken() error { if GetIntegrationToken() == "" && GetServiceAccount() == "" { SetIntegrationToken(getToken()) // read from configuration if GetIntegrationToken() == "" { return fmt.Errorf("either token or service account must be provided") } if checkAccessToken() { // check if the token is still valid return nil } return fmt.Errorf("token expired: request a new access token or pass the service account") } if GetIntegrationToken() != "" { // a token was passed, cache it if checkAccessToken() { _ = writeToken(GetIntegrationToken()) return nil } } else { err := readServiceAccount(GetServiceAccount()) if err != nil { // Handle errors reading the config file return fmt.Errorf("error reading config file: %s", err) } privateKey := getServiceAccountProperty("PrivateKey") if privateKey == "" { return fmt.Errorf("private key missing in the service account") } if getServiceAccountProperty("ClientEmail") == "" { return fmt.Errorf("client email missing in the service account") } _, err = generateAccessToken(privateKey) if err != nil { return fmt.Errorf("fatal error generating access token: %s", err) } return nil } return fmt.Errorf("token expired: request a new access token or pass the service account") } // GetDefaultAccessToken func GetDefaultAccessToken() (err error) { ctx := context.Background() tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/cloud-platform") if err != nil { return err } token, err := tokenSource.Token() if err != nil { return err } SetIntegrationToken(token.AccessToken) return nil } // GetMetadataAccessToken func GetMetadataAccessToken() (err error) { var req *http.Request var tokenResponse map[string]interface{} metadataURL := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" client, err := getHttpClient() if err != nil { return err } if DryRun() { return nil } clilog.Debug.Println("Connecting to: ", metadataURL) req, err = http.NewRequest(http.MethodGet, metadataURL, nil) if err != nil { clilog.Error.Println("error in client: ", err) return err } req.Header.Set("Metadata-Flavor", "Google") resp, err := client.Do(req) if err != nil { clilog.Error.Println("error connecting: ", err) return err } if resp != nil { defer resp.Body.Close() } if resp == nil { clilog.Error.Println("error in response: Response was null") return fmt.Errorf("error in response: Response was null") } respBody, err := io.ReadAll(resp.Body) if err != nil { clilog.Error.Println("error in response: ", err) return err } else if resp.StatusCode > 399 { clilog.Debug.Printf("status code %d, error in response: %s\n", resp.StatusCode, string(respBody)) clilog.HTTPError.Println(string(respBody)) return errors.New(getErrorMessage(resp.StatusCode)) } err = json.Unmarshal(respBody, &tokenResponse) if err != nil { return err } SetIntegrationToken(tokenResponse["access_token"].(string)) return nil }