github.com/nilium/gitlab-runner@v12.5.0+incompatible/network/client_test.go (about)

     1  package network
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"os"
    17  	"path/filepath"
    18  	"strconv"
    19  	"testing"
    20  
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	. "gitlab.com/gitlab-org/gitlab-runner/common"
    26  )
    27  
    28  func clientHandler(w http.ResponseWriter, r *http.Request) {
    29  	body, _ := ioutil.ReadAll(r.Body)
    30  	logrus.Debugln(r.Method, r.URL.String(),
    31  		"Content-Type:", r.Header.Get("Content-Type"),
    32  		"Accept:", r.Header.Get("Accept"),
    33  		"Body:", string(body))
    34  
    35  	switch r.URL.Path {
    36  	case "/api/v4/test/ok":
    37  	case "/api/v4/test/auth":
    38  		w.WriteHeader(http.StatusForbidden)
    39  	case "/api/v4/test/json":
    40  		if r.Header.Get("Content-Type") != "application/json" {
    41  			w.WriteHeader(http.StatusBadRequest)
    42  		} else if r.Header.Get("Accept") != "application/json" {
    43  			w.WriteHeader(http.StatusNotAcceptable)
    44  		} else {
    45  			w.Header().Set("Content-Type", "application/json")
    46  			fmt.Fprint(w, "{\"key\":\"value\"}")
    47  		}
    48  	default:
    49  		w.WriteHeader(http.StatusNotFound)
    50  	}
    51  }
    52  
    53  func writeTLSCertificate(s *httptest.Server, file string) error {
    54  	c := s.TLS.Certificates[0]
    55  	if c.Certificate == nil || c.Certificate[0] == nil {
    56  		return errors.New("no predefined certificate")
    57  	}
    58  
    59  	encoded := pem.EncodeToMemory(&pem.Block{
    60  		Type:  "CERTIFICATE",
    61  		Bytes: c.Certificate[0],
    62  	})
    63  
    64  	return ioutil.WriteFile(file, encoded, 0600)
    65  }
    66  
    67  func writeTLSKeyPair(s *httptest.Server, certFile string, keyFile string) error {
    68  	c := s.TLS.Certificates[0]
    69  	if c.Certificate == nil || c.Certificate[0] == nil {
    70  		return errors.New("no predefined certificate")
    71  	}
    72  
    73  	encodedCert := pem.EncodeToMemory(&pem.Block{
    74  		Type:  "CERTIFICATE",
    75  		Bytes: c.Certificate[0],
    76  	})
    77  
    78  	if err := ioutil.WriteFile(certFile, encodedCert, 0600); err != nil {
    79  		return err
    80  	}
    81  
    82  	switch k := c.PrivateKey.(type) {
    83  	case *rsa.PrivateKey:
    84  		encodedKey := pem.EncodeToMemory(&pem.Block{
    85  			Type:  "RSA PRIVATE KEY",
    86  			Bytes: x509.MarshalPKCS1PrivateKey(k),
    87  		})
    88  		return ioutil.WriteFile(keyFile, encodedKey, 0600)
    89  	default:
    90  		return errors.New("unexpected private key type")
    91  	}
    92  }
    93  
    94  func TestNewClient(t *testing.T) {
    95  	c, err := newClient(&RunnerCredentials{
    96  		URL: "http://test.example.com/ci///",
    97  	})
    98  	assert.NoError(t, err)
    99  	assert.NotNil(t, c)
   100  	assert.Equal(t, "http://test.example.com/api/v4/", c.url.String())
   101  }
   102  
   103  func TestInvalidUrl(t *testing.T) {
   104  	_, err := newClient(&RunnerCredentials{
   105  		URL: "address.com/ci///",
   106  	})
   107  	assert.Error(t, err)
   108  }
   109  
   110  func TestClientDo(t *testing.T) {
   111  	s := httptest.NewServer(http.HandlerFunc(clientHandler))
   112  	defer s.Close()
   113  
   114  	c, err := newClient(&RunnerCredentials{
   115  		URL: s.URL,
   116  	})
   117  	assert.NoError(t, err)
   118  	assert.NotNil(t, c)
   119  
   120  	statusCode, statusText, _ := c.doJSON("test/auth", "GET", http.StatusOK, nil, nil)
   121  	assert.Equal(t, http.StatusForbidden, statusCode, statusText)
   122  
   123  	req := struct {
   124  		Query bool `json:"query"`
   125  	}{
   126  		true,
   127  	}
   128  
   129  	res := struct {
   130  		Key string `json:"key"`
   131  	}{}
   132  
   133  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, &res)
   134  	assert.Equal(t, http.StatusBadRequest, statusCode, statusText)
   135  
   136  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, nil)
   137  	assert.Equal(t, http.StatusNotAcceptable, statusCode, statusText)
   138  
   139  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, nil)
   140  	assert.Equal(t, http.StatusBadRequest, statusCode, statusText)
   141  
   142  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, &res)
   143  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   144  	assert.Equal(t, "value", res.Key, statusText)
   145  }
   146  
   147  func TestClientInvalidSSL(t *testing.T) {
   148  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   149  	defer s.Close()
   150  
   151  	c, _ := newClient(&RunnerCredentials{
   152  		URL: s.URL,
   153  	})
   154  	statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   155  	assert.Equal(t, -1, statusCode, statusText)
   156  	assert.Contains(t, statusText, "certificate signed by unknown authority")
   157  }
   158  
   159  func TestClientTLSCAFile(t *testing.T) {
   160  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   161  	defer s.Close()
   162  
   163  	file, err := ioutil.TempFile("", "cert_")
   164  	assert.NoError(t, err)
   165  	file.Close()
   166  	defer os.Remove(file.Name())
   167  
   168  	err = writeTLSCertificate(s, file.Name())
   169  	assert.NoError(t, err)
   170  
   171  	c, _ := newClient(&RunnerCredentials{
   172  		URL:       s.URL,
   173  		TLSCAFile: file.Name(),
   174  	})
   175  	statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   176  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   177  
   178  	tlsData, err := c.getResponseTLSData(resp.TLS)
   179  	assert.NoError(t, err)
   180  	assert.NotEmpty(t, tlsData.CAChain)
   181  }
   182  
   183  func TestClientCertificateInPredefinedDirectory(t *testing.T) {
   184  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   185  	defer s.Close()
   186  
   187  	serverURL, err := url.Parse(s.URL)
   188  	require.NoError(t, err)
   189  	hostname, _, err := net.SplitHostPort(serverURL.Host)
   190  	require.NoError(t, err)
   191  
   192  	tempDir, err := ioutil.TempDir("", "certs")
   193  	assert.NoError(t, err)
   194  	defer os.RemoveAll(tempDir)
   195  	CertificateDirectory = tempDir
   196  
   197  	err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt"))
   198  	assert.NoError(t, err)
   199  
   200  	c, _ := newClient(&RunnerCredentials{
   201  		URL: s.URL,
   202  	})
   203  	statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   204  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   205  
   206  	tlsData, err := c.getResponseTLSData(resp.TLS)
   207  	assert.NoError(t, err)
   208  	assert.NotEmpty(t, tlsData.CAChain)
   209  }
   210  
   211  func TestClientInvalidTLSAuth(t *testing.T) {
   212  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   213  	s.TLS = new(tls.Config)
   214  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   215  	s.StartTLS()
   216  	defer s.Close()
   217  
   218  	ca, err := ioutil.TempFile("", "cert_")
   219  	assert.NoError(t, err)
   220  	ca.Close()
   221  	defer os.Remove(ca.Name())
   222  
   223  	err = writeTLSCertificate(s, ca.Name())
   224  	assert.NoError(t, err)
   225  
   226  	c, _ := newClient(&RunnerCredentials{
   227  		URL:       s.URL,
   228  		TLSCAFile: ca.Name(),
   229  	})
   230  	statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   231  	assert.Equal(t, -1, statusCode, statusText)
   232  	assert.Contains(t, statusText, "tls: bad certificate")
   233  }
   234  
   235  func TestClientTLSAuth(t *testing.T) {
   236  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   237  	s.TLS = new(tls.Config)
   238  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   239  	s.StartTLS()
   240  	defer s.Close()
   241  
   242  	ca, err := ioutil.TempFile("", "cert_")
   243  	assert.NoError(t, err)
   244  	ca.Close()
   245  	defer os.Remove(ca.Name())
   246  
   247  	err = writeTLSCertificate(s, ca.Name())
   248  	assert.NoError(t, err)
   249  
   250  	cert, err := ioutil.TempFile("", "cert_")
   251  	assert.NoError(t, err)
   252  	cert.Close()
   253  	defer os.Remove(cert.Name())
   254  
   255  	key, err := ioutil.TempFile("", "key_")
   256  	assert.NoError(t, err)
   257  	key.Close()
   258  	defer os.Remove(key.Name())
   259  
   260  	err = writeTLSKeyPair(s, cert.Name(), key.Name())
   261  	assert.NoError(t, err)
   262  
   263  	c, _ := newClient(&RunnerCredentials{
   264  		URL:         s.URL,
   265  		TLSCAFile:   ca.Name(),
   266  		TLSCertFile: cert.Name(),
   267  		TLSKeyFile:  key.Name(),
   268  	})
   269  
   270  	statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   271  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   272  
   273  	tlsData, err := c.getResponseTLSData(resp.TLS)
   274  	assert.NoError(t, err)
   275  	assert.NotEmpty(t, tlsData.CAChain)
   276  	assert.Equal(t, cert.Name(), tlsData.CertFile)
   277  	assert.Equal(t, key.Name(), tlsData.KeyFile)
   278  }
   279  
   280  func TestClientTLSAuthCertificatesInPredefinedDirectory(t *testing.T) {
   281  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   282  	s.TLS = new(tls.Config)
   283  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   284  	s.StartTLS()
   285  	defer s.Close()
   286  
   287  	tempDir, err := ioutil.TempDir("", "certs")
   288  	assert.NoError(t, err)
   289  	defer os.RemoveAll(tempDir)
   290  	CertificateDirectory = tempDir
   291  
   292  	serverURL, err := url.Parse(s.URL)
   293  	require.NoError(t, err)
   294  	hostname, _, err := net.SplitHostPort(serverURL.Host)
   295  	require.NoError(t, err)
   296  
   297  	err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt"))
   298  	assert.NoError(t, err)
   299  
   300  	err = writeTLSKeyPair(s,
   301  		filepath.Join(tempDir, hostname+".auth.crt"),
   302  		filepath.Join(tempDir, hostname+".auth.key"))
   303  	assert.NoError(t, err)
   304  
   305  	c, _ := newClient(&RunnerCredentials{
   306  		URL: s.URL,
   307  	})
   308  	statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   309  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   310  
   311  	tlsData, err := c.getResponseTLSData(resp.TLS)
   312  	assert.NoError(t, err)
   313  	assert.NotEmpty(t, tlsData.CAChain)
   314  	assert.NotEmpty(t, tlsData.CertFile)
   315  	assert.NotEmpty(t, tlsData.KeyFile)
   316  }
   317  
   318  func TestUrlFixing(t *testing.T) {
   319  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci///"))
   320  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci/"))
   321  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci"))
   322  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/"))
   323  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com///"))
   324  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com"))
   325  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci/"))
   326  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci///"))
   327  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci"))
   328  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/"))
   329  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab///"))
   330  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab"))
   331  }
   332  
   333  func charsetTestClientHandler(w http.ResponseWriter, r *http.Request) {
   334  	switch r.URL.Path {
   335  	case "/api/v4/with-charset":
   336  		w.Header().Set("Content-Type", "application/json; charset=utf-8")
   337  		w.WriteHeader(http.StatusOK)
   338  		fmt.Fprint(w, "{\"key\":\"value\"}")
   339  	case "/api/v4/without-charset":
   340  		w.Header().Set("Content-Type", "application/json")
   341  		w.WriteHeader(http.StatusOK)
   342  		fmt.Fprint(w, "{\"key\":\"value\"}")
   343  	case "/api/v4/without-json":
   344  		w.Header().Set("Content-Type", "application/octet-stream")
   345  		w.WriteHeader(http.StatusOK)
   346  		fmt.Fprint(w, "{\"key\":\"value\"}")
   347  	case "/api/v4/invalid-header":
   348  		w.Header().Set("Content-Type", "application/octet-stream, test, a=b")
   349  		w.WriteHeader(http.StatusOK)
   350  		fmt.Fprint(w, "{\"key\":\"value\"}")
   351  	}
   352  }
   353  
   354  func TestClientHandleCharsetInContentType(t *testing.T) {
   355  	s := httptest.NewServer(http.HandlerFunc(charsetTestClientHandler))
   356  	defer s.Close()
   357  
   358  	c, _ := newClient(&RunnerCredentials{
   359  		URL: s.URL,
   360  	})
   361  
   362  	res := struct {
   363  		Key string `json:"key"`
   364  	}{}
   365  
   366  	statusCode, statusText, _ := c.doJSON("with-charset", "GET", http.StatusOK, nil, &res)
   367  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   368  
   369  	statusCode, statusText, _ = c.doJSON("without-charset", "GET", http.StatusOK, nil, &res)
   370  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   371  
   372  	statusCode, statusText, _ = c.doJSON("without-json", "GET", http.StatusOK, nil, &res)
   373  	assert.Equal(t, -1, statusCode, statusText)
   374  
   375  	statusCode, statusText, _ = c.doJSON("invalid-header", "GET", http.StatusOK, nil, &res)
   376  	assert.Equal(t, -1, statusCode, statusText)
   377  }
   378  
   379  type backoffTestCase struct {
   380  	responseStatus int
   381  	mustBackoff    bool
   382  }
   383  
   384  func tooManyRequestsHandler(w http.ResponseWriter, r *http.Request) {
   385  	status, err := strconv.Atoi(r.Header.Get("responseStatus"))
   386  	if err != nil {
   387  		w.WriteHeader(599)
   388  	} else {
   389  		w.WriteHeader(status)
   390  	}
   391  }
   392  
   393  func TestRequestsBackOff(t *testing.T) {
   394  	s := httptest.NewServer(http.HandlerFunc(tooManyRequestsHandler))
   395  	defer s.Close()
   396  
   397  	c, _ := newClient(&RunnerCredentials{
   398  		URL: s.URL,
   399  	})
   400  
   401  	testCases := []backoffTestCase{
   402  		{http.StatusCreated, false},
   403  		{http.StatusInternalServerError, true},
   404  		{http.StatusBadGateway, true},
   405  		{http.StatusServiceUnavailable, true},
   406  		{http.StatusOK, false},
   407  		{http.StatusConflict, true},
   408  		{http.StatusTooManyRequests, true},
   409  		{http.StatusCreated, false},
   410  		{http.StatusInternalServerError, true},
   411  		{http.StatusTooManyRequests, true},
   412  		{599, true},
   413  		{499, true},
   414  	}
   415  
   416  	backoff := c.ensureBackoff("POST", "")
   417  	for id, testCase := range testCases {
   418  		t.Run(fmt.Sprintf("%d-%d", id, testCase.responseStatus), func(t *testing.T) {
   419  			backoff.Reset()
   420  			assert.Zero(t, backoff.Attempt())
   421  
   422  			var body io.Reader
   423  			headers := make(http.Header)
   424  			headers.Add("responseStatus", strconv.Itoa(testCase.responseStatus))
   425  
   426  			res, err := c.do("/", "POST", body, "application/json", headers)
   427  
   428  			assert.NoError(t, err)
   429  			assert.Equal(t, testCase.responseStatus, res.StatusCode)
   430  
   431  			var expected float64
   432  			if testCase.mustBackoff {
   433  				expected = 1.0
   434  			}
   435  			assert.Equal(t, expected, backoff.Attempt())
   436  		})
   437  	}
   438  }