github.com/alloyci/alloy-runner@v1.0.1-0.20180222164613-925503ccafd6/network/client.go (about)

     1  package network
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"mime"
    15  	"net"
    16  	"net/http"
    17  	"net/url"
    18  	"os"
    19  	"path/filepath"
    20  	"strings"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/Sirupsen/logrus"
    25  	"github.com/jpillora/backoff"
    26  
    27  	"gitlab.com/gitlab-org/gitlab-runner/common"
    28  )
    29  
    30  type requestCredentials interface {
    31  	GetURL() string
    32  	GetToken() string
    33  	GetTLSCAFile() string
    34  	GetTLSCertFile() string
    35  	GetTLSKeyFile() string
    36  }
    37  
    38  var (
    39  	dialer = net.Dialer{
    40  		Timeout:   30 * time.Second,
    41  		KeepAlive: 30 * time.Second,
    42  	}
    43  
    44  	backOffDelayMin    = 100 * time.Millisecond
    45  	backOffDelayMax    = 60 * time.Second
    46  	backOffDelayFactor = 2.0
    47  	backOffDelayJitter = true
    48  )
    49  
    50  type client struct {
    51  	http.Client
    52  	url             *url.URL
    53  	caFile          string
    54  	certFile        string
    55  	keyFile         string
    56  	caData          []byte
    57  	skipVerify      bool
    58  	updateTime      time.Time
    59  	lastUpdate      string
    60  	requestBackOffs map[string]*backoff.Backoff
    61  	lock            sync.Mutex
    62  }
    63  
    64  type ResponseTLSData struct {
    65  	CAChain  string
    66  	CertFile string
    67  	KeyFile  string
    68  }
    69  
    70  func (n *client) getLastUpdate() string {
    71  	return n.lastUpdate
    72  }
    73  
    74  func (n *client) setLastUpdate(headers http.Header) {
    75  	if lu := headers.Get("X-GitLab-Last-Update"); len(lu) > 0 {
    76  		n.lastUpdate = lu
    77  	}
    78  }
    79  
    80  func (n *client) ensureTLSConfig() {
    81  	// certificate got modified
    82  	if stat, err := os.Stat(n.caFile); err == nil && n.updateTime.Before(stat.ModTime()) {
    83  		n.Transport = nil
    84  	}
    85  
    86  	// client certificate got modified
    87  	if stat, err := os.Stat(n.certFile); err == nil && n.updateTime.Before(stat.ModTime()) {
    88  		n.Transport = nil
    89  	}
    90  
    91  	// client private key got modified
    92  	if stat, err := os.Stat(n.keyFile); err == nil && n.updateTime.Before(stat.ModTime()) {
    93  		n.Transport = nil
    94  	}
    95  
    96  	// create or update transport
    97  	if n.Transport == nil {
    98  		n.updateTime = time.Now()
    99  		n.createTransport()
   100  	}
   101  }
   102  
   103  func (n *client) addTLSCA(tlsConfig *tls.Config) {
   104  	// load TLS CA certificate
   105  	if file := n.caFile; file != "" && !n.skipVerify {
   106  		logrus.Debugln("Trying to load", file, "...")
   107  
   108  		data, err := ioutil.ReadFile(file)
   109  		if err == nil {
   110  			pool, err := x509.SystemCertPool()
   111  			if err != nil {
   112  				logrus.Warningln("Failed to load system CertPool:", err)
   113  			}
   114  			if pool == nil {
   115  				pool = x509.NewCertPool()
   116  			}
   117  			if pool.AppendCertsFromPEM(data) {
   118  				tlsConfig.RootCAs = pool
   119  				n.caData = data
   120  			} else {
   121  				logrus.Errorln("Failed to parse PEM in", n.caFile)
   122  			}
   123  		} else {
   124  			if !os.IsNotExist(err) {
   125  				logrus.Errorln("Failed to load", n.caFile, err)
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func (n *client) addTLSAuth(tlsConfig *tls.Config) {
   132  	// load TLS client keypair
   133  	if cert, key := n.certFile, n.keyFile; cert != "" && key != "" {
   134  		logrus.Debugln("Trying to load", cert, "and", key, "pair...")
   135  
   136  		certificate, err := tls.LoadX509KeyPair(cert, key)
   137  		if err == nil {
   138  			tlsConfig.Certificates = []tls.Certificate{certificate}
   139  			tlsConfig.BuildNameToCertificate()
   140  		} else {
   141  			if !os.IsNotExist(err) {
   142  				logrus.Errorln("Failed to load", cert, key, err)
   143  			}
   144  		}
   145  	}
   146  }
   147  
   148  func (n *client) createTransport() {
   149  	// create reference TLS config
   150  	tlsConfig := tls.Config{
   151  		MinVersion:         tls.VersionTLS10,
   152  		InsecureSkipVerify: n.skipVerify,
   153  	}
   154  
   155  	n.addTLSCA(&tlsConfig)
   156  	n.addTLSAuth(&tlsConfig)
   157  
   158  	// create transport
   159  	n.Transport = &http.Transport{
   160  		Proxy: http.ProxyFromEnvironment,
   161  		Dial: func(network, addr string) (net.Conn, error) {
   162  			logrus.Debugln("Dialing:", network, addr, "...")
   163  			return dialer.Dial(network, addr)
   164  		},
   165  		TLSClientConfig:       &tlsConfig,
   166  		MaxIdleConns:          100,
   167  		IdleConnTimeout:       90 * time.Second,
   168  		TLSHandshakeTimeout:   10 * time.Second,
   169  		ExpectContinueTimeout: 1 * time.Second,
   170  		ResponseHeaderTimeout: 10 * time.Minute,
   171  	}
   172  	n.Timeout = common.DefaultNetworkClientTimeout
   173  }
   174  
   175  func (n *client) getCAChain(tls *tls.ConnectionState) string {
   176  	if len(n.caData) != 0 {
   177  		return string(n.caData)
   178  	}
   179  
   180  	if tls == nil {
   181  		return ""
   182  	}
   183  
   184  	// Don't reorder certificates by putting them directly into the map
   185  	var certificates []*x509.Certificate
   186  	seenCertificates := make(map[string]bool, 0)
   187  
   188  	for _, verifiedChain := range tls.VerifiedChains {
   189  		for _, certificate := range verifiedChain {
   190  			signature := hex.EncodeToString(certificate.Signature)
   191  			if seenCertificates[signature] {
   192  				continue
   193  			}
   194  
   195  			seenCertificates[signature] = true
   196  			certificates = append(certificates, certificate)
   197  		}
   198  	}
   199  
   200  	out := bytes.NewBuffer(nil)
   201  	for _, certificate := range certificates {
   202  		if err := pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw}); err != nil {
   203  			logrus.Warn("Failed to encode certificate from chain:", err)
   204  		}
   205  	}
   206  
   207  	return out.String()
   208  }
   209  
   210  func (n *client) ensureBackoff(method, uri string) *backoff.Backoff {
   211  	n.lock.Lock()
   212  	defer n.lock.Unlock()
   213  
   214  	key := fmt.Sprintf("%s_%s", method, uri)
   215  	if n.requestBackOffs[key] == nil {
   216  		n.requestBackOffs[key] = &backoff.Backoff{
   217  			Min:    backOffDelayMin,
   218  			Max:    backOffDelayMax,
   219  			Factor: backOffDelayFactor,
   220  			Jitter: backOffDelayJitter,
   221  		}
   222  	}
   223  
   224  	return n.requestBackOffs[key]
   225  }
   226  
   227  func (n *client) backoffRequired(res *http.Response) bool {
   228  	return res.StatusCode >= 400 && res.StatusCode < 600
   229  }
   230  
   231  func (n *client) doBackoffRequest(req *http.Request) (res *http.Response, err error) {
   232  	res, err = n.Do(req)
   233  	if err != nil {
   234  		err = fmt.Errorf("couldn't execute %v against %s: %v", req.Method, req.URL, err)
   235  		return
   236  	}
   237  
   238  	backoffDelay := n.ensureBackoff(req.Method, req.RequestURI)
   239  	if n.backoffRequired(res) {
   240  		time.Sleep(backoffDelay.Duration())
   241  	} else {
   242  		backoffDelay.Reset()
   243  	}
   244  
   245  	return
   246  }
   247  
   248  func (n *client) do(uri, method string, request io.Reader, requestType string, headers http.Header) (res *http.Response, err error) {
   249  	url, err := n.url.Parse(uri)
   250  	if err != nil {
   251  		return
   252  	}
   253  
   254  	req, err := http.NewRequest(method, url.String(), request)
   255  	if err != nil {
   256  		err = fmt.Errorf("failed to create NewRequest: %v", err)
   257  		return
   258  	}
   259  
   260  	if headers != nil {
   261  		req.Header = headers
   262  	}
   263  
   264  	if request != nil {
   265  		req.Header.Set("Content-Type", requestType)
   266  		req.Header.Set("User-Agent", common.AppVersion.UserAgent())
   267  	}
   268  
   269  	n.ensureTLSConfig()
   270  
   271  	res, err = n.doBackoffRequest(req)
   272  	return
   273  }
   274  
   275  func (n *client) doJSON(uri, method string, statusCode int, request interface{}, response interface{}) (int, string, ResponseTLSData) {
   276  	var body io.Reader
   277  
   278  	if request != nil {
   279  		requestBody, err := json.Marshal(request)
   280  		if err != nil {
   281  			return -1, fmt.Sprintf("failed to marshal project object: %v", err), ResponseTLSData{}
   282  		}
   283  		body = bytes.NewReader(requestBody)
   284  	}
   285  
   286  	headers := make(http.Header)
   287  	if response != nil {
   288  		headers.Set("Accept", "application/json")
   289  	}
   290  
   291  	res, err := n.do(uri, method, body, "application/json", headers)
   292  	if err != nil {
   293  		return -1, err.Error(), ResponseTLSData{}
   294  	}
   295  	defer res.Body.Close()
   296  	defer io.Copy(ioutil.Discard, res.Body)
   297  
   298  	if res.StatusCode == statusCode {
   299  		if response != nil {
   300  			isApplicationJSON, err := isResponseApplicationJSON(res)
   301  			if !isApplicationJSON {
   302  				return -1, err.Error(), ResponseTLSData{}
   303  			}
   304  
   305  			d := json.NewDecoder(res.Body)
   306  			err = d.Decode(response)
   307  			if err != nil {
   308  				return -1, fmt.Sprintf("Error decoding json payload %v", err), ResponseTLSData{}
   309  			}
   310  		}
   311  	}
   312  
   313  	n.setLastUpdate(res.Header)
   314  
   315  	TLSData := ResponseTLSData{
   316  		CAChain:  n.getCAChain(res.TLS),
   317  		CertFile: n.certFile,
   318  		KeyFile:  n.keyFile,
   319  	}
   320  
   321  	return res.StatusCode, res.Status, TLSData
   322  }
   323  
   324  func isResponseApplicationJSON(res *http.Response) (result bool, err error) {
   325  	contentType := res.Header.Get("Content-Type")
   326  
   327  	mimetype, _, err := mime.ParseMediaType(contentType)
   328  	if err != nil {
   329  		return false, fmt.Errorf("Content-Type parsing error: %v", err)
   330  	}
   331  
   332  	if mimetype != "application/json" {
   333  		return false, fmt.Errorf("Server should return application/json. Got: %v", contentType)
   334  	}
   335  
   336  	return true, nil
   337  }
   338  
   339  func fixCIURL(url string) string {
   340  	url = strings.TrimRight(url, "/")
   341  	if strings.HasSuffix(url, "/ci") {
   342  		url = strings.TrimSuffix(url, "/ci")
   343  	}
   344  	return url
   345  }
   346  
   347  func (n *client) findCertificate(certificate *string, base string, name string) {
   348  	if *certificate != "" {
   349  		return
   350  	}
   351  	path := filepath.Join(base, name)
   352  	if _, err := os.Stat(path); err == nil {
   353  		*certificate = path
   354  	}
   355  }
   356  
   357  func newClient(requestCredentials requestCredentials) (c *client, err error) {
   358  	url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/")
   359  	if err != nil {
   360  		return
   361  	}
   362  
   363  	if url.Scheme != "http" && url.Scheme != "https" {
   364  		err = errors.New("only http or https scheme supported")
   365  		return
   366  	}
   367  
   368  	c = &client{
   369  		url:             url,
   370  		caFile:          requestCredentials.GetTLSCAFile(),
   371  		certFile:        requestCredentials.GetTLSCertFile(),
   372  		keyFile:         requestCredentials.GetTLSKeyFile(),
   373  		requestBackOffs: make(map[string]*backoff.Backoff),
   374  	}
   375  
   376  	host := strings.Split(url.Host, ":")[0]
   377  	if CertificateDirectory != "" {
   378  		c.findCertificate(&c.caFile, CertificateDirectory, host+".crt")
   379  		c.findCertificate(&c.certFile, CertificateDirectory, host+".auth.crt")
   380  		c.findCertificate(&c.keyFile, CertificateDirectory, host+".auth.key")
   381  	}
   382  
   383  	return
   384  }