github.com/alloyci/alloy-runner@v1.0.1-0.20180222164613-925503ccafd6/network/client_test.go (about)

     1  package network
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"os"
    17  	"path/filepath"
    18  	"strconv"
    19  	"testing"
    20  
    21  	"github.com/Sirupsen/logrus"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	. "gitlab.com/gitlab-org/gitlab-runner/common"
    26  )
    27  
    28  func clientHandler(w http.ResponseWriter, r *http.Request) {
    29  	body, _ := ioutil.ReadAll(r.Body)
    30  	logrus.Debugln(r.Method, r.URL.String(),
    31  		"Content-Type:", r.Header.Get("Content-Type"),
    32  		"Accept:", r.Header.Get("Accept"),
    33  		"Body:", string(body))
    34  
    35  	switch r.URL.Path {
    36  	case "/api/v4/test/ok":
    37  	case "/api/v4/test/auth":
    38  		w.WriteHeader(http.StatusForbidden)
    39  	case "/api/v4/test/json":
    40  		if r.Header.Get("Content-Type") != "application/json" {
    41  			w.WriteHeader(http.StatusBadRequest)
    42  		} else if r.Header.Get("Accept") != "application/json" {
    43  			w.WriteHeader(http.StatusNotAcceptable)
    44  		} else {
    45  			w.Header().Set("Content-Type", "application/json")
    46  			fmt.Fprint(w, "{\"key\":\"value\"}")
    47  		}
    48  	default:
    49  		w.WriteHeader(http.StatusNotFound)
    50  	}
    51  }
    52  
    53  func writeTLSCertificate(s *httptest.Server, file string) error {
    54  	c := s.TLS.Certificates[0]
    55  	if c.Certificate == nil || c.Certificate[0] == nil {
    56  		return errors.New("no predefined certificate")
    57  	}
    58  
    59  	encoded := pem.EncodeToMemory(&pem.Block{
    60  		Type:  "CERTIFICATE",
    61  		Bytes: c.Certificate[0],
    62  	})
    63  
    64  	return ioutil.WriteFile(file, encoded, 0600)
    65  }
    66  
    67  func writeTLSKeyPair(s *httptest.Server, certFile string, keyFile string) error {
    68  	c := s.TLS.Certificates[0]
    69  	if c.Certificate == nil || c.Certificate[0] == nil {
    70  		return errors.New("no predefined certificate")
    71  	}
    72  
    73  	encodedCert := pem.EncodeToMemory(&pem.Block{
    74  		Type:  "CERTIFICATE",
    75  		Bytes: c.Certificate[0],
    76  	})
    77  
    78  	if err := ioutil.WriteFile(certFile, encodedCert, 0600); err != nil {
    79  		return err
    80  	}
    81  
    82  	switch k := c.PrivateKey.(type) {
    83  	case *rsa.PrivateKey:
    84  		encodedKey := pem.EncodeToMemory(&pem.Block{
    85  			Type:  "RSA PRIVATE KEY",
    86  			Bytes: x509.MarshalPKCS1PrivateKey(k),
    87  		})
    88  		return ioutil.WriteFile(keyFile, encodedKey, 0600)
    89  	default:
    90  		return errors.New("unexpected private key type")
    91  	}
    92  }
    93  
    94  func TestNewClient(t *testing.T) {
    95  	c, err := newClient(&RunnerCredentials{
    96  		URL: "http://test.example.com/ci///",
    97  	})
    98  	assert.NoError(t, err)
    99  	assert.NotNil(t, c)
   100  	assert.Equal(t, "http://test.example.com/api/v4/", c.url.String())
   101  }
   102  
   103  func TestInvalidUrl(t *testing.T) {
   104  	_, err := newClient(&RunnerCredentials{
   105  		URL: "address.com/ci///",
   106  	})
   107  	assert.Error(t, err)
   108  }
   109  
   110  func TestClientDo(t *testing.T) {
   111  	s := httptest.NewServer(http.HandlerFunc(clientHandler))
   112  	defer s.Close()
   113  
   114  	c, err := newClient(&RunnerCredentials{
   115  		URL: s.URL,
   116  	})
   117  	assert.NoError(t, err)
   118  	assert.NotNil(t, c)
   119  
   120  	statusCode, statusText, _ := c.doJSON("test/auth", "GET", http.StatusOK, nil, nil)
   121  	assert.Equal(t, http.StatusForbidden, statusCode, statusText)
   122  
   123  	req := struct {
   124  		Query bool `json:"query"`
   125  	}{
   126  		true,
   127  	}
   128  
   129  	res := struct {
   130  		Key string `json:"key"`
   131  	}{}
   132  
   133  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, &res)
   134  	assert.Equal(t, http.StatusBadRequest, statusCode, statusText)
   135  
   136  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, nil)
   137  	assert.Equal(t, http.StatusNotAcceptable, statusCode, statusText)
   138  
   139  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, nil)
   140  	assert.Equal(t, http.StatusBadRequest, statusCode, statusText)
   141  
   142  	statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, &res)
   143  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   144  	assert.Equal(t, "value", res.Key, statusText)
   145  }
   146  
   147  func TestClientInvalidSSL(t *testing.T) {
   148  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   149  	defer s.Close()
   150  
   151  	c, _ := newClient(&RunnerCredentials{
   152  		URL: s.URL,
   153  	})
   154  	statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   155  	assert.Equal(t, -1, statusCode, statusText)
   156  	assert.Contains(t, statusText, "certificate signed by unknown authority")
   157  }
   158  
   159  func TestClientTLSCAFile(t *testing.T) {
   160  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   161  	defer s.Close()
   162  
   163  	file, err := ioutil.TempFile("", "cert_")
   164  	assert.NoError(t, err)
   165  	file.Close()
   166  	defer os.Remove(file.Name())
   167  
   168  	err = writeTLSCertificate(s, file.Name())
   169  	assert.NoError(t, err)
   170  
   171  	c, _ := newClient(&RunnerCredentials{
   172  		URL:       s.URL,
   173  		TLSCAFile: file.Name(),
   174  	})
   175  	statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   176  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   177  	assert.NotEmpty(t, tlsData.CAChain)
   178  }
   179  
   180  func TestClientCertificateInPredefinedDirectory(t *testing.T) {
   181  	s := httptest.NewTLSServer(http.HandlerFunc(clientHandler))
   182  	defer s.Close()
   183  
   184  	serverURL, err := url.Parse(s.URL)
   185  	require.NoError(t, err)
   186  	hostname, _, err := net.SplitHostPort(serverURL.Host)
   187  	require.NoError(t, err)
   188  
   189  	tempDir, err := ioutil.TempDir("", "certs")
   190  	assert.NoError(t, err)
   191  	defer os.RemoveAll(tempDir)
   192  	CertificateDirectory = tempDir
   193  
   194  	err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt"))
   195  	assert.NoError(t, err)
   196  
   197  	c, _ := newClient(&RunnerCredentials{
   198  		URL: s.URL,
   199  	})
   200  	statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   201  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   202  	assert.NotEmpty(t, tlsData.CAChain)
   203  }
   204  
   205  func TestClientInvalidTLSAuth(t *testing.T) {
   206  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   207  	s.TLS = new(tls.Config)
   208  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   209  	s.StartTLS()
   210  	defer s.Close()
   211  
   212  	ca, err := ioutil.TempFile("", "cert_")
   213  	assert.NoError(t, err)
   214  	ca.Close()
   215  	defer os.Remove(ca.Name())
   216  
   217  	err = writeTLSCertificate(s, ca.Name())
   218  	assert.NoError(t, err)
   219  
   220  	c, _ := newClient(&RunnerCredentials{
   221  		URL:       s.URL,
   222  		TLSCAFile: ca.Name(),
   223  	})
   224  	statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   225  	assert.Equal(t, -1, statusCode, statusText)
   226  	assert.Contains(t, statusText, "tls: bad certificate")
   227  }
   228  
   229  func TestClientTLSAuth(t *testing.T) {
   230  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   231  	s.TLS = new(tls.Config)
   232  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   233  	s.StartTLS()
   234  	defer s.Close()
   235  
   236  	ca, err := ioutil.TempFile("", "cert_")
   237  	assert.NoError(t, err)
   238  	ca.Close()
   239  	defer os.Remove(ca.Name())
   240  
   241  	err = writeTLSCertificate(s, ca.Name())
   242  	assert.NoError(t, err)
   243  
   244  	cert, err := ioutil.TempFile("", "cert_")
   245  	assert.NoError(t, err)
   246  	cert.Close()
   247  	defer os.Remove(cert.Name())
   248  
   249  	key, err := ioutil.TempFile("", "key_")
   250  	assert.NoError(t, err)
   251  	key.Close()
   252  	defer os.Remove(key.Name())
   253  
   254  	err = writeTLSKeyPair(s, cert.Name(), key.Name())
   255  	assert.NoError(t, err)
   256  
   257  	c, _ := newClient(&RunnerCredentials{
   258  		URL:         s.URL,
   259  		TLSCAFile:   ca.Name(),
   260  		TLSCertFile: cert.Name(),
   261  		TLSKeyFile:  key.Name(),
   262  	})
   263  	statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   264  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   265  	assert.NotEmpty(t, tlsData.CAChain)
   266  	assert.Equal(t, cert.Name(), tlsData.CertFile)
   267  	assert.Equal(t, key.Name(), tlsData.KeyFile)
   268  }
   269  
   270  func TestClientTLSAuthCertificatesInPredefinedDirectory(t *testing.T) {
   271  	s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
   272  	s.TLS = new(tls.Config)
   273  	s.TLS.ClientAuth = tls.RequireAnyClientCert
   274  	s.StartTLS()
   275  	defer s.Close()
   276  
   277  	tempDir, err := ioutil.TempDir("", "certs")
   278  	assert.NoError(t, err)
   279  	defer os.RemoveAll(tempDir)
   280  	CertificateDirectory = tempDir
   281  
   282  	serverURL, err := url.Parse(s.URL)
   283  	require.NoError(t, err)
   284  	hostname, _, err := net.SplitHostPort(serverURL.Host)
   285  	require.NoError(t, err)
   286  
   287  	err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt"))
   288  	assert.NoError(t, err)
   289  
   290  	err = writeTLSKeyPair(s,
   291  		filepath.Join(tempDir, hostname+".auth.crt"),
   292  		filepath.Join(tempDir, hostname+".auth.key"))
   293  	assert.NoError(t, err)
   294  
   295  	c, _ := newClient(&RunnerCredentials{
   296  		URL: s.URL,
   297  	})
   298  	statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil)
   299  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   300  	assert.NotEmpty(t, tlsData.CAChain)
   301  	assert.NotEmpty(t, tlsData.CertFile)
   302  	assert.NotEmpty(t, tlsData.KeyFile)
   303  }
   304  
   305  func TestUrlFixing(t *testing.T) {
   306  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci///"))
   307  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci/"))
   308  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci"))
   309  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/"))
   310  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com///"))
   311  	assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com"))
   312  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci/"))
   313  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci///"))
   314  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci"))
   315  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/"))
   316  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab///"))
   317  	assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab"))
   318  }
   319  
   320  func charsetTestClientHandler(w http.ResponseWriter, r *http.Request) {
   321  	switch r.URL.Path {
   322  	case "/api/v4/with-charset":
   323  		w.Header().Set("Content-Type", "application/json; charset=utf-8")
   324  		w.WriteHeader(http.StatusOK)
   325  		fmt.Fprint(w, "{\"key\":\"value\"}")
   326  	case "/api/v4/without-charset":
   327  		w.Header().Set("Content-Type", "application/json")
   328  		w.WriteHeader(http.StatusOK)
   329  		fmt.Fprint(w, "{\"key\":\"value\"}")
   330  	case "/api/v4/without-json":
   331  		w.Header().Set("Content-Type", "application/octet-stream")
   332  		w.WriteHeader(http.StatusOK)
   333  		fmt.Fprint(w, "{\"key\":\"value\"}")
   334  	case "/api/v4/invalid-header":
   335  		w.Header().Set("Content-Type", "application/octet-stream, test, a=b")
   336  		w.WriteHeader(http.StatusOK)
   337  		fmt.Fprint(w, "{\"key\":\"value\"}")
   338  	}
   339  }
   340  
   341  func TestClientHandleCharsetInContentType(t *testing.T) {
   342  	s := httptest.NewServer(http.HandlerFunc(charsetTestClientHandler))
   343  	defer s.Close()
   344  
   345  	c, _ := newClient(&RunnerCredentials{
   346  		URL: s.URL,
   347  	})
   348  
   349  	res := struct {
   350  		Key string `json:"key"`
   351  	}{}
   352  
   353  	statusCode, statusText, _ := c.doJSON("with-charset", "GET", http.StatusOK, nil, &res)
   354  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   355  
   356  	statusCode, statusText, _ = c.doJSON("without-charset", "GET", http.StatusOK, nil, &res)
   357  	assert.Equal(t, http.StatusOK, statusCode, statusText)
   358  
   359  	statusCode, statusText, _ = c.doJSON("without-json", "GET", http.StatusOK, nil, &res)
   360  	assert.Equal(t, -1, statusCode, statusText)
   361  
   362  	statusCode, statusText, _ = c.doJSON("invalid-header", "GET", http.StatusOK, nil, &res)
   363  	assert.Equal(t, -1, statusCode, statusText)
   364  }
   365  
   366  type backoffTestCase struct {
   367  	responseStatus int
   368  	mustBackoff    bool
   369  }
   370  
   371  func tooManyRequestsHandler(w http.ResponseWriter, r *http.Request) {
   372  	status, err := strconv.Atoi(r.Header.Get("responseStatus"))
   373  	if err != nil {
   374  		w.WriteHeader(599)
   375  	} else {
   376  		w.WriteHeader(status)
   377  	}
   378  }
   379  
   380  func TestRequestsBackOff(t *testing.T) {
   381  	s := httptest.NewServer(http.HandlerFunc(tooManyRequestsHandler))
   382  	defer s.Close()
   383  
   384  	c, _ := newClient(&RunnerCredentials{
   385  		URL: s.URL,
   386  	})
   387  
   388  	testCases := []backoffTestCase{
   389  		{http.StatusCreated, false},
   390  		{http.StatusInternalServerError, true},
   391  		{http.StatusBadGateway, true},
   392  		{http.StatusServiceUnavailable, true},
   393  		{http.StatusOK, false},
   394  		{http.StatusConflict, true},
   395  		{http.StatusTooManyRequests, true},
   396  		{http.StatusCreated, false},
   397  		{http.StatusInternalServerError, true},
   398  		{http.StatusTooManyRequests, true},
   399  		{599, true},
   400  		{499, true},
   401  	}
   402  
   403  	backoff := c.ensureBackoff("POST", "")
   404  	for id, testCase := range testCases {
   405  		t.Run(fmt.Sprintf("%d-%d", id, testCase.responseStatus), func(t *testing.T) {
   406  			backoff.Reset()
   407  			assert.Zero(t, backoff.Attempt())
   408  
   409  			var body io.Reader
   410  			headers := make(http.Header)
   411  			headers.Add("responseStatus", strconv.Itoa(testCase.responseStatus))
   412  
   413  			res, err := c.do("/", "POST", body, "application/json", headers)
   414  
   415  			assert.NoError(t, err)
   416  			assert.Equal(t, testCase.responseStatus, res.StatusCode)
   417  
   418  			var expected float64
   419  			if testCase.mustBackoff {
   420  				expected = 1.0
   421  			}
   422  			assert.Equal(t, expected, backoff.Attempt())
   423  		})
   424  	}
   425  }