pkg/conn/conn.go (246 lines of code) (raw):
// Copyright 2018-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package conn
import (
"crypto/tls"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-xray-daemon/pkg/cfg"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
log "github.com/cihub/seelog"
"golang.org/x/net/http2"
)
type connAttr interface {
newAWSSession(roleArn string, region string) *session.Session
getEC2Region(s *session.Session) (string, error)
}
// Conn implements connAttr interface.
type Conn struct{}
func (c *Conn) getEC2Region(s *session.Session) (string, error) {
return ec2metadata.New(s).Region()
}
const (
STSEndpointPrefix = "https://sts."
STSEndpointSuffix = ".amazonaws.com"
STSAwsCnPartitionIDSuffix = ".amazonaws.com.cn" // AWS China partition.
)
// getNewHTTPClient returns new HTTP client instance with provided configuration.
func getNewHTTPClient(maxIdle int, requestTimeout int, noVerify bool, proxyAddress string) *http.Client {
log.Debugf("Using proxy address: %v", proxyAddress)
tls := &tls.Config{
InsecureSkipVerify: noVerify,
}
finalProxyAddress := getProxyAddress(proxyAddress)
proxyURL := getProxyURL(finalProxyAddress)
transport := &http.Transport{
MaxIdleConnsPerHost: maxIdle,
TLSClientConfig: tls,
IdleConnTimeout: 90 * time.Second, // Should be longer than PutTelemetryRecords call frequency: 60 seconds
Proxy: http.ProxyURL(proxyURL),
}
// is not enabled by default as we configure TLSClientConfig for supporting SSL to data plane.
// http2.ConfigureTransport will setup transport layer to use HTTP2
h2transport, err := http2.ConfigureTransports(transport)
if err != nil {
log.Warnf("Failed to configure HTTP2 transport: %v", err)
} else {
// Adding timeout settings to the http2 transport to prevent bad tcp connection hanging the requests for too long
// See: https://t.corp.amazon.com/P104567981
// Doc: https://pkg.go.dev/golang.org/x/net/http2#Transport
// - ReadIdleTimeout is the time before a ping is sent when no frame has been received from a connection
// - PingTimeout is the time before the TCP connection being closed if a Ping response is not received
// So in total, if a TCP connection goes bad, it would take the combined time before the TCP connection is closed
h2transport.ReadIdleTimeout = 1 * time.Second
h2transport.PingTimeout = 2 * time.Second
}
http := &http.Client{
Transport: transport,
Timeout: time.Second * time.Duration(requestTimeout),
}
return http
}
func getProxyAddress(proxyAddress string) string {
var finalProxyAddress string
if proxyAddress != "" {
finalProxyAddress = proxyAddress
} else if proxyAddress == "" && os.Getenv("HTTPS_PROXY") != "" {
finalProxyAddress = os.Getenv("HTTPS_PROXY")
} else {
finalProxyAddress = ""
}
return finalProxyAddress
}
func getProxyURL(finalProxyAddress string) *url.URL {
var proxyURL *url.URL
var err error
if finalProxyAddress != "" {
proxyURL, err = url.Parse(finalProxyAddress)
if err != nil {
log.Errorf("Bad proxy URL: %v", err)
os.Exit(1)
}
} else {
proxyURL = nil
}
return proxyURL
}
func getRegionFromECSMetadata() string {
var ecsMetadataEnabled string
var metadataFilePath string
var metadataFile []byte
var dat map[string]interface{}
var taskArn []string
var err error
var region string
region = ""
ecsMetadataEnabled = os.Getenv("ECS_ENABLE_CONTAINER_METADATA")
ecsMetadataEnabled = strings.ToLower(ecsMetadataEnabled)
if ecsMetadataEnabled == "true" {
metadataFilePath = os.Getenv("ECS_CONTAINER_METADATA_FILE")
metadataFile, err = ioutil.ReadFile(metadataFilePath)
if err != nil {
log.Errorf("Unable to open ECS metadata file: %v\n", err)
} else {
if err := json.Unmarshal(metadataFile, &dat); err != nil {
log.Errorf("Unable to read ECS metadata file contents: %v", err)
} else {
taskArn = strings.Split(dat["TaskARN"].(string), ":")
region = taskArn[3]
log.Debugf("Fetch region %v from ECS metadata file", region)
}
}
}
return region
}
// GetAWSConfigSession returns AWS config and session instances.
func GetAWSConfigSession(cn connAttr, c *cfg.Config, roleArn string, region string, noMetadata bool) (*aws.Config, *session.Session) {
var s *session.Session
var err error
var awsRegion string
http := getNewHTTPClient(cfg.ParameterConfigValue.Processor.MaxIdleConnPerHost, cfg.ParameterConfigValue.Processor.RequestTimeout, *c.NoVerifySSL, c.ProxyAddress)
regionEnv := os.Getenv("AWS_REGION")
if region == "" && regionEnv != "" {
awsRegion = regionEnv
log.Debugf("Fetch region %v from environment variables", awsRegion)
} else if region != "" {
awsRegion = region
log.Debugf("Fetch region %v from commandline/config file", awsRegion)
} else if !noMetadata {
awsRegion = getRegionFromECSMetadata()
if awsRegion == "" {
es := getDefaultSession()
awsRegion, err = cn.getEC2Region(es)
if err != nil {
log.Errorf("Unable to fetch region from EC2 metadata: %v\n", err)
} else {
log.Debugf("Fetch region %v from ec2 metadata", awsRegion)
}
}
} else {
es := getDefaultSession()
awsRegion = *es.Config.Region
log.Debugf("Fetched region %v from session config", awsRegion)
}
if awsRegion == "" {
log.Errorf("Cannot fetch region variable from config file, environment variables, ecs metadata, or ec2 metadata. Use local-mode to use the local session region.")
os.Exit(1)
}
s = cn.newAWSSession(roleArn, awsRegion)
config := &aws.Config{
Region: aws.String(awsRegion),
DisableParamValidation: aws.Bool(true),
MaxRetries: aws.Int(2),
Endpoint: aws.String(c.Endpoint),
HTTPClient: http,
}
return config, s
}
// ProxyServerTransport configures HTTP transport for TCP Proxy Server.
func ProxyServerTransport(config *cfg.Config) *http.Transport {
tls := &tls.Config{
InsecureSkipVerify: *config.NoVerifySSL,
}
proxyAddr := getProxyAddress(config.ProxyAddress)
proxyURL := getProxyURL(proxyAddr)
// Connection timeout in seconds
idleConnTimeout := time.Duration(config.ProxyServer.IdleConnTimeout) * time.Second
transport := &http.Transport{
MaxIdleConns: config.ProxyServer.MaxIdleConns,
MaxIdleConnsPerHost: config.ProxyServer.MaxIdleConnsPerHost,
IdleConnTimeout: idleConnTimeout,
Proxy: http.ProxyURL(proxyURL),
TLSClientConfig: tls,
// If not disabled the transport will add a gzip encoding header
// to requests with no `accept-encoding` header value. The header
// is added after we sign the request which invalidates the
// signature.
DisableCompression: true,
}
return transport
}
func (c *Conn) newAWSSession(roleArn string, region string) *session.Session {
var s *session.Session
var err error
if roleArn == "" {
s = getDefaultSession()
} else {
stsCreds := getSTSCreds(region, roleArn)
s, err = session.NewSession(&aws.Config{
Credentials: stsCreds,
})
if err != nil {
log.Errorf("Error in creating session object : %v\n.", err)
os.Exit(1)
}
}
return s
}
// getSTSCreds gets STS credentials from regional endpoint. ErrCodeRegionDisabledException is received if the
// STS regional endpoint is disabled. In this case STS credentials are fetched from STS primary regional endpoint
// in the respective AWS partition.
func getSTSCreds(region string, roleArn string) *credentials.Credentials {
t := getDefaultSession()
stsCred := getSTSCredsFromRegionEndpoint(t, region, roleArn)
// Make explicit call to fetch credentials.
_, err := stsCred.Get()
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeRegionDisabledException:
log.Errorf("Region : %v - %v", region, aerr.Error())
log.Info("Credentials for provided RoleARN will be fetched from STS primary region endpoint instead of regional endpoint.")
stsCred = getSTSCredsFromPrimaryRegionEndpoint(t, roleArn, region)
}
}
}
return stsCred
}
// getSTSCredsFromRegionEndpoint fetches STS credentials for provided roleARN from regional endpoint.
// AWS STS recommends that you provide both the Region and endpoint when you make calls to a Regional endpoint.
// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
func getSTSCredsFromRegionEndpoint(sess *session.Session, region string, roleArn string) *credentials.Credentials {
regionalEndpoint := getSTSRegionalEndpoint(region)
// if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG)
// for other opt-in regions, region value will create STS regional endpoint.
// This will be only in the case, if provided region is not present in aws_regions.go
c := &aws.Config{Region: aws.String(region), Endpoint: ®ionalEndpoint}
st := sts.New(sess, c)
log.Infof("STS Endpoint : %v", st.Endpoint)
return stscreds.NewCredentialsWithClient(st, roleArn)
}
// getSTSCredsFromPrimaryRegionEndpoint fetches STS credentials for provided roleARN from primary region endpoint in the
// respective partition.
func getSTSCredsFromPrimaryRegionEndpoint(t *session.Session, roleArn string, region string) *credentials.Credentials {
partitionId := getPartition(region)
if partitionId == endpoints.AwsPartitionID {
return getSTSCredsFromRegionEndpoint(t, endpoints.UsEast1RegionID, roleArn)
} else if partitionId == endpoints.AwsCnPartitionID {
return getSTSCredsFromRegionEndpoint(t, endpoints.CnNorth1RegionID, roleArn)
} else if partitionId == endpoints.AwsUsGovPartitionID {
return getSTSCredsFromRegionEndpoint(t, endpoints.UsGovWest1RegionID, roleArn)
}
return nil
}
func getSTSRegionalEndpoint(r string) string {
p := getPartition(r)
var e string
if p == endpoints.AwsPartitionID || p == endpoints.AwsUsGovPartitionID {
e = STSEndpointPrefix + r + STSEndpointSuffix
} else if p == endpoints.AwsCnPartitionID {
e = STSEndpointPrefix + r + STSAwsCnPartitionIDSuffix
}
return e
}
func getDefaultSession() *session.Session {
result, serr := session.NewSessionWithOptions(session.Options{SharedConfigState: session.SharedConfigEnable})
if serr != nil {
log.Errorf("Error in creating session object : %v\n.", serr)
os.Exit(1)
}
return result
}
// getPartition return AWS Partition for the provided region.
func getPartition(region string) string {
p, _ := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region)
return p.ID()
}