github.com/alloyci/alloy-runner@v1.0.1-0.20180222164613-925503ccafd6/network/client_test.go (about) 1 package network 2 3 import ( 4 "crypto/rsa" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "os" 17 "path/filepath" 18 "strconv" 19 "testing" 20 21 "github.com/Sirupsen/logrus" 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 25 . "gitlab.com/gitlab-org/gitlab-runner/common" 26 ) 27 28 func clientHandler(w http.ResponseWriter, r *http.Request) { 29 body, _ := ioutil.ReadAll(r.Body) 30 logrus.Debugln(r.Method, r.URL.String(), 31 "Content-Type:", r.Header.Get("Content-Type"), 32 "Accept:", r.Header.Get("Accept"), 33 "Body:", string(body)) 34 35 switch r.URL.Path { 36 case "/api/v4/test/ok": 37 case "/api/v4/test/auth": 38 w.WriteHeader(http.StatusForbidden) 39 case "/api/v4/test/json": 40 if r.Header.Get("Content-Type") != "application/json" { 41 w.WriteHeader(http.StatusBadRequest) 42 } else if r.Header.Get("Accept") != "application/json" { 43 w.WriteHeader(http.StatusNotAcceptable) 44 } else { 45 w.Header().Set("Content-Type", "application/json") 46 fmt.Fprint(w, "{\"key\":\"value\"}") 47 } 48 default: 49 w.WriteHeader(http.StatusNotFound) 50 } 51 } 52 53 func writeTLSCertificate(s *httptest.Server, file string) error { 54 c := s.TLS.Certificates[0] 55 if c.Certificate == nil || c.Certificate[0] == nil { 56 return errors.New("no predefined certificate") 57 } 58 59 encoded := pem.EncodeToMemory(&pem.Block{ 60 Type: "CERTIFICATE", 61 Bytes: c.Certificate[0], 62 }) 63 64 return ioutil.WriteFile(file, encoded, 0600) 65 } 66 67 func writeTLSKeyPair(s *httptest.Server, certFile string, keyFile string) error { 68 c := s.TLS.Certificates[0] 69 if c.Certificate == nil || c.Certificate[0] == nil { 70 return errors.New("no predefined certificate") 71 } 72 73 encodedCert := pem.EncodeToMemory(&pem.Block{ 74 Type: "CERTIFICATE", 75 Bytes: c.Certificate[0], 76 }) 77 78 if err := ioutil.WriteFile(certFile, encodedCert, 0600); err != nil { 79 return err 80 } 81 82 switch k := c.PrivateKey.(type) { 83 case *rsa.PrivateKey: 84 encodedKey := pem.EncodeToMemory(&pem.Block{ 85 Type: "RSA PRIVATE KEY", 86 Bytes: x509.MarshalPKCS1PrivateKey(k), 87 }) 88 return ioutil.WriteFile(keyFile, encodedKey, 0600) 89 default: 90 return errors.New("unexpected private key type") 91 } 92 } 93 94 func TestNewClient(t *testing.T) { 95 c, err := newClient(&RunnerCredentials{ 96 URL: "http://test.example.com/ci///", 97 }) 98 assert.NoError(t, err) 99 assert.NotNil(t, c) 100 assert.Equal(t, "http://test.example.com/api/v4/", c.url.String()) 101 } 102 103 func TestInvalidUrl(t *testing.T) { 104 _, err := newClient(&RunnerCredentials{ 105 URL: "address.com/ci///", 106 }) 107 assert.Error(t, err) 108 } 109 110 func TestClientDo(t *testing.T) { 111 s := httptest.NewServer(http.HandlerFunc(clientHandler)) 112 defer s.Close() 113 114 c, err := newClient(&RunnerCredentials{ 115 URL: s.URL, 116 }) 117 assert.NoError(t, err) 118 assert.NotNil(t, c) 119 120 statusCode, statusText, _ := c.doJSON("test/auth", "GET", http.StatusOK, nil, nil) 121 assert.Equal(t, http.StatusForbidden, statusCode, statusText) 122 123 req := struct { 124 Query bool `json:"query"` 125 }{ 126 true, 127 } 128 129 res := struct { 130 Key string `json:"key"` 131 }{} 132 133 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, &res) 134 assert.Equal(t, http.StatusBadRequest, statusCode, statusText) 135 136 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, nil) 137 assert.Equal(t, http.StatusNotAcceptable, statusCode, statusText) 138 139 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, nil) 140 assert.Equal(t, http.StatusBadRequest, statusCode, statusText) 141 142 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, &res) 143 assert.Equal(t, http.StatusOK, statusCode, statusText) 144 assert.Equal(t, "value", res.Key, statusText) 145 } 146 147 func TestClientInvalidSSL(t *testing.T) { 148 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 149 defer s.Close() 150 151 c, _ := newClient(&RunnerCredentials{ 152 URL: s.URL, 153 }) 154 statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 155 assert.Equal(t, -1, statusCode, statusText) 156 assert.Contains(t, statusText, "certificate signed by unknown authority") 157 } 158 159 func TestClientTLSCAFile(t *testing.T) { 160 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 161 defer s.Close() 162 163 file, err := ioutil.TempFile("", "cert_") 164 assert.NoError(t, err) 165 file.Close() 166 defer os.Remove(file.Name()) 167 168 err = writeTLSCertificate(s, file.Name()) 169 assert.NoError(t, err) 170 171 c, _ := newClient(&RunnerCredentials{ 172 URL: s.URL, 173 TLSCAFile: file.Name(), 174 }) 175 statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 176 assert.Equal(t, http.StatusOK, statusCode, statusText) 177 assert.NotEmpty(t, tlsData.CAChain) 178 } 179 180 func TestClientCertificateInPredefinedDirectory(t *testing.T) { 181 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 182 defer s.Close() 183 184 serverURL, err := url.Parse(s.URL) 185 require.NoError(t, err) 186 hostname, _, err := net.SplitHostPort(serverURL.Host) 187 require.NoError(t, err) 188 189 tempDir, err := ioutil.TempDir("", "certs") 190 assert.NoError(t, err) 191 defer os.RemoveAll(tempDir) 192 CertificateDirectory = tempDir 193 194 err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt")) 195 assert.NoError(t, err) 196 197 c, _ := newClient(&RunnerCredentials{ 198 URL: s.URL, 199 }) 200 statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 201 assert.Equal(t, http.StatusOK, statusCode, statusText) 202 assert.NotEmpty(t, tlsData.CAChain) 203 } 204 205 func TestClientInvalidTLSAuth(t *testing.T) { 206 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 207 s.TLS = new(tls.Config) 208 s.TLS.ClientAuth = tls.RequireAnyClientCert 209 s.StartTLS() 210 defer s.Close() 211 212 ca, err := ioutil.TempFile("", "cert_") 213 assert.NoError(t, err) 214 ca.Close() 215 defer os.Remove(ca.Name()) 216 217 err = writeTLSCertificate(s, ca.Name()) 218 assert.NoError(t, err) 219 220 c, _ := newClient(&RunnerCredentials{ 221 URL: s.URL, 222 TLSCAFile: ca.Name(), 223 }) 224 statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 225 assert.Equal(t, -1, statusCode, statusText) 226 assert.Contains(t, statusText, "tls: bad certificate") 227 } 228 229 func TestClientTLSAuth(t *testing.T) { 230 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 231 s.TLS = new(tls.Config) 232 s.TLS.ClientAuth = tls.RequireAnyClientCert 233 s.StartTLS() 234 defer s.Close() 235 236 ca, err := ioutil.TempFile("", "cert_") 237 assert.NoError(t, err) 238 ca.Close() 239 defer os.Remove(ca.Name()) 240 241 err = writeTLSCertificate(s, ca.Name()) 242 assert.NoError(t, err) 243 244 cert, err := ioutil.TempFile("", "cert_") 245 assert.NoError(t, err) 246 cert.Close() 247 defer os.Remove(cert.Name()) 248 249 key, err := ioutil.TempFile("", "key_") 250 assert.NoError(t, err) 251 key.Close() 252 defer os.Remove(key.Name()) 253 254 err = writeTLSKeyPair(s, cert.Name(), key.Name()) 255 assert.NoError(t, err) 256 257 c, _ := newClient(&RunnerCredentials{ 258 URL: s.URL, 259 TLSCAFile: ca.Name(), 260 TLSCertFile: cert.Name(), 261 TLSKeyFile: key.Name(), 262 }) 263 statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 264 assert.Equal(t, http.StatusOK, statusCode, statusText) 265 assert.NotEmpty(t, tlsData.CAChain) 266 assert.Equal(t, cert.Name(), tlsData.CertFile) 267 assert.Equal(t, key.Name(), tlsData.KeyFile) 268 } 269 270 func TestClientTLSAuthCertificatesInPredefinedDirectory(t *testing.T) { 271 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 272 s.TLS = new(tls.Config) 273 s.TLS.ClientAuth = tls.RequireAnyClientCert 274 s.StartTLS() 275 defer s.Close() 276 277 tempDir, err := ioutil.TempDir("", "certs") 278 assert.NoError(t, err) 279 defer os.RemoveAll(tempDir) 280 CertificateDirectory = tempDir 281 282 serverURL, err := url.Parse(s.URL) 283 require.NoError(t, err) 284 hostname, _, err := net.SplitHostPort(serverURL.Host) 285 require.NoError(t, err) 286 287 err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt")) 288 assert.NoError(t, err) 289 290 err = writeTLSKeyPair(s, 291 filepath.Join(tempDir, hostname+".auth.crt"), 292 filepath.Join(tempDir, hostname+".auth.key")) 293 assert.NoError(t, err) 294 295 c, _ := newClient(&RunnerCredentials{ 296 URL: s.URL, 297 }) 298 statusCode, statusText, tlsData := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 299 assert.Equal(t, http.StatusOK, statusCode, statusText) 300 assert.NotEmpty(t, tlsData.CAChain) 301 assert.NotEmpty(t, tlsData.CertFile) 302 assert.NotEmpty(t, tlsData.KeyFile) 303 } 304 305 func TestUrlFixing(t *testing.T) { 306 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci///")) 307 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci/")) 308 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci")) 309 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/")) 310 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com///")) 311 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com")) 312 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci/")) 313 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci///")) 314 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci")) 315 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/")) 316 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab///")) 317 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab")) 318 } 319 320 func charsetTestClientHandler(w http.ResponseWriter, r *http.Request) { 321 switch r.URL.Path { 322 case "/api/v4/with-charset": 323 w.Header().Set("Content-Type", "application/json; charset=utf-8") 324 w.WriteHeader(http.StatusOK) 325 fmt.Fprint(w, "{\"key\":\"value\"}") 326 case "/api/v4/without-charset": 327 w.Header().Set("Content-Type", "application/json") 328 w.WriteHeader(http.StatusOK) 329 fmt.Fprint(w, "{\"key\":\"value\"}") 330 case "/api/v4/without-json": 331 w.Header().Set("Content-Type", "application/octet-stream") 332 w.WriteHeader(http.StatusOK) 333 fmt.Fprint(w, "{\"key\":\"value\"}") 334 case "/api/v4/invalid-header": 335 w.Header().Set("Content-Type", "application/octet-stream, test, a=b") 336 w.WriteHeader(http.StatusOK) 337 fmt.Fprint(w, "{\"key\":\"value\"}") 338 } 339 } 340 341 func TestClientHandleCharsetInContentType(t *testing.T) { 342 s := httptest.NewServer(http.HandlerFunc(charsetTestClientHandler)) 343 defer s.Close() 344 345 c, _ := newClient(&RunnerCredentials{ 346 URL: s.URL, 347 }) 348 349 res := struct { 350 Key string `json:"key"` 351 }{} 352 353 statusCode, statusText, _ := c.doJSON("with-charset", "GET", http.StatusOK, nil, &res) 354 assert.Equal(t, http.StatusOK, statusCode, statusText) 355 356 statusCode, statusText, _ = c.doJSON("without-charset", "GET", http.StatusOK, nil, &res) 357 assert.Equal(t, http.StatusOK, statusCode, statusText) 358 359 statusCode, statusText, _ = c.doJSON("without-json", "GET", http.StatusOK, nil, &res) 360 assert.Equal(t, -1, statusCode, statusText) 361 362 statusCode, statusText, _ = c.doJSON("invalid-header", "GET", http.StatusOK, nil, &res) 363 assert.Equal(t, -1, statusCode, statusText) 364 } 365 366 type backoffTestCase struct { 367 responseStatus int 368 mustBackoff bool 369 } 370 371 func tooManyRequestsHandler(w http.ResponseWriter, r *http.Request) { 372 status, err := strconv.Atoi(r.Header.Get("responseStatus")) 373 if err != nil { 374 w.WriteHeader(599) 375 } else { 376 w.WriteHeader(status) 377 } 378 } 379 380 func TestRequestsBackOff(t *testing.T) { 381 s := httptest.NewServer(http.HandlerFunc(tooManyRequestsHandler)) 382 defer s.Close() 383 384 c, _ := newClient(&RunnerCredentials{ 385 URL: s.URL, 386 }) 387 388 testCases := []backoffTestCase{ 389 {http.StatusCreated, false}, 390 {http.StatusInternalServerError, true}, 391 {http.StatusBadGateway, true}, 392 {http.StatusServiceUnavailable, true}, 393 {http.StatusOK, false}, 394 {http.StatusConflict, true}, 395 {http.StatusTooManyRequests, true}, 396 {http.StatusCreated, false}, 397 {http.StatusInternalServerError, true}, 398 {http.StatusTooManyRequests, true}, 399 {599, true}, 400 {499, true}, 401 } 402 403 backoff := c.ensureBackoff("POST", "") 404 for id, testCase := range testCases { 405 t.Run(fmt.Sprintf("%d-%d", id, testCase.responseStatus), func(t *testing.T) { 406 backoff.Reset() 407 assert.Zero(t, backoff.Attempt()) 408 409 var body io.Reader 410 headers := make(http.Header) 411 headers.Add("responseStatus", strconv.Itoa(testCase.responseStatus)) 412 413 res, err := c.do("/", "POST", body, "application/json", headers) 414 415 assert.NoError(t, err) 416 assert.Equal(t, testCase.responseStatus, res.StatusCode) 417 418 var expected float64 419 if testCase.mustBackoff { 420 expected = 1.0 421 } 422 assert.Equal(t, expected, backoff.Attempt()) 423 }) 424 } 425 }