pkg/awsutils/imds.go (564 lines of code) (raw):
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package awsutils
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"github.com/aws/smithy-go"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/pkg/errors"
)
// EC2MetadataIface is a subset of the EC2Metadata API.
type EC2MetadataIface interface {
GetMetadata(ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
}
// TypedIMDS is a typed wrapper around raw untyped IMDS SDK API.
type TypedIMDS struct {
EC2MetadataIface
}
// imdsRequestError to provide the caller on the request status
type imdsRequestError struct {
requestKey string
err error
code string // Added to support SDK V2 APIError interface
fault smithy.ErrorFault // Added to support SDK V2 APIError interface
}
var _ error = &imdsRequestError{}
func newIMDSRequestError(requestKey string, err error) *imdsRequestError {
return &imdsRequestError{
requestKey: requestKey,
err: err,
code: "IMDSRequestError", // default code
fault: smithy.FaultUnknown, // default fault
}
}
func (e *imdsRequestError) Error() string {
return fmt.Sprintf("failed to retrieve %s from instance metadata %v", e.requestKey, e.err)
}
func (e *imdsRequestError) Unwrap() error {
return e.err
}
// Implement smithy.APIError interface
func (e *imdsRequestError) ErrorCode() string {
// If wrapped error is an APIError, delegate to it
var apiErr smithy.APIError
if errors.As(e.err, &apiErr) {
return apiErr.ErrorCode()
}
return e.code
}
func (e *imdsRequestError) ErrorMessage() string {
return e.Error()
}
func (e *imdsRequestError) ErrorFault() smithy.ErrorFault {
// If wrapped error is an APIError, delegate to it
var apiErr smithy.APIError
if errors.As(e.err, &apiErr) {
return apiErr.ErrorFault()
}
return e.fault
}
func (e *imdsRequestError) HTTPStatusCode() int {
if resp, ok := e.err.(interface{ HTTPStatusCode() int }); ok {
return resp.HTTPStatusCode()
}
return 200
}
func (e *imdsRequestError) RequestID() string {
if resp, ok := e.err.(interface{ RequestID() string }); ok {
return resp.RequestID()
}
return ""
}
func (typedimds TypedIMDS) getList(ctx context.Context, key string) ([]string, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key,
})
if err != nil {
return nil, err
}
if output == nil || output.Content == nil {
return nil, newIMDSRequestError(key, fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return nil, newIMDSRequestError(key, fmt.Errorf("failed to read content: %w", err))
}
return strings.Fields(string(bytes)), nil
}
// GetAZ returns the Availability Zone in which the instance launched.
func (typedimds TypedIMDS) GetAZ(ctx context.Context) (string, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "placement/availability-zone"})
if err != nil {
return "", err
}
if output == nil || output.Content == nil {
return "", newIMDSRequestError("placement/availability-zone", fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return "", newIMDSRequestError("placement/availability-zone", fmt.Errorf("failed to read content: %w", err))
}
return strings.TrimSpace(string(bytes)), nil
}
// GetInstanceType returns the type of this instance.
func (typedimds TypedIMDS) GetInstanceType(ctx context.Context) (string, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "instance-type"})
if err != nil {
return "", err
}
if output == nil || output.Content == nil {
return "", newIMDSRequestError("instance-type", fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return "", newIMDSRequestError("instance-type", fmt.Errorf("failed to read content: %w", err))
}
return strings.TrimSpace(string(bytes)), nil
}
// GetLocalIPv4 returns the private (primary) IPv4 address of the instance.
func (typedimds TypedIMDS) GetLocalIPv4(ctx context.Context) (net.IP, error) {
return typedimds.getIP(ctx, "local-ipv4")
}
// GetInstanceID returns the ID of this instance.
func (typedimds TypedIMDS) GetInstanceID(ctx context.Context) (string, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "instance-id"})
if err != nil {
return "", err
}
if output == nil || output.Content == nil {
return "", newIMDSRequestError("instance-id", fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return "", newIMDSRequestError("instance-id", fmt.Errorf("failed to read content: %w", err))
}
return strings.TrimSpace(string(bytes)), nil
}
// GetMAC returns the first/primary network interface mac address.
func (typedimds TypedIMDS) GetMAC(ctx context.Context) (string, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "mac"})
if err != nil {
return "", err
}
if output == nil || output.Content == nil {
return "", newIMDSRequestError("mac", fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return "", newIMDSRequestError("mac", fmt.Errorf("failed to read content: %w", err))
}
return string(bytes), nil
}
// GetMACs returns the interface addresses attached to the instance.
func (typedimds TypedIMDS) GetMACs(ctx context.Context) ([]string, error) {
list, err := typedimds.getList(ctx, "network/interfaces/macs")
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
// Remove trailing /
for i, item := range list {
list[i] = strings.TrimSuffix(item, "/")
}
return list, err
}
// GetMACImdsFields returns the imds fields present for a MAC
func (typedimds TypedIMDS) GetMACImdsFields(ctx context.Context, mac string) ([]string, error) {
key := fmt.Sprintf("network/interfaces/macs/%s", mac)
list, err := typedimds.getList(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("%v", err)
return nil, imdsErr.err
}
return nil, err
}
// Remove trailing /
for i, item := range list {
list[i] = strings.TrimSuffix(item, "/")
}
return list, err
}
// GetInterfaceID returns the ID of the network interface.
func (typedimds TypedIMDS) GetInterfaceID(ctx context.Context, mac string) (string, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/interface-id", mac)
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key})
if err != nil {
return "", err
}
if output == nil || output.Content == nil {
return "", newIMDSRequestError(key, fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return "", newIMDSRequestError(key, fmt.Errorf("failed to read content: %w", err))
}
return string(bytes), nil
}
func (typedimds TypedIMDS) getInt(ctx context.Context, key string) (int, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key})
if err != nil {
return 0, err
}
if output == nil || output.Content == nil {
return 0, newIMDSRequestError(key, fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return 0, newIMDSRequestError(key, fmt.Errorf("failed to read content: %w", err))
}
dataInt, err := strconv.Atoi(strings.TrimSpace(string(bytes)))
if err != nil {
return 0, err
}
return dataInt, err
}
// GetDeviceNumber returns the unique device number associated with an interface. The primary interface is 0.
func (typedimds TypedIMDS) GetDeviceNumber(ctx context.Context, mac string) (int, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/device-number", mac)
return typedimds.getInt(ctx, key)
}
// GetSubnetID returns the ID of the subnet in which the interface resides.
func (typedimds TypedIMDS) GetSubnetID(ctx context.Context, mac string) (string, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/subnet-id", mac)
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key,
})
// Read the content first, even if there's an error
var subnetID string
if output != nil && output.Content != nil {
defer output.Content.Close()
bytes, readErr := io.ReadAll(output.Content)
if readErr == nil {
subnetID = string(bytes)
}
}
// Now handle any errors, but return subnetID if it was read
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("Warning: %v", err)
return subnetID, newIMDSRequestError(err.Error(), err)
}
return "", err
}
return subnetID, nil
}
func (typedimds TypedIMDS) GetVpcID(ctx context.Context, mac string) (string, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/vpc-id", mac)
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key,
})
// Read the content first, even if there's an error
var vpcID string
if output != nil && output.Content != nil {
defer output.Content.Close()
bytes, readErr := io.ReadAll(output.Content)
if readErr == nil {
vpcID = string(bytes)
}
}
// Handle errors but preserve any partial vpcID data
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("Warning: %v", err)
return vpcID, newIMDSRequestError(err.Error(), err)
}
return "", err
}
return vpcID, nil
}
// GetSecurityGroupIDs returns the IDs of the security groups to which the network interface belongs.
func (typedimds TypedIMDS) GetSecurityGroupIDs(ctx context.Context, mac string) ([]string, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/security-group-ids", mac)
sgs, err := typedimds.getList(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("%v", err)
return sgs, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return sgs, err
}
func (typedimds TypedIMDS) getIP(ctx context.Context, key string) (net.IP, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key})
if err != nil {
return nil, err
}
if output == nil || output.Content == nil {
return nil, newIMDSRequestError(key, fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return nil, newIMDSRequestError(key, fmt.Errorf("failed to read content: %w", err))
}
ip := net.ParseIP(strings.TrimSpace(string(bytes)))
if ip == nil {
err = &net.ParseError{Type: "IP address", Text: string(bytes)}
}
return ip, err
}
func (typedimds TypedIMDS) getIPs(ctx context.Context, key string) ([]net.IP, error) {
list, err := typedimds.getList(ctx, key)
if err != nil {
return nil, err
}
ips := make([]net.IP, len(list))
for i, item := range list {
ip := net.ParseIP(item)
if ip == nil {
err = &net.ParseError{Type: "IP address", Text: item}
return nil, err
}
ips[i] = ip
}
return ips, err
}
func (typedimds TypedIMDS) getCIDR(ctx context.Context, key string) (net.IPNet, error) {
output, err := typedimds.GetMetadata(ctx, &imds.GetMetadataInput{
Path: key})
if err != nil {
return net.IPNet{}, err
}
if output == nil || output.Content == nil {
return net.IPNet{}, newIMDSRequestError(key, fmt.Errorf("empty response"))
}
defer output.Content.Close()
bytes, err := io.ReadAll(output.Content)
if err != nil {
return net.IPNet{}, newIMDSRequestError(key, fmt.Errorf("failed to read content: %w", err))
}
data := strings.TrimSpace(string(bytes))
ip, network, err := net.ParseCIDR(data)
if err != nil {
return net.IPNet{}, err
}
// Why doesn't net.ParseCIDR just return values in this form?
cidr := net.IPNet{IP: ip, Mask: network.Mask}
return cidr, err
}
func (typedimds TypedIMDS) getCIDRs(ctx context.Context, key string) ([]net.IPNet, error) {
list, err := typedimds.getList(ctx, key)
if err != nil {
return nil, err
}
cidrs := make([]net.IPNet, len(list))
for i, item := range list {
ip, network, err := net.ParseCIDR(item)
if err != nil {
return nil, err
}
// Why doesn't net.ParseCIDR just return values in this form?
cidrs[i] = net.IPNet{IP: ip, Mask: network.Mask}
}
return cidrs, nil
}
// GetLocalIPv4s returns the private IPv4 addresses associated with the interface. First returned address is the primary address.
func (typedimds TypedIMDS) GetLocalIPv4s(ctx context.Context, mac string) ([]net.IP, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/local-ipv4s", mac)
ips, err := typedimds.getIPs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return ips, err
}
// GetIPv4Prefixes returns the IPv4 prefixes delegated to this interface
func (typedimds TypedIMDS) GetIPv4Prefixes(ctx context.Context, mac string) ([]net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/ipv4-prefix", mac)
prefixes, err := typedimds.getCIDRs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
if IsNotFound(err) {
return nil, nil
}
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return prefixes, err
}
// GetIPv6Prefixes returns the IPv6 prefixes delegated to this interface
func (typedimds TypedIMDS) GetIPv6Prefixes(ctx context.Context, mac string) ([]net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/ipv6-prefix", mac)
prefixes, err := typedimds.getCIDRs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
if IsNotFound(err) {
return nil, nil
}
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return prefixes, err
}
// GetIPv6s returns the IPv6 addresses associated with the interface.
func (typedimds TypedIMDS) GetIPv6s(ctx context.Context, mac string) ([]net.IP, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/ipv6s", mac)
ips, err := typedimds.getIPs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
if IsNotFound(err) {
// No IPv6. Not an error, just a disappointment :(
return nil, nil
}
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return ips, err
}
// GetSubnetIPv4CIDRBlock returns the IPv4 CIDR block for the subnet in which the interface resides.
func (typedimds TypedIMDS) GetSubnetIPv4CIDRBlock(ctx context.Context, mac string) (net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/subnet-ipv4-cidr-block", mac)
return typedimds.getCIDR(ctx, key)
}
// GetVPCIPv4CIDRBlocks returns the IPv4 CIDR blocks for the VPC.
func (typedimds TypedIMDS) GetVPCIPv4CIDRBlocks(ctx context.Context, mac string) ([]net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/vpc-ipv4-cidr-blocks", mac)
cidrs, err := typedimds.getCIDRs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
log.Warnf("%v", err)
return cidrs, newIMDSRequestError(err.Error(), err)
}
return nil, err
}
return cidrs, err
}
// GetVPCIPv6CIDRBlocks returns the IPv6 CIDR blocks for the VPC.
func (typedimds TypedIMDS) GetVPCIPv6CIDRBlocks(ctx context.Context, mac string) ([]net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/vpc-ipv6-cidr-blocks", mac)
ipnets, err := typedimds.getCIDRs(ctx, key)
if err != nil {
imdsErr := new(imdsRequestError)
oe := new(smithy.OperationError)
if errors.As(err, &imdsErr) || errors.As(err, &oe) {
if IsNotFound(err) {
// No IPv6. Not an error, just a disappointment :(
return nil, nil
}
log.Warnf("%v", err)
return nil, newIMDSRequestError(err.Error(), err)
}
return nil, nil
}
return ipnets, err
}
// GetSubnetIPv6CIDRBlocks returns the IPv6 CIDR block for the subnet in which the interface resides.
func (typedimds TypedIMDS) GetSubnetIPv6CIDRBlocks(ctx context.Context, mac string) (net.IPNet, error) {
key := fmt.Sprintf("network/interfaces/macs/%s/subnet-ipv6-cidr-blocks", mac)
return typedimds.getCIDR(ctx, key)
}
// IsNotFound returns true if the error was caused by an AWS API 404 response.
// We implement a Custom IMDS Error, so need to use APIError instead of HTTP Response Error
func IsNotFound(err error) bool {
if err == nil {
return false
}
// Check for AWS ResponseError first
var re *awshttp.ResponseError
if errors.As(err, &re) {
return re.Response.StatusCode == http.StatusNotFound
}
var oe *smithy.OperationError
if errors.As(err, &oe) {
// Check if the error message contains status code 404
return strings.Contains(oe.Error(), "StatusCode: 404")
}
// Check for any APIError (including imdsRequestError)
var ae smithy.APIError
if errors.As(err, &ae) {
// If it's our custom error with a wrapped ResponseError, check that
if imdsErr, ok := ae.(*imdsRequestError); ok {
return IsNotFound(imdsErr.err)
}
// Otherwise check if the error code indicates NotFound
return ae.ErrorCode() == "NotFound"
}
return false
}
// FakeIMDS is a trivial implementation of EC2MetadataIface using an in-memory map - for testing.
type FakeIMDS map[string]interface{}
func (f FakeIMDS) GetMetadata(ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error) {
result, ok := f[params.Path]
if !ok {
result, ok = f[params.Path+"/"] // Metadata API treats foo/ as foo
}
if !ok {
notFoundErr := &CustomRequestFailure{
code: "NotFound",
message: "not found",
fault: smithy.FaultUnknown,
statusCode: http.StatusNotFound,
requestID: "dummy-reqid",
}
return nil, newIMDSRequestError(params.Path, notFoundErr)
}
switch v := result.(type) {
case string:
return &imds.GetMetadataOutput{
Content: io.NopCloser(strings.NewReader(v)),
}, nil
case error:
return nil, v
default:
panic(fmt.Sprintf("unknown test metadata value type %T for %s", result, params.Path))
}
}
// Custom error type
type CustomRequestFailure struct {
code string
message string
fault smithy.ErrorFault
statusCode int
requestID string
}
func (e *CustomRequestFailure) Error() string {
return fmt.Sprintf("%s: %s", e.code, e.message)
}
func (e *CustomRequestFailure) ErrorCode() string {
return e.code
}
func (e *CustomRequestFailure) ErrorMessage() string {
return e.message
}
func (e *CustomRequestFailure) ErrorFault() smithy.ErrorFault {
return e.fault
}
func (e *CustomRequestFailure) HTTPStatusCode() int {
return e.statusCode
}
func (e *CustomRequestFailure) RequestID() string {
return e.requestID
}
// GetMetadataWithContext implements the EC2MetadataIface interface.
func (f FakeIMDS) GetMetadataWithContext(ctx context.Context, p string) (string, error) {
result, ok := f[p]
if !ok {
result, ok = f[p+"/"] // Metadata API treats foo/ as foo
}
if !ok {
notFoundErr := &CustomRequestFailure{
code: "NotFound",
message: "not found",
fault: smithy.FaultUnknown,
statusCode: http.StatusNotFound,
requestID: "dummy-reqid",
}
return "", newIMDSRequestError(p, notFoundErr)
}
switch v := result.(type) {
case string:
return v, nil
case error:
return "", v
default:
panic(fmt.Sprintf("unknown test metadata value type %T for %s", result, p))
}
}