aws-http-auth/sigv4a/credentials.go (126 lines of code) (raw):

package sigv4a import ( "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/hmac" "crypto/sha256" "encoding/binary" "fmt" "hash" "math" "math/big" "sync" "github.com/aws/smithy-go/aws-http-auth/credentials" ) var ( p256 elliptic.Curve nMinusTwoP256 *big.Int one = new(big.Int).SetInt64(1) ) func init() { p256 = elliptic.P256() nMinusTwoP256 = new(big.Int).SetBytes(p256.Params().N.Bytes()) nMinusTwoP256 = nMinusTwoP256.Sub(nMinusTwoP256, new(big.Int).SetInt64(2)) } // ecdsaCache stores the result of deriving an ECDSA private key from a // shared-secret identity. type ecdsaCache struct { mu sync.Mutex akid string priv *ecdsa.PrivateKey } // Derive computes and caches the ECDSA key-pair for the identity, returning // the result. // // Future calls to Derive with the same set of credentials (identified by AKID) // will short-circuit. Future calls with a different set of credentials // (identified by AKID) will re-derive the value, overwriting the old result. func (c *ecdsaCache) Derive(creds credentials.Credentials) (*ecdsa.PrivateKey, error) { c.mu.Lock() defer c.mu.Unlock() if creds.AccessKeyID == c.akid { return c.priv, nil } priv, err := derivePrivateKey(creds) if err != nil { return nil, err } c.akid = creds.AccessKeyID c.priv = priv return priv, nil } // derivePrivateKey derives a NIST P-256 PrivateKey from the given IAM // AccessKey and SecretKey pair. // // Based on FIPS.186-4 Appendix B.4.2 func derivePrivateKey(creds credentials.Credentials) (*ecdsa.PrivateKey, error) { akid := creds.AccessKeyID secret := creds.SecretAccessKey params := p256.Params() bitLen := params.BitSize // Testing random candidates does not require an additional 64 bits counter := 0x01 buffer := make([]byte, 1+len(akid)) // 1 byte counter + len(accessKey) kdfContext := bytes.NewBuffer(buffer) inputKey := append([]byte("AWS4A"), []byte(secret)...) d := new(big.Int) for { kdfContext.Reset() kdfContext.WriteString(akid) kdfContext.WriteByte(byte(counter)) key, err := deriveHMACKey(sha256.New, bitLen, inputKey, []byte(algorithm), kdfContext.Bytes()) if err != nil { return nil, err } cmp, err := cmpConst(key, nMinusTwoP256.Bytes()) if err != nil { return nil, err } if cmp == -1 { d.SetBytes(key) break } counter++ if counter > 0xFF { return nil, fmt.Errorf("exhausted single byte external counter") } } d = d.Add(d, one) priv := new(ecdsa.PrivateKey) priv.PublicKey.Curve = p256 priv.D = d priv.PublicKey.X, priv.PublicKey.Y = p256.ScalarBaseMult(d.Bytes()) return priv, nil } // deriveHMACKey provides an implementation of a NIST-800-108 of a KDF (Key // Derivation Function) in Counter Mode. HMAC is used as the pseudorandom // function, where the value of `r` is defined as a 4-byte counter. func deriveHMACKey(hash func() hash.Hash, bitLen int, key []byte, label, context []byte) ([]byte, error) { // verify that we won't overflow the counter n := int64(math.Ceil((float64(bitLen) / 8) / float64(hash().Size()))) if n > 0x7FFFFFFF { return nil, fmt.Errorf("unable to derive key of size %d using 32-bit counter", bitLen) } // verify the requested bit length is not larger then the length encoding size if int64(bitLen) > 0x7FFFFFFF { return nil, fmt.Errorf("bitLen is greater than 32-bits") } fixedInput := bytes.NewBuffer(nil) fixedInput.Write(label) fixedInput.WriteByte(0x00) fixedInput.Write(context) if err := binary.Write(fixedInput, binary.BigEndian, int32(bitLen)); err != nil { return nil, fmt.Errorf("failed to write bit length to fixed input string: %v", err) } var output []byte h := hmac.New(hash, key) for i := int64(1); i <= n; i++ { h.Reset() if err := binary.Write(h, binary.BigEndian, int32(i)); err != nil { return nil, err } _, err := h.Write(fixedInput.Bytes()) if err != nil { return nil, err } output = append(output, h.Sum(nil)...) } return output[:bitLen/8], nil } // constant-time byte slice compare func cmpConst(x, y []byte) (int, error) { if len(x) != len(y) { return 0, fmt.Errorf("slice lengths do not match") } xLarger, yLarger := 0, 0 for i := 0; i < len(x); i++ { xByte, yByte := int(x[i]), int(y[i]) x := ((yByte - xByte) >> 8) & 1 y := ((xByte - yByte) >> 8) & 1 xLarger |= x &^ yLarger yLarger |= y &^ xLarger } return xLarger - yLarger, nil }