assertion/pkg/jwtvault/helpers.go (93 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package jwtvault
import (
"crypto/x509"
"encoding/pem"
"fmt"
"path"
"github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
"gopkg.in/square/go-jose.v2"
)
// -----------------------------------------------------------------------------
type keyResponse struct {
Type string `json:"type" mapstructure:"type"`
LatestVersion float64 `json:"latest_version" mapstructure:"latest_version"`
Keys map[string]keyVersion `json:"keys" mapstructure:"keys"`
}
type keyVersion struct {
PublicKey string `json:"public_key" mapstructure:"public_key"`
}
// -----------------------------------------------------------------------------
// GetPublicKey returns parsed public key
func GetPublicKey(vaultClient *api.Client, transitPath, keyName string) (publicKey interface{}, version uint, err error) {
// Check arguments
if vaultClient == nil {
return nil, 0, fmt.Errorf("vault client must not be nil")
}
if transitPath == "" {
return nil, 0, fmt.Errorf("transit path path must not be blank")
}
if keyName == "" {
return nil, 0, fmt.Errorf("key name must not be blank")
}
// Retrieve transit key
d, err := vaultClient.Logical().Read(path.Join(transitPath, "keys", keyName))
if err != nil {
return nil, 0, fmt.Errorf("unable to retrieve key details: %w", err)
}
if d == nil {
return nil, 0, fmt.Errorf("returned key details are nil")
}
// Decode data
var transitKey keyResponse
if err = mapstructure.Decode(d.Data, &transitKey); err != nil {
return nil, 0, fmt.Errorf("unable to decode key response: %w", err)
}
// Get latest version
latestVersion, ok := transitKey.Keys[fmt.Sprintf("%d", uint(transitKey.LatestVersion))]
if !ok {
return nil, 0, fmt.Errorf("unable to retrieve transit key version '%f'", transitKey.LatestVersion)
}
// Decode PEM
block, _ := pem.Decode([]byte(latestVersion.PublicKey))
if block == nil {
return nil, 0, fmt.Errorf("unable to decode public key PEM block")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, 0, fmt.Errorf("unable to decode publiv key: %w", err)
}
// No error
return pub, uint(transitKey.LatestVersion), nil
}
// JWKS extracts the public key set from a vault transit key.
func JWKS(vaultClient *api.Client, transitPath, keyName string) (*jose.JSONWebKeySet, error) {
// Check arguments
if vaultClient == nil {
return nil, fmt.Errorf("vault client must not be nil")
}
if transitPath == "" {
return nil, fmt.Errorf("transit path path must not be blank")
}
if keyName == "" {
return nil, fmt.Errorf("key name must not be blank")
}
// Retrieve transit key
tk, err := vaultClient.Logical().Read(path.Join(transitPath, "keys", keyName))
if err != nil {
return nil, fmt.Errorf("unable to retrieve key details: %w", err)
}
if tk == nil {
return nil, fmt.Errorf("returned key details are nil")
}
// Decode data
var transitKey keyResponse
if err = mapstructure.Decode(tk.Data, &transitKey); err != nil {
return nil, fmt.Errorf("unable to decode transit key response: %w", err)
}
// Prepare key set
jwks := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{},
}
// Iterate over all keys
for kid, keyVersion := range transitKey.Keys {
// Decode PEM
block, _ := pem.Decode([]byte(keyVersion.PublicKey))
if block == nil {
return nil, fmt.Errorf("unable to decode public key PEM block")
}
// Parse key
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to decode publiv key: %w", err)
}
// Prepare JWK
jwks.Keys = append(jwks.Keys, jose.JSONWebKey{
KeyID: fmt.Sprintf("vault:%s:%s:v%s", transitPath, keyName, kid),
Key: pub,
})
}
// No error
return jwks, nil
}