elastictransport/elastictransport.go (421 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package elastictransport
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
)
const (
userAgentHeader = "User-Agent"
)
var (
defaultMaxRetries = 3
defaultRetryOnStatus = [...]int{502, 503, 504}
)
// Interface defines the interface for HTTP client.
type Interface interface {
Perform(*http.Request) (*http.Response, error)
}
// Instrumented allows to retrieve the current transport Instrumentation
type Instrumented interface {
InstrumentationEnabled() Instrumentation
}
// Config represents the configuration of HTTP client.
type Config struct {
UserAgent string
URLs []*url.URL
Username string
Password string
APIKey string
ServiceToken string
Header http.Header
CACert []byte
// DisableRetry disables retrying requests.
//
// If DisableRetry is true, then RetryOnStatus, RetryOnError, MaxRetries, and RetryBackoff will be ignored.
DisableRetry bool
// RetryOnStatus holds an optional list of HTTP response status codes that should trigger a retry.
//
// If RetryOnStatus is nil, then the defaults will be used:
// 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout).
RetryOnStatus []int
// RetryOnError holds an optional function that will be called when a request fails due to an
// HTTP transport error, to indicate whether the request should be retried, e.g. timeouts.
RetryOnError func(*http.Request, error) bool
MaxRetries int
RetryBackoff func(attempt int) time.Duration
CompressRequestBody bool
CompressRequestBodyLevel int
// If PoolCompressor is true, a sync.Pool based gzip writer is used. Should be enabled with CompressRequestBody.
PoolCompressor bool
EnableMetrics bool
EnableDebugLogger bool
Instrumentation Instrumentation
DiscoverNodesInterval time.Duration
DiscoverNodeTimeout *time.Duration
Transport http.RoundTripper
Logger Logger
Selector Selector
ConnectionPoolFunc func([]*Connection, Selector) ConnectionPool
CertificateFingerprint string
}
// Client represents the HTTP client.
type Client struct {
sync.Mutex
userAgent string
urls []*url.URL
username string
password string
apikey string
servicetoken string
fingerprint string
header http.Header
retryOnStatus []int
disableRetry bool
enableRetryOnTimeout bool
maxRetries int
retryOnError func(*http.Request, error) bool
retryBackoff func(attempt int) time.Duration
discoverNodesInterval time.Duration
discoverNodesTimer *time.Timer
discoverNodeTimeout *time.Duration
compressRequestBody bool
compressRequestBodyLevel int
gzipCompressor gzipCompressor
instrumentation Instrumentation
metrics *metrics
transport http.RoundTripper
logger Logger
selector Selector
pool ConnectionPool
poolFunc func([]*Connection, Selector) ConnectionPool
}
// New creates new transport client.
//
// http.DefaultTransport will be used if no transport is passed in the configuration.
func New(cfg Config) (*Client, error) {
if cfg.Transport == nil {
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("cannot clone http.DefaultTransport")
}
cfg.Transport = defaultTransport.Clone()
}
if transport, ok := cfg.Transport.(*http.Transport); ok {
if cfg.CertificateFingerprint != "" {
transport.DialTLS = func(network, addr string) (net.Conn, error) {
fingerprint, _ := hex.DecodeString(cfg.CertificateFingerprint)
c, err := tls.Dial(network, addr, &tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, err
}
// Retrieve the connection state from the remote server.
cState := c.ConnectionState()
for _, cert := range cState.PeerCertificates {
// Compute digest for each certificate.
digest := sha256.Sum256(cert.Raw)
// Provided fingerprint should match at least one certificate from remote before we continue.
if bytes.Compare(digest[0:], fingerprint) == 0 {
return c, nil
}
}
return nil, fmt.Errorf("fingerprint mismatch, provided: %s", cfg.CertificateFingerprint)
}
}
}
if cfg.CACert != nil {
httpTransport, ok := cfg.Transport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unable to set CA certificate for transport of type %T", cfg.Transport)
}
httpTransport = httpTransport.Clone()
httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(cfg.CACert); !ok {
return nil, errors.New("unable to add CA certificate")
}
cfg.Transport = httpTransport
}
if len(cfg.RetryOnStatus) == 0 {
cfg.RetryOnStatus = defaultRetryOnStatus[:]
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = defaultMaxRetries
}
var conns []*Connection
for _, u := range cfg.URLs {
conns = append(conns, &Connection{URL: u})
}
client := Client{
userAgent: cfg.UserAgent,
urls: cfg.URLs,
username: cfg.Username,
password: cfg.Password,
apikey: cfg.APIKey,
servicetoken: cfg.ServiceToken,
header: cfg.Header,
retryOnStatus: cfg.RetryOnStatus,
disableRetry: cfg.DisableRetry,
maxRetries: cfg.MaxRetries,
retryOnError: cfg.RetryOnError,
retryBackoff: cfg.RetryBackoff,
discoverNodesInterval: cfg.DiscoverNodesInterval,
compressRequestBody: cfg.CompressRequestBody,
compressRequestBodyLevel: cfg.CompressRequestBodyLevel,
transport: cfg.Transport,
logger: cfg.Logger,
selector: cfg.Selector,
poolFunc: cfg.ConnectionPoolFunc,
instrumentation: cfg.Instrumentation,
}
if cfg.DiscoverNodeTimeout != nil {
client.discoverNodeTimeout = cfg.DiscoverNodeTimeout
}
if client.poolFunc != nil {
client.pool = client.poolFunc(conns, client.selector)
} else {
client.pool, _ = NewConnectionPool(conns, client.selector)
}
if cfg.EnableDebugLogger {
debugLogger = &debuggingLogger{Output: os.Stdout}
}
if cfg.EnableMetrics {
client.metrics = &metrics{responses: make(map[int]int)}
// TODO(karmi): Type assertion to interface
if pool, ok := client.pool.(*singleConnectionPool); ok {
pool.metrics = client.metrics
}
if pool, ok := client.pool.(*statusConnectionPool); ok {
pool.metrics = client.metrics
}
}
if client.discoverNodesInterval > 0 {
time.AfterFunc(client.discoverNodesInterval, func() {
client.scheduleDiscoverNodes(client.discoverNodesInterval)
})
}
if client.compressRequestBodyLevel == 0 {
client.compressRequestBodyLevel = gzip.DefaultCompression
}
if cfg.PoolCompressor {
client.gzipCompressor = newPooledGzipCompressor(client.compressRequestBodyLevel)
} else {
client.gzipCompressor = newSimpleGzipCompressor(client.compressRequestBodyLevel)
}
return &client, nil
}
// Perform executes the request and returns a response or error.
func (c *Client) Perform(req *http.Request) (*http.Response, error) {
var (
res *http.Response
err error
)
// Record metrics, when enabled
if c.metrics != nil {
c.metrics.Lock()
c.metrics.requests++
c.metrics.Unlock()
}
// Update request
c.setReqUserAgent(req)
c.setReqGlobalHeader(req)
if req.Body != nil && req.Body != http.NoBody {
if c.compressRequestBody {
buf, err := c.gzipCompressor.compress(req.Body)
if err != nil {
return nil, err
}
defer c.gzipCompressor.collectBuffer(buf)
req.GetBody = func() (io.ReadCloser, error) {
// Copy value of buf so it's not destroyed on first read
r := *buf
return ioutil.NopCloser(&r), nil
}
req.Body, _ = req.GetBody()
req.Header.Set("Content-Encoding", "gzip")
req.ContentLength = int64(buf.Len())
} else if req.GetBody == nil {
if !c.disableRetry || (c.logger != nil && c.logger.RequestBodyEnabled()) {
var buf bytes.Buffer
buf.ReadFrom(req.Body)
req.GetBody = func() (io.ReadCloser, error) {
// Copy value of buf so it's not destroyed on first read
r := buf
return ioutil.NopCloser(&r), nil
}
req.Body, _ = req.GetBody()
}
}
}
originalPath := req.URL.Path
for i := 0; i <= c.maxRetries; i++ {
var (
conn *Connection
shouldRetry bool
shouldCloseBody bool
)
// Get connection from the pool
c.Lock()
conn, err = c.pool.Next()
c.Unlock()
if err != nil {
if c.logger != nil {
c.logRoundTrip(req, nil, err, time.Time{}, time.Duration(0))
}
return nil, fmt.Errorf("cannot get connection: %s", err)
}
// Update request
c.setReqURL(conn.URL, req)
c.setReqAuth(conn.URL, req)
if !c.disableRetry && i > 0 && req.Body != nil && req.Body != http.NoBody {
body, err := req.GetBody()
if err != nil {
return nil, fmt.Errorf("cannot get request body: %s", err)
}
req.Body = body
}
// Set up time measures and execute the request
start := time.Now().UTC()
res, err = c.transport.RoundTrip(req)
dur := time.Since(start)
// Log request and response
if c.logger != nil {
if c.logger.RequestBodyEnabled() && req.Body != nil && req.Body != http.NoBody {
req.Body, _ = req.GetBody()
}
c.logRoundTrip(req, res, err, start, dur)
}
if err != nil {
// Record metrics, when enabled
if c.metrics != nil {
c.metrics.Lock()
c.metrics.failures++
c.metrics.Unlock()
}
// Report the connection as unsuccessful
c.Lock()
c.pool.OnFailure(conn)
c.Unlock()
// Retry upon decision by the user
if !c.disableRetry && (c.retryOnError == nil || c.retryOnError(req, err)) {
shouldRetry = true
}
} else {
// Report the connection as succesfull
c.Lock()
c.pool.OnSuccess(conn)
c.Unlock()
}
if res != nil && c.metrics != nil {
c.metrics.Lock()
c.metrics.responses[res.StatusCode]++
c.metrics.Unlock()
}
if res != nil && c.instrumentation != nil {
c.instrumentation.AfterResponse(req.Context(), res)
}
// Retry on configured response statuses
if res != nil && !c.disableRetry {
for _, code := range c.retryOnStatus {
if res.StatusCode == code {
shouldRetry = true
shouldCloseBody = true
}
}
}
// Break if retry should not be performed
if !shouldRetry {
break
}
// Drain and close body when retrying after response
if shouldCloseBody && i < c.maxRetries {
if res.Body != nil {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
}
}
// Delay the retry if a backoff function is configured
if c.retryBackoff != nil {
var cancelled bool
backoff := c.retryBackoff(i + 1)
timer := time.NewTimer(backoff)
select {
case <-req.Context().Done():
err = req.Context().Err()
cancelled = true
timer.Stop()
case <-timer.C:
}
if cancelled {
break
}
}
// Re-init the path of the request to its original state
// This will be re-enriched by the connection upon retry
req.URL.Path = originalPath
}
// TODO(karmi): Wrap error
return res, err
}
// URLs returns a list of transport URLs.
func (c *Client) URLs() []*url.URL {
return c.pool.URLs()
}
func (c *Client) InstrumentationEnabled() Instrumentation {
return c.instrumentation
}
func (c *Client) setReqURL(u *url.URL, req *http.Request) *http.Request {
req.URL.Scheme = u.Scheme
req.URL.Host = u.Host
if u.Path != "" {
var b strings.Builder
b.Grow(len(u.Path) + len(req.URL.Path))
b.WriteString(u.Path)
b.WriteString(req.URL.Path)
req.URL.Path = b.String()
}
return req
}
func (c *Client) setReqAuth(u *url.URL, req *http.Request) *http.Request {
if _, ok := req.Header["Authorization"]; !ok {
if u.User != nil {
password, _ := u.User.Password()
req.SetBasicAuth(u.User.Username(), password)
return req
}
if c.apikey != "" {
var b bytes.Buffer
b.Grow(len("APIKey ") + len(c.apikey))
b.WriteString("APIKey ")
b.WriteString(c.apikey)
req.Header.Set("Authorization", b.String())
return req
}
if c.servicetoken != "" {
var b bytes.Buffer
b.Grow(len("Bearer ") + len(c.servicetoken))
b.WriteString("Bearer ")
b.WriteString(c.servicetoken)
req.Header.Set("Authorization", b.String())
return req
}
if c.username != "" && c.password != "" {
req.SetBasicAuth(c.username, c.password)
return req
}
}
return req
}
func (c *Client) setReqUserAgent(req *http.Request) *http.Request {
if len(c.header) > 0 {
ua := c.header.Get(userAgentHeader)
if ua != "" {
req.Header.Set(userAgentHeader, ua)
return req
}
}
req.Header.Set(userAgentHeader, c.userAgent)
return req
}
func (c *Client) setReqGlobalHeader(req *http.Request) *http.Request {
if len(c.header) > 0 {
for k, v := range c.header {
if req.Header.Get(k) != k {
for _, vv := range v {
req.Header.Add(k, vv)
}
}
}
}
return req
}
func (c *Client) logRoundTrip(
req *http.Request,
res *http.Response,
err error,
start time.Time,
dur time.Duration,
) {
var dupRes http.Response
if res != nil {
dupRes = *res
}
if c.logger.ResponseBodyEnabled() {
if res != nil && res.Body != nil && res.Body != http.NoBody {
b1, b2, _ := duplicateBody(res.Body)
dupRes.Body = b1
res.Body = b2
}
}
c.logger.LogRoundTrip(req, &dupRes, err, start, dur) // errcheck exclude
}