network/client.go (308 lines of code) (raw):
package network
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"mime"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Sirupsen/logrus"
"github.com/jpillora/backoff"
"gitlab.com/gitlab-org/gitlab-ci-multi-runner/common"
)
type requestCredentials interface {
GetURL() string
GetToken() string
GetTLSCAFile() string
GetTLSCertFile() string
GetTLSKeyFile() string
}
var (
dialer = net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
backOffDelayMin = 100 * time.Millisecond
backOffDelayMax = 60 * time.Second
backOffDelayFactor = 2.0
backOffDelayJitter = true
)
type client struct {
http.Client
url *url.URL
caFile string
certFile string
keyFile string
caData []byte
skipVerify bool
updateTime time.Time
lastUpdate string
requestBackOffs map[string]*backoff.Backoff
lock sync.Mutex
}
type ResponseTLSData struct {
CAChain string
CertFile string
KeyFile string
}
func (n *client) getLastUpdate() string {
return n.lastUpdate
}
func (n *client) setLastUpdate(headers http.Header) {
if lu := headers.Get("X-GitLab-Last-Update"); len(lu) > 0 {
n.lastUpdate = lu
}
}
func (n *client) ensureTLSConfig() {
// certificate got modified
if stat, err := os.Stat(n.caFile); err == nil && n.updateTime.Before(stat.ModTime()) {
n.Transport = nil
}
// client certificate got modified
if stat, err := os.Stat(n.certFile); err == nil && n.updateTime.Before(stat.ModTime()) {
n.Transport = nil
}
// client private key got modified
if stat, err := os.Stat(n.keyFile); err == nil && n.updateTime.Before(stat.ModTime()) {
n.Transport = nil
}
// create or update transport
if n.Transport == nil {
n.updateTime = time.Now()
n.createTransport()
}
}
func (n *client) addTLSCA(tlsConfig *tls.Config) {
// load TLS CA certificate
if file := n.caFile; file != "" && !n.skipVerify {
logrus.Debugln("Trying to load", file, "...")
data, err := ioutil.ReadFile(file)
if err == nil {
pool := x509.NewCertPool()
if pool.AppendCertsFromPEM(data) {
tlsConfig.RootCAs = pool
n.caData = data
} else {
logrus.Errorln("Failed to parse PEM in", n.caFile)
}
} else {
if !os.IsNotExist(err) {
logrus.Errorln("Failed to load", n.caFile, err)
}
}
}
}
func (n *client) addTLSAuth(tlsConfig *tls.Config) {
// load TLS client keypair
if cert, key := n.certFile, n.keyFile; cert != "" && key != "" {
logrus.Debugln("Trying to load", cert, "and", key, "pair...")
certificate, err := tls.LoadX509KeyPair(cert, key)
if err == nil {
tlsConfig.Certificates = []tls.Certificate{certificate}
tlsConfig.BuildNameToCertificate()
} else {
if !os.IsNotExist(err) {
logrus.Errorln("Failed to load", cert, key, err)
}
}
}
}
func (n *client) createTransport() {
// create reference TLS config
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: n.skipVerify,
}
n.addTLSCA(&tlsConfig)
n.addTLSAuth(&tlsConfig)
// create transport
n.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: func(network, addr string) (net.Conn, error) {
logrus.Debugln("Dialing:", network, addr, "...")
return dialer.Dial(network, addr)
},
TLSClientConfig: &tlsConfig,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Minute,
}
n.Timeout = common.DefaultNetworkClientTimeout
}
func (n *client) getCAChain(tls *tls.ConnectionState) string {
if len(n.caData) != 0 {
return string(n.caData)
}
if tls == nil {
return ""
}
// Don't reorder certificates by putting them directly into the map
var certificates []*x509.Certificate
seenCertificates := make(map[string]bool, 0)
for _, verifiedChain := range tls.VerifiedChains {
for _, certificate := range verifiedChain {
signature := hex.EncodeToString(certificate.Signature)
if seenCertificates[signature] {
continue
}
seenCertificates[signature] = true
certificates = append(certificates, certificate)
}
}
out := bytes.NewBuffer(nil)
for _, certificate := range certificates {
if err := pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw}); err != nil {
logrus.Warn("Failed to encode certificate from chain:", err)
}
}
return out.String()
}
func (n *client) ensureBackoff(method, uri string) *backoff.Backoff {
n.lock.Lock()
defer n.lock.Unlock()
key := fmt.Sprintf("%s_%s", method, uri)
if n.requestBackOffs[key] == nil {
n.requestBackOffs[key] = &backoff.Backoff{
Min: backOffDelayMin,
Max: backOffDelayMax,
Factor: backOffDelayFactor,
Jitter: backOffDelayJitter,
}
}
return n.requestBackOffs[key]
}
func (n *client) backoffRequired(res *http.Response) bool {
return res.StatusCode >= 400 && res.StatusCode < 600
}
func (n *client) doBackoffRequest(req *http.Request) (res *http.Response, err error) {
res, err = n.Do(req)
if err != nil {
err = fmt.Errorf("couldn't execute %v against %s: %v", req.Method, req.URL, err)
return
}
backoffDelay := n.ensureBackoff(req.Method, req.RequestURI)
if n.backoffRequired(res) {
time.Sleep(backoffDelay.Duration())
} else {
backoffDelay.Reset()
}
return
}
func (n *client) do(uri, method string, request io.Reader, requestType string, headers http.Header) (res *http.Response, err error) {
url, err := n.url.Parse(uri)
if err != nil {
return
}
req, err := http.NewRequest(method, url.String(), request)
if err != nil {
err = fmt.Errorf("failed to create NewRequest: %v", err)
return
}
if headers != nil {
req.Header = headers
}
if request != nil {
req.Header.Set("Content-Type", requestType)
req.Header.Set("User-Agent", common.AppVersion.UserAgent())
}
n.ensureTLSConfig()
res, err = n.doBackoffRequest(req)
return
}
func (n *client) doJSON(uri, method string, statusCode int, request interface{}, response interface{}) (int, string, ResponseTLSData) {
var body io.Reader
if request != nil {
requestBody, err := json.Marshal(request)
if err != nil {
return -1, fmt.Sprintf("failed to marshal project object: %v", err), ResponseTLSData{}
}
body = bytes.NewReader(requestBody)
}
headers := make(http.Header)
if response != nil {
headers.Set("Accept", "application/json")
}
res, err := n.do(uri, method, body, "application/json", headers)
if err != nil {
return -1, err.Error(), ResponseTLSData{}
}
defer res.Body.Close()
defer io.Copy(ioutil.Discard, res.Body)
if res.StatusCode == statusCode {
if response != nil {
isApplicationJSON, err := isResponseApplicationJSON(res)
if !isApplicationJSON {
return -1, err.Error(), ResponseTLSData{}
}
d := json.NewDecoder(res.Body)
err = d.Decode(response)
if err != nil {
return -1, fmt.Sprintf("Error decoding json payload %v", err), ResponseTLSData{}
}
}
}
n.setLastUpdate(res.Header)
TLSData := ResponseTLSData{
CAChain: n.getCAChain(res.TLS),
CertFile: n.certFile,
KeyFile: n.keyFile,
}
return res.StatusCode, res.Status, TLSData
}
func isResponseApplicationJSON(res *http.Response) (result bool, err error) {
contentType := res.Header.Get("Content-Type")
mimetype, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false, fmt.Errorf("Content-Type parsing error: %v", err)
}
if mimetype != "application/json" {
return false, fmt.Errorf("Server should return application/json. Got: %v", contentType)
}
return true, nil
}
func fixCIURL(url string) string {
url = strings.TrimRight(url, "/")
if strings.HasSuffix(url, "/ci") {
url = strings.TrimSuffix(url, "/ci")
}
return url
}
func (n *client) findCertificate(certificate *string, base string, name string) {
if *certificate != "" {
return
}
path := filepath.Join(base, name)
if _, err := os.Stat(path); err == nil {
*certificate = path
}
}
func newClient(requestCredentials requestCredentials) (c *client, err error) {
url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/")
if err != nil {
return
}
if url.Scheme != "http" && url.Scheme != "https" {
err = errors.New("only http or https scheme supported")
return
}
c = &client{
url: url,
caFile: requestCredentials.GetTLSCAFile(),
certFile: requestCredentials.GetTLSCertFile(),
keyFile: requestCredentials.GetTLSKeyFile(),
requestBackOffs: make(map[string]*backoff.Backoff),
}
host := strings.Split(url.Host, ":")[0]
if CertificateDirectory != "" {
c.findCertificate(&c.caFile, CertificateDirectory, host+".crt")
c.findCertificate(&c.certFile, CertificateDirectory, host+".auth.crt")
c.findCertificate(&c.keyFile, CertificateDirectory, host+".auth.key")
}
return
}