transport/tls.go (171 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package transport
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/elastic/elastic-agent-libs/testing"
"github.com/elastic/elastic-agent-libs/transport/tlscommon"
)
func TLSDialer(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration) Dialer {
return TestTLSDialer(testing.NullDriver, forward, config, timeout)
}
func TestTLSDialer(
d testing.Driver,
forward Dialer,
config *tlscommon.TLSConfig,
timeout time.Duration,
) Dialer {
var lastTLSConfig *tls.Config
var lastNetwork string
var lastAddress string
var m sync.Mutex
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unsupported network type %v", network)
}
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
m.Lock()
if network == lastNetwork && address == lastAddress {
tlsConfig = lastTLSConfig
}
if tlsConfig == nil {
tlsConfig = config.BuildModuleClientConfig(host)
lastNetwork = network
lastAddress = address
lastTLSConfig = tlsConfig
}
m.Unlock()
return tlsDialWith(ctx, d, forward, network, address, timeout, tlsConfig, config)
})
}
type DialerH2 interface {
DialContext(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error)
}
type DialerFuncH2 func(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error)
func (d DialerFuncH2) DialContext(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error) {
return d(ctx, network, address, cfg)
}
func TLSDialerH2(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration) (DialerH2, error) {
return TestTLSDialerH2(testing.NullDriver, forward, config, timeout)
}
func TestTLSDialerH2(
d testing.Driver,
forward Dialer,
config *tlscommon.TLSConfig,
timeout time.Duration,
) (DialerH2, error) {
var lastTLSConfig *tls.Config
var lastNetwork string
var lastAddress string
var m sync.Mutex
return DialerFuncH2(func(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unsupported network type %v", network)
}
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
m.Lock()
if network == lastNetwork && address == lastAddress {
tlsConfig = lastTLSConfig
}
if tlsConfig == nil {
tlsConfig = config.BuildModuleClientConfig(host)
lastNetwork = network
lastAddress = address
lastTLSConfig = tlsConfig
}
m.Unlock()
// NextProtos must be set from the passed h2 connection or it will fail
tlsConfig.NextProtos = cfg.NextProtos
return tlsDialWith(ctx, d, forward, network, address, timeout, tlsConfig, config)
}), nil
}
func tlsDialWith(
ctx context.Context,
d testing.Driver,
dialer Dialer,
network, address string,
timeout time.Duration,
tlsConfig *tls.Config,
config *tlscommon.TLSConfig,
) (net.Conn, error) {
socket, err := dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
conn := tls.Client(socket, tlsConfig)
withTimeout := timeout > 0
if withTimeout {
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
d.Fatal("timeout", err)
_ = conn.Close()
return nil, err
}
}
// config might be nil, so get the zero-value and then read what is in config.
// We assume that the zero-value is the default value
var verification tlscommon.TLSVerificationMode
if config != nil {
verification = config.Verification
}
// We only check the status of config.Verification (`ssl.verification_mode`
// in the configuration file) because we have a custom verification logic
// implemented by setting tlsConfig.VerifyConnection that runs regardless of
// the status of tlsConfig.InsecureSkipVerify.
// For verification modes VerifyFull and VerifyCeritifcate we set
// tlsConfig.InsecureSkipVerify to true, hence it's not an indicator of
// whether TLS verification is enabled or not.
if verification == tlscommon.VerifyNone {
d.Warn("security", "server's certificate chain verification is disabled")
} else {
d.Info("security", "server's certificate chain verification is enabled")
}
err = conn.Handshake()
d.Fatal("handshake", err)
if err != nil {
_ = conn.Close()
return nil, err
}
// remove timeout if handshake was subject to timeout:
if withTimeout {
err := conn.SetDeadline(time.Time{})
if err != nil {
return nil, err
}
}
if err := postVerifyTLSConnection(d, conn, config); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func postVerifyTLSConnection(d testing.Driver, conn *tls.Conn, config *tlscommon.TLSConfig) error {
st := conn.ConnectionState()
if !st.HandshakeComplete {
err := errors.New("incomplete handshake")
d.Fatal("incomplete handshake", err)
return err
}
d.Info("TLS version", fmt.Sprintf("%v", tlscommon.TLSVersion(st.Version)))
// no more checks if no extra configs available
if config == nil {
return nil
}
versions := config.Versions
if versions == nil {
versions = tlscommon.TLSDefaultVersions
}
versionOK := false
for _, version := range versions {
versionOK = versionOK || st.Version == uint16(version)
}
if !versionOK {
err := fmt.Errorf("tls version %v not configured", tlscommon.TLSVersion(st.Version))
d.Fatal("TLS version", err)
return err
}
return nil
}