github.com/nilium/gitlab-runner@v12.5.0+incompatible/network/client_test.go (about) 1 package network 2 3 import ( 4 "crypto/rsa" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "os" 17 "path/filepath" 18 "strconv" 19 "testing" 20 21 "github.com/sirupsen/logrus" 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 25 . "gitlab.com/gitlab-org/gitlab-runner/common" 26 ) 27 28 func clientHandler(w http.ResponseWriter, r *http.Request) { 29 body, _ := ioutil.ReadAll(r.Body) 30 logrus.Debugln(r.Method, r.URL.String(), 31 "Content-Type:", r.Header.Get("Content-Type"), 32 "Accept:", r.Header.Get("Accept"), 33 "Body:", string(body)) 34 35 switch r.URL.Path { 36 case "/api/v4/test/ok": 37 case "/api/v4/test/auth": 38 w.WriteHeader(http.StatusForbidden) 39 case "/api/v4/test/json": 40 if r.Header.Get("Content-Type") != "application/json" { 41 w.WriteHeader(http.StatusBadRequest) 42 } else if r.Header.Get("Accept") != "application/json" { 43 w.WriteHeader(http.StatusNotAcceptable) 44 } else { 45 w.Header().Set("Content-Type", "application/json") 46 fmt.Fprint(w, "{\"key\":\"value\"}") 47 } 48 default: 49 w.WriteHeader(http.StatusNotFound) 50 } 51 } 52 53 func writeTLSCertificate(s *httptest.Server, file string) error { 54 c := s.TLS.Certificates[0] 55 if c.Certificate == nil || c.Certificate[0] == nil { 56 return errors.New("no predefined certificate") 57 } 58 59 encoded := pem.EncodeToMemory(&pem.Block{ 60 Type: "CERTIFICATE", 61 Bytes: c.Certificate[0], 62 }) 63 64 return ioutil.WriteFile(file, encoded, 0600) 65 } 66 67 func writeTLSKeyPair(s *httptest.Server, certFile string, keyFile string) error { 68 c := s.TLS.Certificates[0] 69 if c.Certificate == nil || c.Certificate[0] == nil { 70 return errors.New("no predefined certificate") 71 } 72 73 encodedCert := pem.EncodeToMemory(&pem.Block{ 74 Type: "CERTIFICATE", 75 Bytes: c.Certificate[0], 76 }) 77 78 if err := ioutil.WriteFile(certFile, encodedCert, 0600); err != nil { 79 return err 80 } 81 82 switch k := c.PrivateKey.(type) { 83 case *rsa.PrivateKey: 84 encodedKey := pem.EncodeToMemory(&pem.Block{ 85 Type: "RSA PRIVATE KEY", 86 Bytes: x509.MarshalPKCS1PrivateKey(k), 87 }) 88 return ioutil.WriteFile(keyFile, encodedKey, 0600) 89 default: 90 return errors.New("unexpected private key type") 91 } 92 } 93 94 func TestNewClient(t *testing.T) { 95 c, err := newClient(&RunnerCredentials{ 96 URL: "http://test.example.com/ci///", 97 }) 98 assert.NoError(t, err) 99 assert.NotNil(t, c) 100 assert.Equal(t, "http://test.example.com/api/v4/", c.url.String()) 101 } 102 103 func TestInvalidUrl(t *testing.T) { 104 _, err := newClient(&RunnerCredentials{ 105 URL: "address.com/ci///", 106 }) 107 assert.Error(t, err) 108 } 109 110 func TestClientDo(t *testing.T) { 111 s := httptest.NewServer(http.HandlerFunc(clientHandler)) 112 defer s.Close() 113 114 c, err := newClient(&RunnerCredentials{ 115 URL: s.URL, 116 }) 117 assert.NoError(t, err) 118 assert.NotNil(t, c) 119 120 statusCode, statusText, _ := c.doJSON("test/auth", "GET", http.StatusOK, nil, nil) 121 assert.Equal(t, http.StatusForbidden, statusCode, statusText) 122 123 req := struct { 124 Query bool `json:"query"` 125 }{ 126 true, 127 } 128 129 res := struct { 130 Key string `json:"key"` 131 }{} 132 133 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, &res) 134 assert.Equal(t, http.StatusBadRequest, statusCode, statusText) 135 136 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, nil) 137 assert.Equal(t, http.StatusNotAcceptable, statusCode, statusText) 138 139 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, nil, nil) 140 assert.Equal(t, http.StatusBadRequest, statusCode, statusText) 141 142 statusCode, statusText, _ = c.doJSON("test/json", "GET", http.StatusOK, &req, &res) 143 assert.Equal(t, http.StatusOK, statusCode, statusText) 144 assert.Equal(t, "value", res.Key, statusText) 145 } 146 147 func TestClientInvalidSSL(t *testing.T) { 148 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 149 defer s.Close() 150 151 c, _ := newClient(&RunnerCredentials{ 152 URL: s.URL, 153 }) 154 statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 155 assert.Equal(t, -1, statusCode, statusText) 156 assert.Contains(t, statusText, "certificate signed by unknown authority") 157 } 158 159 func TestClientTLSCAFile(t *testing.T) { 160 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 161 defer s.Close() 162 163 file, err := ioutil.TempFile("", "cert_") 164 assert.NoError(t, err) 165 file.Close() 166 defer os.Remove(file.Name()) 167 168 err = writeTLSCertificate(s, file.Name()) 169 assert.NoError(t, err) 170 171 c, _ := newClient(&RunnerCredentials{ 172 URL: s.URL, 173 TLSCAFile: file.Name(), 174 }) 175 statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 176 assert.Equal(t, http.StatusOK, statusCode, statusText) 177 178 tlsData, err := c.getResponseTLSData(resp.TLS) 179 assert.NoError(t, err) 180 assert.NotEmpty(t, tlsData.CAChain) 181 } 182 183 func TestClientCertificateInPredefinedDirectory(t *testing.T) { 184 s := httptest.NewTLSServer(http.HandlerFunc(clientHandler)) 185 defer s.Close() 186 187 serverURL, err := url.Parse(s.URL) 188 require.NoError(t, err) 189 hostname, _, err := net.SplitHostPort(serverURL.Host) 190 require.NoError(t, err) 191 192 tempDir, err := ioutil.TempDir("", "certs") 193 assert.NoError(t, err) 194 defer os.RemoveAll(tempDir) 195 CertificateDirectory = tempDir 196 197 err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt")) 198 assert.NoError(t, err) 199 200 c, _ := newClient(&RunnerCredentials{ 201 URL: s.URL, 202 }) 203 statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 204 assert.Equal(t, http.StatusOK, statusCode, statusText) 205 206 tlsData, err := c.getResponseTLSData(resp.TLS) 207 assert.NoError(t, err) 208 assert.NotEmpty(t, tlsData.CAChain) 209 } 210 211 func TestClientInvalidTLSAuth(t *testing.T) { 212 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 213 s.TLS = new(tls.Config) 214 s.TLS.ClientAuth = tls.RequireAnyClientCert 215 s.StartTLS() 216 defer s.Close() 217 218 ca, err := ioutil.TempFile("", "cert_") 219 assert.NoError(t, err) 220 ca.Close() 221 defer os.Remove(ca.Name()) 222 223 err = writeTLSCertificate(s, ca.Name()) 224 assert.NoError(t, err) 225 226 c, _ := newClient(&RunnerCredentials{ 227 URL: s.URL, 228 TLSCAFile: ca.Name(), 229 }) 230 statusCode, statusText, _ := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 231 assert.Equal(t, -1, statusCode, statusText) 232 assert.Contains(t, statusText, "tls: bad certificate") 233 } 234 235 func TestClientTLSAuth(t *testing.T) { 236 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 237 s.TLS = new(tls.Config) 238 s.TLS.ClientAuth = tls.RequireAnyClientCert 239 s.StartTLS() 240 defer s.Close() 241 242 ca, err := ioutil.TempFile("", "cert_") 243 assert.NoError(t, err) 244 ca.Close() 245 defer os.Remove(ca.Name()) 246 247 err = writeTLSCertificate(s, ca.Name()) 248 assert.NoError(t, err) 249 250 cert, err := ioutil.TempFile("", "cert_") 251 assert.NoError(t, err) 252 cert.Close() 253 defer os.Remove(cert.Name()) 254 255 key, err := ioutil.TempFile("", "key_") 256 assert.NoError(t, err) 257 key.Close() 258 defer os.Remove(key.Name()) 259 260 err = writeTLSKeyPair(s, cert.Name(), key.Name()) 261 assert.NoError(t, err) 262 263 c, _ := newClient(&RunnerCredentials{ 264 URL: s.URL, 265 TLSCAFile: ca.Name(), 266 TLSCertFile: cert.Name(), 267 TLSKeyFile: key.Name(), 268 }) 269 270 statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 271 assert.Equal(t, http.StatusOK, statusCode, statusText) 272 273 tlsData, err := c.getResponseTLSData(resp.TLS) 274 assert.NoError(t, err) 275 assert.NotEmpty(t, tlsData.CAChain) 276 assert.Equal(t, cert.Name(), tlsData.CertFile) 277 assert.Equal(t, key.Name(), tlsData.KeyFile) 278 } 279 280 func TestClientTLSAuthCertificatesInPredefinedDirectory(t *testing.T) { 281 s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler)) 282 s.TLS = new(tls.Config) 283 s.TLS.ClientAuth = tls.RequireAnyClientCert 284 s.StartTLS() 285 defer s.Close() 286 287 tempDir, err := ioutil.TempDir("", "certs") 288 assert.NoError(t, err) 289 defer os.RemoveAll(tempDir) 290 CertificateDirectory = tempDir 291 292 serverURL, err := url.Parse(s.URL) 293 require.NoError(t, err) 294 hostname, _, err := net.SplitHostPort(serverURL.Host) 295 require.NoError(t, err) 296 297 err = writeTLSCertificate(s, filepath.Join(tempDir, hostname+".crt")) 298 assert.NoError(t, err) 299 300 err = writeTLSKeyPair(s, 301 filepath.Join(tempDir, hostname+".auth.crt"), 302 filepath.Join(tempDir, hostname+".auth.key")) 303 assert.NoError(t, err) 304 305 c, _ := newClient(&RunnerCredentials{ 306 URL: s.URL, 307 }) 308 statusCode, statusText, resp := c.doJSON("test/ok", "GET", http.StatusOK, nil, nil) 309 assert.Equal(t, http.StatusOK, statusCode, statusText) 310 311 tlsData, err := c.getResponseTLSData(resp.TLS) 312 assert.NoError(t, err) 313 assert.NotEmpty(t, tlsData.CAChain) 314 assert.NotEmpty(t, tlsData.CertFile) 315 assert.NotEmpty(t, tlsData.KeyFile) 316 } 317 318 func TestUrlFixing(t *testing.T) { 319 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci///")) 320 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci/")) 321 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/ci")) 322 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com/")) 323 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com///")) 324 assert.Equal(t, "https://gitlab.example.com", fixCIURL("https://gitlab.example.com")) 325 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci/")) 326 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci///")) 327 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/ci")) 328 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab/")) 329 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab///")) 330 assert.Equal(t, "https://example.com/gitlab", fixCIURL("https://example.com/gitlab")) 331 } 332 333 func charsetTestClientHandler(w http.ResponseWriter, r *http.Request) { 334 switch r.URL.Path { 335 case "/api/v4/with-charset": 336 w.Header().Set("Content-Type", "application/json; charset=utf-8") 337 w.WriteHeader(http.StatusOK) 338 fmt.Fprint(w, "{\"key\":\"value\"}") 339 case "/api/v4/without-charset": 340 w.Header().Set("Content-Type", "application/json") 341 w.WriteHeader(http.StatusOK) 342 fmt.Fprint(w, "{\"key\":\"value\"}") 343 case "/api/v4/without-json": 344 w.Header().Set("Content-Type", "application/octet-stream") 345 w.WriteHeader(http.StatusOK) 346 fmt.Fprint(w, "{\"key\":\"value\"}") 347 case "/api/v4/invalid-header": 348 w.Header().Set("Content-Type", "application/octet-stream, test, a=b") 349 w.WriteHeader(http.StatusOK) 350 fmt.Fprint(w, "{\"key\":\"value\"}") 351 } 352 } 353 354 func TestClientHandleCharsetInContentType(t *testing.T) { 355 s := httptest.NewServer(http.HandlerFunc(charsetTestClientHandler)) 356 defer s.Close() 357 358 c, _ := newClient(&RunnerCredentials{ 359 URL: s.URL, 360 }) 361 362 res := struct { 363 Key string `json:"key"` 364 }{} 365 366 statusCode, statusText, _ := c.doJSON("with-charset", "GET", http.StatusOK, nil, &res) 367 assert.Equal(t, http.StatusOK, statusCode, statusText) 368 369 statusCode, statusText, _ = c.doJSON("without-charset", "GET", http.StatusOK, nil, &res) 370 assert.Equal(t, http.StatusOK, statusCode, statusText) 371 372 statusCode, statusText, _ = c.doJSON("without-json", "GET", http.StatusOK, nil, &res) 373 assert.Equal(t, -1, statusCode, statusText) 374 375 statusCode, statusText, _ = c.doJSON("invalid-header", "GET", http.StatusOK, nil, &res) 376 assert.Equal(t, -1, statusCode, statusText) 377 } 378 379 type backoffTestCase struct { 380 responseStatus int 381 mustBackoff bool 382 } 383 384 func tooManyRequestsHandler(w http.ResponseWriter, r *http.Request) { 385 status, err := strconv.Atoi(r.Header.Get("responseStatus")) 386 if err != nil { 387 w.WriteHeader(599) 388 } else { 389 w.WriteHeader(status) 390 } 391 } 392 393 func TestRequestsBackOff(t *testing.T) { 394 s := httptest.NewServer(http.HandlerFunc(tooManyRequestsHandler)) 395 defer s.Close() 396 397 c, _ := newClient(&RunnerCredentials{ 398 URL: s.URL, 399 }) 400 401 testCases := []backoffTestCase{ 402 {http.StatusCreated, false}, 403 {http.StatusInternalServerError, true}, 404 {http.StatusBadGateway, true}, 405 {http.StatusServiceUnavailable, true}, 406 {http.StatusOK, false}, 407 {http.StatusConflict, true}, 408 {http.StatusTooManyRequests, true}, 409 {http.StatusCreated, false}, 410 {http.StatusInternalServerError, true}, 411 {http.StatusTooManyRequests, true}, 412 {599, true}, 413 {499, true}, 414 } 415 416 backoff := c.ensureBackoff("POST", "") 417 for id, testCase := range testCases { 418 t.Run(fmt.Sprintf("%d-%d", id, testCase.responseStatus), func(t *testing.T) { 419 backoff.Reset() 420 assert.Zero(t, backoff.Attempt()) 421 422 var body io.Reader 423 headers := make(http.Header) 424 headers.Add("responseStatus", strconv.Itoa(testCase.responseStatus)) 425 426 res, err := c.do("/", "POST", body, "application/json", headers) 427 428 assert.NoError(t, err) 429 assert.Equal(t, testCase.responseStatus, res.StatusCode) 430 431 var expected float64 432 if testCase.mustBackoff { 433 expected = 1.0 434 } 435 assert.Equal(t, expected, backoff.Attempt()) 436 }) 437 } 438 }