pkg/controller/common/license/verifier.go (232 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.
package license
import (
"bytes"
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"io"
"time"
errors2 "github.com/pkg/errors"
ulog "github.com/elastic/cloud-on-k8s/v3/pkg/utils/log"
)
// Verifier verifies Enterprise licenses.
type Verifier struct {
PublicKey *rsa.PublicKey
}
// Valid checks the validity of the given Enterprise license.
func (v *Verifier) Valid(ctx context.Context, l EnterpriseLicense, now time.Time) LicenseStatus {
if !l.IsValid(now) {
return LicenseStatusExpired
}
if err := v.ValidSignature(l); err != nil {
ulog.FromContext(ctx).Error(err, "Failed signature check")
return LicenseStatusInvalid
}
return LicenseStatusValid
}
// ValidSignature checks signature of the given Enterprise license. Returns nil if valid.
func (v *Verifier) ValidSignature(l EnterpriseLicense) error {
allParts := make([]byte, base64.StdEncoding.DecodedLen(len(l.License.Signature)))
_, err := base64.StdEncoding.Decode(allParts, []byte(l.License.Signature))
if err != nil {
return errors2.Wrap(err, "failed to base64 decode signature")
}
buf := bytes.NewBuffer(allParts)
maxLen := uint32(len(allParts))
var version uint32
if err := readInt(buf, &version); err != nil {
return errors2.Wrap(err, "failed to read version")
}
var magicLen uint32
if err := readInt(buf, &magicLen); err != nil {
return errors2.Wrap(err, "failed to read magic length")
}
if magicLen > maxLen {
return errors.New("magic exceeds max length")
}
magic := make([]byte, magicLen)
_, err = buf.Read(magic)
if err != nil {
return errors2.Wrap(err, "failed to read magic")
}
var hashLen uint32
if err := readInt(buf, &hashLen); err != nil {
return errors2.Wrap(err, "failed to read hash length")
}
if hashLen > maxLen {
return errors.New("hash exceeds max len")
}
pubKeyFingerprint := make([]byte, hashLen)
_, err = buf.Read(pubKeyFingerprint)
if err != nil {
return err
}
var signedContentLen uint32
if err := readInt(buf, &signedContentLen); err != nil {
return errors2.Wrap(err, "failed to read signed content length")
}
if signedContentLen > maxLen {
return errors.New("signed content exceeds max length")
}
signedContentSig := make([]byte, signedContentLen)
_, err = buf.Read(signedContentSig)
if err != nil {
return err
}
contentBytes, err := l.SignableContentBytes()
if err != nil {
return err
}
// TODO optional pubkey fingerprint check
hashed := sha512.Sum512(contentBytes)
return rsa.VerifyPKCS1v15(v.PublicKey, crypto.SHA512, hashed[:], signedContentSig)
}
// NewVerifier creates a new license verifier from a DER encoded public key.
func NewVerifier(pubKeyBytes []byte) (*Verifier, error) {
key, err := ParsePubKey(pubKeyBytes)
return &Verifier{
PublicKey: key,
}, err
}
func ParsePubKey(pubKeyBytes []byte) (*rsa.PublicKey, error) {
pub, err := x509.ParsePKIXPublicKey(pubKeyBytes)
if err != nil {
return nil, err
}
pubKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not an RSA key")
}
return pubKey, nil
}
// Signable represents data that can be signed by a Signer.
type Signable interface {
// SignableContentBytes returns the data to be signed.
SignableContentBytes() ([]byte, error)
// Version indicates the version of the license spec used when generating SignableContentBytes.
Version() int
}
// Signer signs Enterprise licenses.
type Signer struct {
Verifier
privateKey *rsa.PrivateKey
}
// NewSigner creates a new license signer from a private key.
func NewSigner(privKey *rsa.PrivateKey) *Signer {
return &Signer{
Verifier: Verifier{
PublicKey: &privKey.PublicKey,
},
privateKey: privKey,
}
}
// Sign signs the given Signable. Returns the signature or an error.
func (s *Signer) Sign(spec Signable) ([]byte, error) {
toSign, err := spec.SignableContentBytes()
if err != nil {
return nil, err
}
rng := rand.Reader
hashed := sha512.Sum512(toSign)
rsaSig, err := rsa.SignPKCS1v15(rng, s.privateKey, crypto.SHA512, hashed[:])
if err != nil {
return nil, err
}
const magicLen = 13
magic := make([]byte, magicLen)
_, err = rand.Read(magic)
if err != nil {
return nil, err
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(s.PublicKey)
if err != nil {
return nil, errors2.Wrap(err, "while marshalling public key")
}
encPubKeyBytes, err := encryptWithAESECB(publicKeyBytes)
if err != nil {
return nil, errors2.Wrap(err, "while encrypting public key")
}
hash := make([]byte, base64.StdEncoding.EncodedLen(len(encPubKeyBytes)))
base64.StdEncoding.Encode(hash, encPubKeyBytes)
// version (uint32) + magic length (uint32) + magic + hash length (uint32) + hash + sig length (uint32) + sig
sig := make([]byte, 0, 4+4+magicLen+4+len(hash)+4+len(rsaSig))
buf := bytes.NewBuffer(sig)
if err := writeInt(buf, spec.Version()); err != nil {
return nil, err
}
if err := writeInt(buf, len(magic)); err != nil {
return nil, err
}
_, err = buf.Write(magic)
if err != nil {
return nil, err
}
if err := writeInt(buf, len(hash)); err != nil {
return nil, err
}
_, err = buf.Write(hash)
if err != nil {
return nil, err
}
if err := writeInt(buf, len(rsaSig)); err != nil {
return nil, err
}
_, err = buf.Write(rsaSig)
if err != nil {
return nil, err
}
sigBytes := buf.Bytes()
out := make([]byte, base64.StdEncoding.EncodedLen(len(sigBytes)))
base64.StdEncoding.Encode(out, sigBytes)
return out, nil
}
type licenseSpec struct {
UID string `json:"uid"`
LicenseType string `json:"type"`
IssueDate string `json:"issue_date,omitempty"`
StartDate string `json:"start_date,omitempty"`
ExpiryDate string `json:"expiry_date,omitempty"`
IssueDateInMillis int64 `json:"issue_date_in_millis,omitempty"`
StartDateInMillis int64 `json:"start_date_in_millis,omitempty"`
ExpiryDateInMillis int64 `json:"expiry_date_in_millis,omitempty"`
MaxInstances int `json:"max_instances,omitempty"`
MaxResourceUnits int `json:"max_resource_units,omitempty"`
IssuedTo string `json:"issued_to"`
Issuer string `json:"issuer"`
}
func (l EnterpriseLicense) SignableContentBytes() ([]byte, error) {
return unescapedJSONMarshal(licenseSpec{
UID: l.License.UID,
LicenseType: string(l.License.Type),
IssueDateInMillis: l.License.IssueDateInMillis,
StartDateInMillis: l.License.StartDateInMillis,
ExpiryDateInMillis: l.License.ExpiryDateInMillis,
MaxInstances: l.License.MaxInstances,
MaxResourceUnits: l.License.MaxResourceUnits,
IssuedTo: l.License.IssuedTo,
Issuer: l.License.Issuer,
})
}
// unescapedJSONMarshal is a custom JSON encoder that turns off Go json's default behaviour of escaping > < and &
// which is problematic and would lead to failed signature checks as our license signing does not escape those characters.
func unescapedJSONMarshal(t interface{}) ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
encoder.SetEscapeHTML(false)
err := encoder.Encode(t)
if err != nil {
return nil, err
}
marshaledBytes := buffer.Bytes()
// json.Encoder adds an additional newline between objects which we do not want here
// as it is not part of the signature. That's we we are trimming it here.
return bytes.TrimRight(marshaledBytes, "\n"), err
}
func (l EnterpriseLicense) Version() int {
return l.License.Version
}
func writeInt(buffer *bytes.Buffer, i int) error {
in := make([]byte, 4)
binary.BigEndian.PutUint32(in, uint32(i))
_, err := buffer.Write(in)
return err
}
func readInt(r io.Reader, i *uint32) error {
out := make([]byte, 4)
if _, err := io.ReadFull(r, out); err != nil {
return err
}
*i = binary.BigEndian.Uint32(out)
return nil
}