credentials_provider.go (183 lines of code) (raw):
package sls
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"go.uber.org/atomic"
"github.com/go-kit/kit/log/level"
)
type CredentialsProvider interface {
/**
* GetCredentials is called everytime credentials are needed, the CredentialsProvider
* should cache credentials to avoid fetch credentials too frequently.
*
* @note GetCredentials must be thread-safe to avoid data race.
*/
GetCredentials() (Credentials, error)
}
// Create a static credential provider with AccessKeyID/AccessKeySecret/SecurityToken.
//
// Param accessKeyID and accessKeySecret must not be an empty string.
func NewStaticCredentialsProvider(accessKeyID, accessKeySecret, securityToken string) *StaticCredentialsProvider {
return &StaticCredentialsProvider{
Cred: Credentials{
AccessKeyID: accessKeyID,
AccessKeySecret: accessKeySecret,
SecurityToken: securityToken,
},
}
}
/**
* Create a new credentials provider, which uses a UpdateTokenFunction to fetch credentials.
* @param updateFunc The function to fetch credentials. If the error returned is not nil, use last saved credentials.
* This updateFunc will mayed called concurrently, and whill be called before the expiration of the last saved credentials.
*/
func NewUpdateFuncProviderAdapter(updateFunc UpdateTokenFunction) *UpdateFuncProviderAdapter {
fetcher := fetcherWithRetry(updateFuncFetcher(updateFunc), UPDATE_FUNC_RETRY_TIMES)
return &UpdateFuncProviderAdapter{
fetcher: fetcher,
fetchAhead: defaultFetchAhead,
}
}
/**
* Create a credentials provider that uses ecs ram role, only works on ecs.
*
*/
func NewEcsRamRoleCredentialsProvider(roleName string) CredentialsProvider {
fetcherFunc := fetcherWithRetry(newEcsRamRoleFetcher(ECS_RAM_ROLE_URL_PREFIX, roleName, nil), ECS_RAM_ROLE_RETRY_TIMES)
updateFunc := func() (accessKeyID, accessKeySecret, securityToken string, expireTime time.Time, err error) {
cred, err := fetcherFunc()
if err != nil {
return "", "", "", time.Now(), err
}
return cred.AccessKeyID, cred.AccessKeySecret, cred.SecurityToken, cred.Expiration, nil
}
return NewUpdateFuncProviderAdapter(updateFunc)
}
/**
* A static credetials provider that always returns the same long-lived credentials.
* For back compatible.
*/
type StaticCredentialsProvider struct {
Cred Credentials
}
func (p *StaticCredentialsProvider) GetCredentials() (Credentials, error) {
return p.Cred, nil
}
type CredentialsFetcher = func() (*tempCredentials, error)
const UPDATE_FUNC_RETRY_TIMES = 3
// Adapter for porting UpdateTokenFunc to a CredentialsProvider.
type UpdateFuncProviderAdapter struct {
cred atomic.Value // type *tempCredentials
fetcher CredentialsFetcher
fetchAhead time.Duration
}
func updateFuncFetcher(updateFunc UpdateTokenFunction) CredentialsFetcher {
return func() (*tempCredentials, error) {
id, secret, token, expireTime, err := updateFunc()
if err != nil {
return nil, fmt.Errorf("updateTokenFunc fetch credentials failed: %w", err)
}
res := newTempCredentials(id, secret, token, expireTime, time.Now())
if !res.isValid() {
return nil, fmt.Errorf("updateTokenFunc result not valid, expirationTime:%s",
expireTime.Format(time.RFC3339))
}
return res, nil
}
}
// If credentials expires or will be exipred soon, fetch a new credentials and return it.
//
// Otherwise returns the credentials fetched last time.
//
// Retry at most maxRetryTimes if failed to fetch.
func (adp *UpdateFuncProviderAdapter) GetCredentials() (Credentials, error) {
if !adp.shouldRefresh() {
res := adp.cred.Load().(*tempCredentials)
return res.Credentials, nil
}
level.Debug(Logger).Log("reason", "updateTokenFunc start to fetch new credentials")
res, err := adp.fetcher() // res.lastUpdatedTime is not valid, do not use its
if err == nil {
copy := *res
adp.cred.Store(©)
if res.Expiration.Before(time.Now()) {
level.Warn(Logger).Log("reason", "updateTokenFunc got a new credentials with expiration time before now",
"nextExpiration", res.Expiration,
)
} else {
level.Debug(Logger).Log("reason", "updateTokenFunc fetch new credentials succeed",
"nextExpiration", res.Expiration,
)
}
return res.Credentials, nil
}
lastCred := adp.cred.Load()
// use last saved credentials when failed to fetch new credentials
if lastCred != nil {
return lastCred.(*tempCredentials).Credentials, nil
}
return Credentials{}, fmt.Errorf("updateTokenFunc fail to fetch credentials, err:%w", err)
}
var defaultFetchAhead = time.Minute * 2
func (adp *UpdateFuncProviderAdapter) shouldRefresh() bool {
v := adp.cred.Load()
if v == nil {
return true
}
t := v.(*tempCredentials)
if t.isExpired() {
return true
}
return time.Now().Add(adp.fetchAhead).After(t.Expiration)
}
const ECS_RAM_ROLE_URL_PREFIX = "http://100.100.100.200/latest/meta-data/ram/security-credentials/"
const ECS_RAM_ROLE_RETRY_TIMES = 3
func newEcsRamRoleFetcher(urlPrefix, ramRole string, customClient *http.Client) CredentialsFetcher {
return func() (*tempCredentials, error) {
url := urlPrefix + ramRole
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("fail to build http request: %w", err)
}
var client *http.Client
if customClient != nil {
client = customClient
} else {
client = &http.Client{}
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("fail to do http request: %w", err)
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("fail to read http resp body: %w", err)
}
fetchResp := ecsRamRoleHttpResp{}
// 2. unmarshal
err = json.Unmarshal(data, &fetchResp)
if err != nil {
return nil, fmt.Errorf("fail to unmarshal json: %w, body: %s", err, string(data))
}
// 3. check json param
if !fetchResp.isValid() {
return nil, fmt.Errorf("invalid fetch result, body: %s", string(data))
}
return newTempCredentials(
fetchResp.AccessKeyID,
fetchResp.AccessKeySecret,
fetchResp.SecurityToken,
fetchResp.Expiration,
fetchResp.LastUpdated), nil
}
}
// Response struct for http response of ecs ram role fetch request
type ecsRamRoleHttpResp struct {
Code string `json:"Code"`
AccessKeyID string `json:"AccessKeyId"`
AccessKeySecret string `json:"AccessKeySecret"`
SecurityToken string `json:"SecurityToken"`
Expiration time.Time `json:"Expiration"`
LastUpdated time.Time `json:"LastUpdated"`
}
func (r *ecsRamRoleHttpResp) isValid() bool {
return strings.ToLower(r.Code) == "success" && r.AccessKeyID != "" &&
r.AccessKeySecret != "" && !r.Expiration.IsZero() && !r.LastUpdated.IsZero()
}
// Wraps a CredentialsFetcher with retry.
//
// @param retryTimes If <= 0, no retry will be performed.
func fetcherWithRetry(fetcher CredentialsFetcher, retryTimes int) CredentialsFetcher {
return func() (*tempCredentials, error) {
var errs []error
for i := 0; i <= retryTimes; i++ {
cred, err := fetcher()
if err == nil {
return cred, nil
}
errs = append(errs, err)
}
return nil, fmt.Errorf("exceed max retry times, last error: %w",
joinErrors(errs...))
}
}
// Replace this with errors.Join when go version >= 1.20
func joinErrors(errs ...error) error {
if errs == nil {
return nil
}
errStrs := make([]string, 0, len(errs))
for _, e := range errs {
errStrs = append(errStrs, e.Error())
}
return fmt.Errorf("[%s]", strings.Join(errStrs, ", "))
}