aws_signing_helper/serve.go (284 lines of code) (raw):
package aws_signing_helper
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws/arn"
)
const DefaultPort = 9911
const DefaultHopLimit = 64
const LocalHostAddress = "127.0.0.1"
var RefreshTime = time.Minute * time.Duration(5)
type RefreshableCred struct {
AccessKeyId string
SecretAccessKey string
Token string
Code string
Type string
Expiration time.Time
LastUpdated time.Time
}
type Endpoint struct {
PortNum int
Server *http.Server
TmpCred RefreshableCred
}
type SessionToken struct {
Expiration time.Time
}
const TOKEN_RESOURCE_PATH = "/latest/api/token"
const SECURITY_CREDENTIALS_RESOURCE_PATH = "/latest/meta-data/iam/security-credentials/"
const EC2_METADATA_TOKEN_HEADER = "x-aws-ec2-metadata-token"
const EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds"
const DEFAULT_TOKEN_TTL_SECONDS = "21600"
const X_FORWARDED_FOR_HEADER = "X-Forwarded-For"
const REFRESHABLE_CRED_TYPE = "AWS-HMAC"
const REFRESHABLE_CRED_CODE = "Success"
const MAX_TOKENS = 256
var mutex sync.Mutex
var tokenMap = make(map[string]time.Time)
// Generates a random string with the specified length
func GenerateToken(length int) (string, error) {
if length < 0 || length >= 128 {
msg := "invalid token length"
return "", errors.New(msg)
}
randomBytes := make([]byte, 128)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(randomBytes)[:length], nil
}
// Removes the token that expires the earliest
func InsertToken(token string, expirationTime time.Time) error {
mutex.Lock()
if len(tokenMap) == MAX_TOKENS {
earliestExpirationTime := time.Unix(1<<63-1, 0)
var earliestExpiringToken string
for key, value := range tokenMap {
if earliestExpirationTime.After(value) {
earliestExpiringToken = key
earliestExpirationTime = value
}
}
delete(tokenMap, earliestExpiringToken)
log.Printf("evicting earliest expiring token: %s", earliestExpiringToken)
}
tokenMap[token] = expirationTime
mutex.Unlock()
return nil
}
// Helper function that checks to see whether the token provided in the request is valid
func CheckValidToken(w http.ResponseWriter, r *http.Request) error {
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
if token == "" {
w.WriteHeader(http.StatusUnauthorized)
msg := "no token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
mutex.Lock()
expiration, ok := tokenMap[token]
mutex.Unlock()
if ok {
if time.Now().After(expiration) {
w.WriteHeader(http.StatusUnauthorized)
msg := "invalid token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
} else {
w.WriteHeader(http.StatusUnauthorized)
msg := "invalid token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
return nil
}
// Helper function that finds a token's TTL in seconds
func FindTokenTTLSeconds(r *http.Request) (string, error) {
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
if token == "" {
msg := "no token provided"
return "", errors.New(msg)
}
mutex.Lock()
expiration, ok := tokenMap[token]
mutex.Unlock()
if ok {
tokenTTLFloat := expiration.Sub(time.Now()).Seconds()
tokenTTLInt64 := int64(tokenTTLFloat)
return strconv.FormatInt(tokenTTLInt64, 10), nil
} else {
msg := "invalid token provided"
return "", errors.New(msg)
}
}
func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) {
// Handles PUT requests to /latest/api/token/
putTokenHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Check for the presence of the X-Forwarded-For header
xForwardedForHeader := r.Header.Get(X_FORWARDED_FOR_HEADER) // canonicalized headers are used (casing doesn't matter)
if xForwardedForHeader != "" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "unable to process requests with X-Forwarded-For header")
return
}
// Obtain the token TTL
tokenTTLStr := r.Header.Get(EC2_METADATA_TOKEN_TTL_HEADER)
if tokenTTLStr == "" {
tokenTTLStr = DEFAULT_TOKEN_TTL_SECONDS
}
tokenTTL, err := strconv.Atoi(tokenTTLStr)
if err != nil || tokenTTL < 1 || tokenTTL > 21600 {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "invalid token TTL")
return
}
// Generate token and insert it into map
token, err := GenerateToken(100)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "unable to generate token")
return
}
expirationTime := time.Now().Add(time.Second * time.Duration(tokenTTL))
InsertToken(token, expirationTime)
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTLStr)
io.WriteString(w, token) // nosemgrep
}
// Handles requests to /latest/meta-data/iam/security-credentials/
getRoleNameHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
err := CheckValidToken(w, r)
if err != nil {
return
}
tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
io.WriteString(w, roleName) // nosemgrep
}
// Handles GET requests to /latest/meta-data/iam/security-credentials/<ROLE_NAME>
getCredentialsHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
err := CheckValidToken(w, r)
if err != nil {
log.Printf("Token validation received error: %s\n", err)
return
}
var nextRefreshTime = cred.Expiration.Add(-RefreshTime)
if time.Until(nextRefreshTime) < RefreshTime {
if Debug {
log.Println("Generating credentials")
}
credentialProcessOutput, gcErr := GenerateCredentials(opts, signer, signatureAlgorithm)
if gcErr != nil {
log.Printf("Error generating credentials: %s\n", gcErr)
}
cred.AccessKeyId = credentialProcessOutput.AccessKeyId
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
cred.Token = credentialProcessOutput.SessionToken
cred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
cred.Code = REFRESHABLE_CRED_CODE
cred.LastUpdated = time.Now()
cred.Type = REFRESHABLE_CRED_TYPE
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "failed to encode credentials")
return
}
} else {
if Debug {
log.Println("Using previously obtained credentials")
}
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "failed to encode credentials")
return
}
}
tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
}
return putTokenHandler, getRoleNameHandler, getCredentialsHandler
}
func Serve(port int, credentialsOptions CredentialsOpts) {
var refreshableCred = RefreshableCred{}
roleArn, err := arn.Parse(credentialsOptions.RoleArn)
if err != nil {
log.Println("invalid role ARN")
os.Exit(1)
}
signer, signatureAlgorithm, err := GetSigner(&credentialsOptions)
if err != nil {
log.Println(err)
os.Exit(1)
}
defer signer.Close()
credentialProcessOutput, _ := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm)
refreshableCred.AccessKeyId = credentialProcessOutput.AccessKeyId
refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
refreshableCred.Token = credentialProcessOutput.SessionToken
refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
refreshableCred.Code = REFRESHABLE_CRED_CODE
refreshableCred.LastUpdated = time.Now()
refreshableCred.Type = REFRESHABLE_CRED_TYPE
endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred}
endpoint.Server = &http.Server{}
roleResourceParts := strings.Split(roleArn.Resource, "/")
roleName := roleResourceParts[len(roleResourceParts)-1] // Find role name without path
putTokenHandler, getRoleNameHandler, getCredentialsHandler := AllIssuesHandlers(&endpoint.TmpCred, roleName, &credentialsOptions, signer, signatureAlgorithm)
http.HandleFunc(TOKEN_RESOURCE_PATH, putTokenHandler)
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH, getRoleNameHandler)
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH+roleName, getCredentialsHandler)
// Background thread that cleans up expired tokens
ticker := time.NewTicker(5 * time.Second)
go func() {
for range ticker.C {
curTime := time.Now()
mutex.Lock()
for key, value := range tokenMap {
if curTime.After(value) {
delete(tokenMap, key)
log.Printf("removed expired token: %s", key)
}
}
mutex.Unlock()
}
}()
// Start the credentials endpoint
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", LocalHostAddress, endpoint.PortNum))
if err != nil {
log.Println("failed to create listener")
os.Exit(1)
}
listener = NewListenerWithTTL(listener, credentialsOptions.ServerTTL)
endpoint.PortNum = listener.Addr().(*net.TCPAddr).Port
log.Println("Local server started on port:", endpoint.PortNum)
log.Println("Make it available to the sdk by running:")
log.Printf("export AWS_EC2_METADATA_SERVICE_ENDPOINT=http://%s:%d/", LocalHostAddress, endpoint.PortNum)
if err := endpoint.Server.Serve(listener); err != nil {
log.Println("Httpserver: ListenAndServe() error")
os.Exit(1)
}
}