keyvault/_example/main.go (120 lines of code) (raw):
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/Azure/azure-container-networking/keyvault"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
const serverAddr = "127.0.0.1:9005"
var logger *zap.Logger
func mustArgs() (kvURL string, kvCert string) {
flag.StringVar(&kvURL, "keyvault-url", "", "keyvault url")
flag.StringVar(&kvCert, "keyvault-cert-name", "", "keyvault certificate name")
flag.Parse()
if kvURL == "" || kvCert == "" {
flag.Usage()
os.Exit(1)
}
core := zapcore.NewCore(zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), os.Stdout, zap.DebugLevel)
logger = zap.New(core)
return
}
// you must be logged in via the az cli and have proper permissions to a keyvault to run this example
func main() {
kvURL, kvCert := mustArgs()
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
logger.Fatal("could not create credentials", zap.Error(err))
}
kvs, err := keyvault.NewShim(kvURL, cred)
if err != nil {
logger.Fatal("could not create keyvault client", zap.Error(err))
}
tlsCert, err := kvs.GetLatestTLSCertificate(context.TODO(), kvCert)
if err != nil {
logger.Fatal("could not get tls cert from keyvault", zap.Error(err))
}
clientTLSConfig, err := createClientTLSConfig(tlsCert)
if err != nil {
logger.Fatal("could not create client tls config", zap.Error(err))
}
server := http.Server{
Addr: serverAddr,
Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("hello"))
}),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: clientTLSConfig.RootCAs,
ClientAuth: tls.RequireAndVerifyClientCert,
},
}
go func() {
if err := server.ListenAndServeTLS("", ""); err != nil {
logger.Fatal("could not serve tls", zap.Error(err))
}
}()
// wait for a short time to allow server to start
time.Sleep(time.Second)
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTLSConfig,
},
}
addr := fmt.Sprintf("https://%s", serverAddr)
resp, err := client.Get(addr)
if err != nil {
logger.Fatal("could not get response", zap.String("host", addr), zap.Error(err))
}
printTLSConnState(resp.TLS)
bs, _ := io.ReadAll(resp.Body)
logger.Info("response from tls server", zap.String("body bytes", string(bs)))
}
func createClientTLSConfig(tlsCert tls.Certificate) (*tls.Config, error) {
certs := x509.NewCertPool()
if len(tlsCert.Certificate) == 1 { // self signed
cer, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return nil, err
}
certs.AddCert(cer)
return &tls.Config{RootCAs: certs, ServerName: tlsCert.Leaf.Subject.CommonName}, nil
}
for i, bytes := range tlsCert.Certificate {
if i == 0 {
continue // skip leaf
}
cer, err := x509.ParseCertificate(bytes)
if err != nil {
return nil, err
}
certs.AddCert(cer)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}, RootCAs: certs, ServerName: tlsCert.Leaf.Subject.CommonName}, nil
}
func printTLSConnState(connState *tls.ConnectionState) {
logger.Info("response tls connection state", zap.Object("conn state", loggableConnState(*connState)))
for i, cert := range connState.PeerCertificates {
logger.Info(fmt.Sprintf("peer certificate %d:", i), zap.Stringer("subject", cert.Subject), zap.Stringer("issuer", cert.Issuer))
}
for i, chain := range connState.VerifiedChains {
for j, cert := range chain {
logger.Info(fmt.Sprintf("chain %d, cert %d:", i, j), zap.Stringer("subject", cert.Subject), zap.Stringer("issuer", cert.Issuer))
}
}
}
type loggableConnState tls.ConnectionState
func (l loggableConnState) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
encoder.AddString("server name", l.ServerName)
encoder.AddBool("handshake complete", l.HandshakeComplete)
encoder.AddInt("peer certificates", len(l.PeerCertificates))
encoder.AddInt("verified certificates", len(l.VerifiedChains))
return nil
}