internal/awstesting/util.go (143 lines of code) (raw):
package awstesting
import (
"context"
"io"
"net/http"
"os"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
)
// ZeroReader is a io.Reader which will always write zeros to the byte slice provided.
type ZeroReader struct{}
// Read fills the provided byte slice with zeros returning the number of bytes written.
func (r *ZeroReader) Read(b []byte) (int, error) {
for i := 0; i < len(b); i++ {
b[i] = 0
}
return len(b), nil
}
// ReadCloser is a io.ReadCloser for unit testing.
// Designed to test for leaks and whether a handle has
// been closed
type ReadCloser struct {
Size int
Closed bool
set bool
FillData func(bool, []byte, int, int)
}
// Read will call FillData and fill it with whatever data needed.
// Decrements the size until zero, then return io.EOF.
func (r *ReadCloser) Read(b []byte) (int, error) {
if r.Closed {
return 0, io.EOF
}
delta := len(b)
if delta > r.Size {
delta = r.Size
}
r.Size -= delta
for i := 0; i < delta; i++ {
b[i] = 'a'
}
if r.FillData != nil {
r.FillData(r.set, b, r.Size, delta)
}
r.set = true
if r.Size > 0 {
return delta, nil
}
return delta, io.EOF
}
// Close sets Closed to true and returns no error
func (r *ReadCloser) Close() error {
r.Closed = true
return nil
}
// A FakeContext provides a simple stub implementation of a Context
type FakeContext struct {
Error error
DoneCh chan struct{}
}
// Deadline always will return not set
func (c *FakeContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Done returns a read channel for listening to the Done event
func (c *FakeContext) Done() <-chan struct{} {
return c.DoneCh
}
// Err returns the error, is nil if not set.
func (c *FakeContext) Err() error {
return c.Error
}
// Value ignores the Value and always returns nil
func (c *FakeContext) Value(key interface{}) interface{} {
return nil
}
// StashEnv stashes the current environment variables except variables listed in envToKeepx
// Returns an function to pop out old environment
func StashEnv(envToKeep ...string) []string {
if runtime.GOOS == "windows" {
envToKeep = append(envToKeep, "ComSpec")
envToKeep = append(envToKeep, "SYSTEM32")
envToKeep = append(envToKeep, "SYSTEMROOT")
}
envToKeep = append(envToKeep, "PATH", "HOME", "USERPROFILE")
extraEnv := getEnvs(envToKeep)
originalEnv := os.Environ()
os.Clearenv() // clear env
for key, val := range extraEnv {
os.Setenv(key, val)
}
return originalEnv
}
// PopEnv takes the list of the environment values and injects them into the
// process's environment variable data. Clears any existing environment values
// that may already exist.
func PopEnv(env []string) {
os.Clearenv()
for _, e := range env {
p := strings.SplitN(e, "=", 2)
k, v := p[0], ""
if len(p) > 1 {
v = p[1]
}
os.Setenv(k, v)
}
}
// MockCredentialsProvider is a type that can be used to mock out credentials
// providers
type MockCredentialsProvider struct {
RetrieveFn func(ctx context.Context) (aws.Credentials, error)
InvalidateFn func()
}
// Retrieve calls the RetrieveFn
func (p MockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
return p.RetrieveFn(ctx)
}
// Invalidate calls the InvalidateFn
func (p MockCredentialsProvider) Invalidate() {
p.InvalidateFn()
}
func getEnvs(envs []string) map[string]string {
extraEnvs := make(map[string]string)
for _, env := range envs {
if val, ok := os.LookupEnv(env); ok && len(val) > 0 {
extraEnvs[env] = val
}
}
return extraEnvs
}
const (
signaturePreambleSigV4 = "AWS4-HMAC-SHA256"
signaturePreambleSigV4A = "AWS4-ECDSA-P256-SHA256"
)
// SigV4Signature represents a parsed sigv4 or sigv4a signature.
type SigV4Signature struct {
Preamble string // e.g. AWS4-HMAC-SHA256, AWS4-ECDSA-P256-SHA256
SigningName string // generally the service name e.g. "s3"
SigningRegion string // for sigv4a this is the region-set header as-is
SignedHeaders []string // list of signed headers
Signature string // calculated signature
}
// ParseSigV4Signature deconstructs a sigv4 or sigv4a signature from a set of
// request headers.
func ParseSigV4Signature(header http.Header) *SigV4Signature {
auth := header.Get("Authorization")
preamble, after, _ := strings.Cut(auth, " ")
credential, after, _ := strings.Cut(after, ", ")
signedHeaders, signature, _ := strings.Cut(after, ", ")
credentialParts := strings.Split(credential, "/")
// sigv4 : AccessKeyID/DateString/SigningRegion/SigningName/SignatureID
// sigv4a : AccessKeyID/DateString/SigningName/SignatureID, region set on
// header
var signingName, signingRegion string
if preamble == signaturePreambleSigV4 {
signingName = credentialParts[3]
signingRegion = credentialParts[2]
} else if preamble == signaturePreambleSigV4A {
signingName = credentialParts[2]
signingRegion = header.Get("X-Amz-Region-Set")
}
return &SigV4Signature{
Preamble: preamble,
SigningName: signingName,
SigningRegion: signingRegion,
SignedHeaders: strings.Split(signedHeaders, ";"),
Signature: signature,
}
}