pkg/awsutils/awssession/session.go (102 lines of code) (raw):
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package awssession
import (
"context"
"fmt"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/smithy-go"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
smithymiddleware "github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"strconv"
"time"
"github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger"
"github.com/aws/amazon-vpc-cni-k8s/utils"
)
// Http client timeout env for sessions
const (
httpTimeoutEnv = "HTTP_TIMEOUT"
maxRetries = 10
envVpcCniVersion = "VPC_CNI_VERSION"
)
var (
log = logger.Get()
// HTTP timeout default value in seconds (10 seconds)
httpTimeoutValue = 10 * time.Second
)
func getHTTPTimeout() time.Duration {
httpTimeoutEnvInput := os.Getenv(httpTimeoutEnv)
// if httpTimeout is not empty, we convert value to int and overwrite default httpTimeoutValue
if httpTimeoutEnvInput != "" {
input, err := strconv.Atoi(httpTimeoutEnvInput)
if err == nil && input >= 10 {
log.Debugf("Using HTTP_TIMEOUT %v", input)
httpTimeoutValue = time.Duration(input) * time.Second
return httpTimeoutValue
}
}
log.Warn("HTTP_TIMEOUT env is not set or set to less than 10 seconds, defaulting to httpTimeout to 10sec")
return httpTimeoutValue
}
// New will return aws.Config to be used by Service Clients.
func New(ctx context.Context) (aws.Config, error) {
httpClient := awshttp.NewBuildableClient().WithTimeout(getHTTPTimeout())
optFns := []func(*config.LoadOptions) error{
config.WithHTTPClient(httpClient),
config.WithRetryMaxAttempts(maxRetries),
config.WithRetryer(func() aws.Retryer {
return retry.NewStandard()
}),
injectUserAgent,
}
endpoint := os.Getenv("AWS_EC2_ENDPOINT")
if endpoint != "" {
optFns = append(optFns, config.WithEndpointResolver(aws.EndpointResolverFunc(
func(service, region string) (aws.Endpoint, error) {
if service == ec2.ServiceID {
return aws.Endpoint{
URL: endpoint,
}, nil
}
// Fall back to default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})))
}
cfg, err := config.LoadDefaultConfig(ctx, optFns...)
if err != nil {
return aws.Config{}, fmt.Errorf("failed to load AWS config: %w", err)
}
return cfg, nil
}
// injectUserAgent will inject app specific user-agent into awsSDK
func injectUserAgent(loadOptions *config.LoadOptions) error {
version := utils.GetEnv(envVpcCniVersion, "")
userAgent := fmt.Sprintf("amazon-vpc-cni-k8s/version/%s", version)
loadOptions.APIOptions = append(loadOptions.APIOptions, func(stack *smithymiddleware.Stack) error {
return stack.Build.Add(&addUserAgentMiddleware{
userAgent: userAgent,
}, smithymiddleware.After)
})
return nil
}
type addUserAgentMiddleware struct {
userAgent string
}
func (m *addUserAgentMiddleware) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) (out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {
// Simply pass through to the next handler in the middleware chain
return next.HandleBuild(ctx, in)
}
func (m *addUserAgentMiddleware) ID() string {
return "AddUserAgent"
}
func (m *addUserAgentMiddleware) HandleFinalize(ctx context.Context, in smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler) (
out smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, &smithy.SerializationError{Err: fmt.Errorf("unknown request type %T", in.Request)}
}
userAgent := req.Header.Get("User-Agent")
if userAgent == "" {
userAgent = m.userAgent
} else {
userAgent += " " + m.userAgent
}
req.Header.Set("User-Agent", userAgent)
return next.HandleFinalize(ctx, in)
}