github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/checks/client_test.go (about) 1 package checks 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/hashicorp/nomad/ci" 15 "github.com/hashicorp/nomad/helper/freeport" 16 "github.com/hashicorp/nomad/helper/testlog" 17 "github.com/hashicorp/nomad/nomad/mock" 18 "github.com/hashicorp/nomad/nomad/structs" 19 "github.com/shoenig/test/must" 20 "golang.org/x/exp/maps" 21 "oss.indeed.com/go/libtime/libtimetest" 22 ) 23 24 func splitURL(u string) (string, string) { 25 // get the address and port for http server 26 tokens := strings.Split(u, ":") 27 addr, port := strings.TrimPrefix(tokens[1], "//"), tokens[2] 28 return addr, port 29 } 30 31 func TestChecker_Do_HTTP(t *testing.T) { 32 ci.Parallel(t) 33 34 // an example response that will be truncated 35 tooLong, truncate := bigResponse() 36 37 // create an http server with various responses 38 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 switch r.URL.Path { 40 case "/fail": 41 w.WriteHeader(500) 42 _, _ = io.WriteString(w, "500 problem") 43 case "/hang": 44 time.Sleep(1 * time.Second) 45 _, _ = io.WriteString(w, "too slow") 46 case "/long-fail": 47 w.WriteHeader(500) 48 _, _ = io.WriteString(w, tooLong) 49 case "/long-not-fail": 50 w.WriteHeader(201) 51 _, _ = io.WriteString(w, tooLong) 52 default: 53 w.WriteHeader(200) 54 _, _ = io.WriteString(w, "200 ok") 55 } 56 })) 57 defer ts.Close() 58 59 // get the address and port for http server 60 addr, port := splitURL(ts.URL) 61 62 // create a mock clock so we can assert time is set 63 now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC) 64 clock := libtimetest.NewClockMock(t).NowMock.Return(now) 65 66 makeQueryContext := func() *QueryContext { 67 return &QueryContext{ 68 ID: "abc123", 69 CustomAddress: addr, 70 ServicePortLabel: port, 71 Networks: nil, 72 NetworkStatus: mock.NewNetworkStatus(addr), 73 Ports: nil, 74 Group: "group", 75 Task: "task", 76 Service: "service", 77 Check: "check", 78 } 79 } 80 81 makeQuery := func( 82 kind structs.CheckMode, 83 path string, 84 ) *Query { 85 return &Query{ 86 Mode: kind, 87 Type: "http", 88 Timeout: 100 * time.Millisecond, 89 AddressMode: "auto", 90 PortLabel: port, 91 Protocol: "http", 92 Path: path, 93 Method: "GET", 94 } 95 } 96 97 makeExpResult := func( 98 kind structs.CheckMode, 99 status structs.CheckStatus, 100 code int, 101 output string, 102 ) *structs.CheckQueryResult { 103 return &structs.CheckQueryResult{ 104 ID: "abc123", 105 Mode: kind, 106 Status: status, 107 StatusCode: code, 108 Output: output, 109 Timestamp: now.Unix(), 110 Group: "group", 111 Task: "task", 112 Service: "service", 113 Check: "check", 114 } 115 } 116 117 cases := []struct { 118 name string 119 qc *QueryContext 120 q *Query 121 expResult *structs.CheckQueryResult 122 }{{ 123 name: "200 healthiness", 124 qc: makeQueryContext(), 125 q: makeQuery(structs.Healthiness, "/"), 126 expResult: makeExpResult( 127 structs.Healthiness, 128 structs.CheckSuccess, 129 http.StatusOK, 130 "nomad: http ok", 131 ), 132 }, { 133 name: "200 readiness", 134 qc: makeQueryContext(), 135 q: makeQuery(structs.Readiness, "/"), 136 expResult: makeExpResult( 137 structs.Readiness, 138 structs.CheckSuccess, 139 http.StatusOK, 140 "nomad: http ok", 141 ), 142 }, { 143 name: "500 healthiness", 144 qc: makeQueryContext(), 145 q: makeQuery(structs.Healthiness, "fail"), 146 expResult: makeExpResult( 147 structs.Healthiness, 148 structs.CheckFailure, 149 http.StatusInternalServerError, 150 "500 problem", 151 ), 152 }, { 153 name: "hang", 154 qc: makeQueryContext(), 155 q: makeQuery(structs.Healthiness, "hang"), 156 expResult: makeExpResult( 157 structs.Healthiness, 158 structs.CheckFailure, 159 0, 160 fmt.Sprintf(`nomad: Get "%s/hang": context deadline exceeded`, ts.URL), 161 ), 162 }, { 163 name: "500 truncate", 164 qc: makeQueryContext(), 165 q: makeQuery(structs.Healthiness, "long-fail"), 166 expResult: makeExpResult( 167 structs.Healthiness, 168 structs.CheckFailure, 169 http.StatusInternalServerError, 170 truncate, 171 ), 172 }, { 173 name: "201 truncate", 174 qc: makeQueryContext(), 175 q: makeQuery(structs.Healthiness, "long-not-fail"), 176 expResult: makeExpResult( 177 structs.Healthiness, 178 structs.CheckSuccess, 179 http.StatusCreated, 180 truncate, 181 ), 182 }} 183 184 for _, tc := range cases { 185 t.Run(tc.name, func(t *testing.T) { 186 logger := testlog.HCLogger(t) 187 188 c := New(logger) 189 c.(*checker).clock = clock 190 191 ctx := context.Background() 192 result := c.Do(ctx, tc.qc, tc.q) 193 must.Eq(t, tc.expResult, result) 194 }) 195 } 196 } 197 198 // bigResponse creates a response payload larger than the maximum outputSizeLimit 199 // as well as the same response but truncated to length of outputSizeLimit 200 func bigResponse() (string, string) { 201 size := outputSizeLimit + 5 202 b := make([]byte, size, size) 203 for i := 0; i < size; i++ { 204 b[i] = 'a' 205 } 206 s := string(b) 207 return s, s[:outputSizeLimit] 208 } 209 210 func TestChecker_Do_HTTP_extras(t *testing.T) { 211 ci.Parallel(t) 212 213 // record the method, body, and headers of the request 214 var ( 215 method string 216 body []byte 217 headers map[string][]string 218 host string 219 ) 220 221 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 222 method = r.Method 223 body, _ = io.ReadAll(r.Body) 224 headers = maps.Clone(r.Header) 225 host = r.Host 226 w.WriteHeader(http.StatusOK) 227 })) 228 defer ts.Close() 229 230 // get the address and port for http server 231 addr, port := splitURL(ts.URL) 232 233 // make headers from key-value pairs 234 makeHeaders := func(more ...[2]string) http.Header { 235 h := make(http.Header) 236 for _, extra := range more { 237 h.Set(extra[0], extra[1]) 238 } 239 return h 240 } 241 242 encoding := [2]string{"Accept-Encoding", "gzip"} 243 agent := [2]string{"User-Agent", "Go-http-client/1.1"} 244 245 cases := []struct { 246 name string 247 method string 248 body string 249 headers http.Header 250 }{ 251 { 252 name: "method GET", 253 method: "GET", 254 headers: makeHeaders(encoding, agent), 255 }, 256 { 257 name: "method Get", 258 method: "Get", 259 headers: makeHeaders(encoding, agent), 260 }, 261 { 262 name: "method HEAD", 263 method: "HEAD", 264 headers: makeHeaders(agent), 265 }, 266 { 267 name: "extra headers", 268 method: "GET", 269 headers: makeHeaders(encoding, agent, 270 [2]string{"X-My-Header", "hello"}, 271 [2]string{"Authorization", "Basic ZWxhc3RpYzpjaGFuZ2VtZQ=="}, 272 ), 273 }, 274 { 275 name: "host header", 276 method: "GET", 277 headers: makeHeaders(encoding, agent, 278 [2]string{"Host", "hello"}, 279 [2]string{"Test-Abc", "hello"}, 280 ), 281 }, 282 { 283 name: "host header without normalization", 284 method: "GET", 285 body: "", 286 // This is needed to prevent header normalization by http.Header.Set 287 headers: func() map[string][]string { 288 h := makeHeaders(encoding, agent, [2]string{"Test-Abc", "hello"}) 289 h["hoST"] = []string{"heLLO"} 290 return h 291 }(), 292 }, 293 { 294 name: "with body", 295 method: "POST", 296 headers: makeHeaders(encoding, agent), 297 body: "some payload", 298 }, 299 } 300 301 for _, tc := range cases { 302 qc := &QueryContext{ 303 ID: "abc123", 304 CustomAddress: addr, 305 ServicePortLabel: port, 306 Networks: nil, 307 NetworkStatus: mock.NewNetworkStatus(addr), 308 Ports: nil, 309 Group: "group", 310 Task: "task", 311 Service: "service", 312 Check: "check", 313 } 314 315 q := &Query{ 316 Mode: structs.Healthiness, 317 Type: "http", 318 Timeout: 1 * time.Second, 319 AddressMode: "auto", 320 PortLabel: port, 321 Protocol: "http", 322 Path: "/", 323 Method: tc.method, 324 Headers: tc.headers, 325 Body: tc.body, 326 } 327 328 t.Run(tc.name, func(t *testing.T) { 329 logger := testlog.HCLogger(t) 330 c := New(logger) 331 ctx := context.Background() 332 result := c.Do(ctx, qc, q) 333 must.Eq(t, http.StatusOK, result.StatusCode, 334 must.Sprintf("test.URL: %s", ts.URL), 335 must.Sprintf("headers: %v", tc.headers), 336 must.Sprintf("received headers: %v", tc.headers), 337 ) 338 must.Eq(t, tc.method, method) 339 must.Eq(t, tc.body, string(body)) 340 341 hostSent := false 342 343 for key, values := range tc.headers { 344 if strings.EqualFold(key, "Host") && len(values) > 0 { 345 must.Eq(t, values[0], host) 346 hostSent = true 347 delete(tc.headers, key) 348 349 } 350 } 351 if !hostSent { 352 must.Eq(t, nil, tc.headers["Host"]) 353 } 354 355 must.Eq(t, tc.headers, headers) 356 }) 357 } 358 } 359 360 func TestChecker_Do_TCP(t *testing.T) { 361 ci.Parallel(t) 362 363 // create a mock clock so we can assert time is set 364 now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC) 365 clock := libtimetest.NewClockMock(t).NowMock.Return(now) 366 367 makeQueryContext := func(address string, port int) *QueryContext { 368 return &QueryContext{ 369 ID: "abc123", 370 CustomAddress: address, 371 ServicePortLabel: fmt.Sprintf("%d", port), 372 Networks: nil, 373 NetworkStatus: mock.NewNetworkStatus(address), 374 Ports: nil, 375 Group: "group", 376 Task: "task", 377 Service: "service", 378 Check: "check", 379 } 380 } 381 382 makeQuery := func( 383 kind structs.CheckMode, 384 port int, 385 ) *Query { 386 return &Query{ 387 Mode: kind, 388 Type: "tcp", 389 Timeout: 100 * time.Millisecond, 390 AddressMode: "auto", 391 PortLabel: fmt.Sprintf("%d", port), 392 } 393 } 394 395 makeExpResult := func( 396 kind structs.CheckMode, 397 status structs.CheckStatus, 398 output string, 399 ) *structs.CheckQueryResult { 400 return &structs.CheckQueryResult{ 401 ID: "abc123", 402 Mode: kind, 403 Status: status, 404 Output: output, 405 Timestamp: now.Unix(), 406 Group: "group", 407 Task: "task", 408 Service: "service", 409 Check: "check", 410 } 411 } 412 413 ports := freeport.MustTake(3) 414 defer freeport.Return(ports) 415 416 cases := []struct { 417 name string 418 qc *QueryContext 419 q *Query 420 tcpMode string // "ok", "off", "hang" 421 tcpPort int 422 expResult *structs.CheckQueryResult 423 }{{ 424 name: "tcp ok", 425 qc: makeQueryContext("localhost", ports[0]), 426 q: makeQuery(structs.Healthiness, ports[0]), 427 tcpMode: "ok", 428 tcpPort: ports[0], 429 expResult: makeExpResult( 430 structs.Healthiness, 431 structs.CheckSuccess, 432 "nomad: tcp ok", 433 ), 434 }, { 435 name: "tcp not listening", 436 qc: makeQueryContext("127.0.0.1", ports[1]), 437 q: makeQuery(structs.Healthiness, ports[1]), 438 tcpMode: "off", 439 tcpPort: ports[1], 440 expResult: makeExpResult( 441 structs.Healthiness, 442 structs.CheckFailure, 443 fmt.Sprintf("dial tcp 127.0.0.1:%d: connect: connection refused", ports[1]), 444 ), 445 }, { 446 name: "tcp slow accept", 447 qc: makeQueryContext("localhost", ports[2]), 448 q: makeQuery(structs.Healthiness, ports[2]), 449 tcpMode: "hang", 450 tcpPort: ports[2], 451 expResult: makeExpResult( 452 structs.Healthiness, 453 structs.CheckFailure, 454 "dial tcp: lookup localhost: i/o timeout", 455 ), 456 }} 457 458 for _, tc := range cases { 459 t.Run(tc.name, func(t *testing.T) { 460 logger := testlog.HCLogger(t) 461 462 ctx, cancel := context.WithCancel(context.Background()) 463 defer cancel() 464 465 c := New(logger) 466 c.(*checker).clock = clock 467 468 switch tc.tcpMode { 469 case "ok": 470 // simulate tcp server by listening 471 tcpServer(t, ctx, tc.tcpPort) 472 case "hang": 473 // simulate tcp hang by setting an already expired context 474 timeout, stop := context.WithDeadline(ctx, now.Add(-1*time.Second)) 475 defer stop() 476 ctx = timeout 477 case "off": 478 // simulate tcp dead connection by not listening 479 } 480 481 result := c.Do(ctx, tc.qc, tc.q) 482 must.Eq(t, tc.expResult, result) 483 }) 484 } 485 } 486 487 // tcpServer will start a tcp listener that accepts connections and closes them. 488 // The caller can close the listener by cancelling ctx. 489 func tcpServer(t *testing.T, ctx context.Context, port int) { 490 var lc net.ListenConfig 491 l, err := lc.Listen(ctx, "tcp", net.JoinHostPort( 492 "localhost", fmt.Sprintf("%d", port), 493 )) 494 must.NoError(t, err, must.Sprint("port", port)) 495 t.Cleanup(func() { 496 _ = l.Close() 497 }) 498 499 go func() { 500 // caller can stop us by cancelling ctx 501 for { 502 _, acceptErr := l.Accept() 503 if acceptErr != nil { 504 return 505 } 506 } 507 }() 508 }