transport/client.go (190 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package transport
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/testing"
"github.com/elastic/elastic-agent-libs/transport/tlscommon"
)
type Client struct {
log *logp.Logger
dialer Dialer
network string
host string
config Config
conn net.Conn
mutex sync.Mutex
}
type Config struct {
Proxy *ProxyConfig
TLS *tlscommon.TLSConfig
Timeout time.Duration
Stats IOStatser
}
func NewClient(c Config, network, host string, defaultPort int) (*Client, error) {
// do some sanity checks regarding network and Config matching +
// address being parseable
switch network {
case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
if c.TLS == nil && c.Proxy == nil {
break
}
fallthrough
default:
return nil, fmt.Errorf("unsupported network type %v", network)
}
dialer, err := MakeDialer(c)
if err != nil {
return nil, err
}
return NewClientWithDialer(dialer, c, network, host, defaultPort)
}
func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort int) (*Client, error) {
// check address being parseable
host = fullAddress(host, defaultPort)
_, _, err := net.SplitHostPort(host)
if err != nil {
return nil, err
}
client := &Client{
log: logp.NewLogger(logSelector),
dialer: d,
network: network,
host: host,
config: c,
}
return client, nil
}
func (c *Client) Connect() error {
return c.ConnectContext(context.Background())
}
func (c *Client) ConnectContext(ctx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
conn, err := c.dialer.DialContext(ctx, c.network, c.host)
if err != nil {
return err
}
c.conn = conn
return nil
}
func (c *Client) IsConnected() bool {
c.mutex.Lock()
b := c.conn != nil
c.mutex.Unlock()
return b
}
func (c *Client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.conn != nil {
c.log.Debug("closing")
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
func (c *Client) getConn() net.Conn {
c.mutex.Lock()
conn := c.conn
c.mutex.Unlock()
return conn
}
func (c *Client) Read(b []byte) (int, error) {
conn := c.getConn()
if conn == nil {
return 0, ErrNotConnected
}
n, err := conn.Read(b)
return n, c.handleError(err)
}
func (c *Client) Write(b []byte) (int, error) {
conn := c.getConn()
if conn == nil {
return 0, ErrNotConnected
}
n, err := c.conn.Write(b)
return n, c.handleError(err)
}
func (c *Client) LocalAddr() net.Addr {
conn := c.getConn()
if conn != nil {
return c.conn.LocalAddr()
}
return nil
}
func (c *Client) RemoteAddr() net.Addr {
conn := c.getConn()
if conn != nil {
return c.conn.RemoteAddr()
}
return nil
}
func (c *Client) Host() string {
return c.host
}
func (c *Client) SetDeadline(t time.Time) error {
conn := c.getConn()
if conn == nil {
return ErrNotConnected
}
err := conn.SetDeadline(t)
return c.handleError(err)
}
func (c *Client) SetReadDeadline(t time.Time) error {
conn := c.getConn()
if conn == nil {
return ErrNotConnected
}
err := conn.SetReadDeadline(t)
return c.handleError(err)
}
func (c *Client) SetWriteDeadline(t time.Time) error {
conn := c.getConn()
if conn == nil {
return ErrNotConnected
}
err := conn.SetWriteDeadline(t)
return c.handleError(err)
}
func (c *Client) handleError(err error) error {
if err != nil {
c.log.Debugf("handle error: %+v", err)
var nerr net.Error
if errors.As(err, &nerr) && !nerr.Timeout() {
_ = c.Close()
}
}
return err
}
func (c *Client) Test(d testing.Driver) {
d.Run("logstash: "+c.host, func(d testing.Driver) {
d.Run("connection", func(d testing.Driver) {
netDialer := TestNetDialer(d, c.config.Timeout)
_, err := netDialer.DialContext(context.Background(), "tcp", c.host)
d.Fatal("dial up", err)
})
if c.config.TLS == nil {
d.Warn("TLS", "secure connection disabled")
} else {
d.Run("TLS", func(d testing.Driver) {
netDialer := NetDialer(c.config.Timeout)
tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout)
_, err := tlsDialer.DialContext(context.Background(), "tcp", c.host)
d.Fatal("dial up", err)
})
}
err := c.Connect()
d.Fatal("talk to server", err)
})
}
func (c *Client) String() string {
return c.network + "://" + c.host
}