utils/httputil/tls.go (155 lines of code) (raw):

// Copyright (c) 2016-2019 Uber Technologies, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httputil import ( "bytes" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "io" "io/ioutil" "github.com/uber/kraken/utils/log" ) // ErrEmptyCommonName is returned when common name is not provided for key generation. var ErrEmptyCommonName = errors.New("empty common name") // TLSConfig defines TLS configuration. type TLSConfig struct { Name string `yaml:"name"` Server X509Pair `yaml:"server"` Client X509Pair `yaml:"client"` CAs []Secret `yaml:"cas"` // Lazy init. tls *tls.Config } // X509Pair contains x509 cert configuration. // Both Cert and Key should be already in pem format. type X509Pair struct { Disabled bool `yaml:"disabled"` Cert Secret `yaml:"cert"` Key Secret `yaml:"key"` Passphrase Secret `yaml:"passphrase"` } // Secret contains secret path configuration. type Secret struct { Path string `yaml:"path"` } // BuildClient builts tls.Config for http client. func (c *TLSConfig) BuildClient() (*tls.Config, error) { if c.Client.Disabled { log.Infof("Client TLS is disabled") return nil, nil } if c.tls != nil { return c.tls, nil } var caPool *x509.CertPool var certs []tls.Certificate var err error if len(c.CAs) > 0 { caPool, err = createCertPool(c.CAs) if err != nil { return nil, fmt.Errorf("create cert pool: %s", err) } } if c.Client.Cert.Path != "" { certPEM, err := parseCert(c.Client.Cert.Path) if err != nil { return nil, fmt.Errorf("parse client cert: %s", err) } keyPEM, err := parseKey(c.Client.Key.Path, c.Client.Passphrase.Path) if err != nil { return nil, fmt.Errorf("parse client key: %s", err) } cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { return nil, fmt.Errorf("load client x509 key pair: %s", err) } certs = []tls.Certificate{cert} } c.tls = &tls.Config{ Certificates: certs, RootCAs: caPool, ServerName: c.Name, PreferServerCipherSuites: true, InsecureSkipVerify: false, // This is important to enforce verification of server. } return c.tls, nil } // WriteCABundle writes a list of CA to a writer. func (c *TLSConfig) WriteCABundle(w io.Writer) error { pems, err := concatSecrets(c.CAs) if err != nil { return fmt.Errorf("concat secrets: %s", err) } if _, err := w.Write(pems); err != nil { return fmt.Errorf("write cas: %s", err) } return nil } func createCertPool(secrets []Secret) (*x509.CertPool, error) { pool, err := x509.SystemCertPool() if err != nil { return nil, fmt.Errorf("create system cert pool: %s", err) } // No system certs provided. Create an empty cert pool. if pool == nil { pool = x509.NewCertPool() } pems, err := concatSecrets(secrets) if err != nil { return nil, fmt.Errorf("concat secrets: %s", err) } if ok := pool.AppendCertsFromPEM(pems); !ok { return nil, fmt.Errorf("cannot append cert") } return pool, nil } func concatSecrets(secrets []Secret) ([]byte, error) { result := bytes.Buffer{} for _, s := range secrets { pem, err := parseCert(s.Path) if err != nil { return nil, fmt.Errorf("parse cert: %s", err) } result.Write(pem) } return result.Bytes(), nil } func parseCert(path string) ([]byte, error) { certBytes, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %s", err) } return certBytes, nil } // parseKey reads key from file and decrypts if passphrase is provided. func parseKey(path, passphrasePath string) ([]byte, error) { keyPEM, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %s", err) } if passphrasePath != "" { passphrase, err := ioutil.ReadFile(passphrasePath) if err != nil { return nil, fmt.Errorf("read passphrase file: %s", err) } keyBytes, err := decryptPEMBlock(keyPEM, passphrase) if err != nil { return nil, fmt.Errorf("decrypt key: %s", err) } keyPEM, err = encodePEMKey(keyBytes) if err != nil { return nil, fmt.Errorf("encode key: %s", err) } } return keyPEM, nil } // decryptPEMBlock decrypts the block of data. func decryptPEMBlock(data, secret []byte) ([]byte, error) { block, _ := pem.Decode(data) if block == nil || len(block.Bytes) < 1 { return nil, errors.New("empty block") } decoded, err := x509.DecryptPEMBlock(block, secret) if err != nil { return nil, fmt.Errorf("decrypt block: %s", err) } return decoded, nil } // encodePEMKey marshals the DER-encoded private key. func encodePEMKey(data []byte) ([]byte, error) { buf := new(bytes.Buffer) err := pem.Encode(buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: data}) if err != nil { return nil, fmt.Errorf("encode key: %s", err) } return buf.Bytes(), nil }