github.com/xmidt-org/webpa-common@v1.11.9/xhttp/fanout/handler_test.go (about) 1 package fanout 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 "time" 13 14 gokithttp "github.com/go-kit/kit/transport/http" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/mock" 17 "github.com/stretchr/testify/require" 18 "github.com/xmidt-org/webpa-common/logging" 19 "github.com/xmidt-org/webpa-common/xhttp" 20 "github.com/xmidt-org/webpa-common/xhttp/xhttptest" 21 ) 22 23 func testHandlerBodyError(t *testing.T) { 24 var ( 25 assert = assert.New(t) 26 require = require.New(t) 27 28 expectedError = &xhttp.Error{Code: 599, Text: "body read error"} 29 body = new(xhttptest.MockBody) 30 logger = logging.NewTestLogger(nil, t) 31 ctx = logging.WithLogger(context.Background(), logger) 32 original = httptest.NewRequest("POST", "/something", body).WithContext(ctx) 33 response = httptest.NewRecorder() 34 35 handler = New(FixedEndpoints{}) 36 ) 37 38 require.NotNil(handler) 39 body.OnReadError(expectedError).Once() 40 41 handler.ServeHTTP(response, original) 42 assert.Equal(599, response.Code) 43 44 body.AssertExpectations(t) 45 } 46 47 func testHandlerNoEndpoints(t *testing.T) { 48 var ( 49 assert = assert.New(t) 50 require = require.New(t) 51 52 body = new(xhttptest.MockBody) 53 logger = logging.NewTestLogger(nil, t) 54 ctx = logging.WithLogger(context.Background(), logger) 55 original = httptest.NewRequest("POST", "/something", body).WithContext(ctx) 56 response = httptest.NewRecorder() 57 58 handler = New(FixedEndpoints{}, WithErrorEncoder(func(_ context.Context, err error, response http.ResponseWriter) { 59 response.WriteHeader(599) 60 })) 61 ) 62 63 require.NotNil(handler) 64 body.OnReadError(io.EOF).Once() 65 66 handler.ServeHTTP(response, original) 67 assert.Equal(599, response.Code) 68 69 body.AssertExpectations(t) 70 } 71 72 func testHandlerEndpointsError(t *testing.T) { 73 var ( 74 assert = assert.New(t) 75 require = require.New(t) 76 77 expectedError = errors.New("endpoints error") 78 body = new(xhttptest.MockBody) 79 endpoints = new(mockEndpoints) 80 81 logger = logging.NewTestLogger(nil, t) 82 ctx = logging.WithLogger(context.Background(), logger) 83 original = httptest.NewRequest("POST", "/something", body).WithContext(ctx) 84 response = httptest.NewRecorder() 85 86 handler = New(endpoints, WithErrorEncoder(func(_ context.Context, err error, response http.ResponseWriter) { 87 response.WriteHeader(599) 88 })) 89 ) 90 91 require.NotNil(handler) 92 body.OnReadError(io.EOF).Once() 93 endpoints.On("FanoutURLs", original).Once().Return(nil, expectedError) 94 95 handler.ServeHTTP(response, original) 96 assert.Equal(599, response.Code) 97 98 body.AssertExpectations(t) 99 } 100 101 func testHandlerBadTransactor(t *testing.T) { 102 var ( 103 assert = assert.New(t) 104 require = require.New(t) 105 106 logger = logging.NewTestLogger(nil, t) 107 ctx = logging.WithLogger(context.Background(), logger) 108 original = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx) 109 response = httptest.NewRecorder() 110 111 endpoints = generateEndpoints(1) 112 transactor = new(xhttptest.MockTransactor) 113 complete = make(chan struct{}, 1) 114 handler = New(endpoints, WithTransactor(transactor.Do)) 115 ) 116 117 require.NotNil(handler) 118 transactor.OnDo( 119 xhttptest.MatchMethod("GET"), 120 xhttptest.MatchURLString(endpoints[0].String()+"/api/v2/something"), 121 ).Respond(nil, nil).Once().Run(func(mock.Arguments) { complete <- struct{}{} }) 122 123 handler.ServeHTTP(response, original) 124 assert.Equal(http.StatusServiceUnavailable, response.Code) 125 126 select { 127 case <-complete: 128 // passing 129 case <-time.After(5 * time.Second): 130 assert.Fail("Not all transactors completed") 131 } 132 133 transactor.AssertExpectations(t) 134 } 135 136 func testHandlerGet(t *testing.T, expectedResponses []xhttptest.ExpectedResponse, expectedStatusCode int, expectedResponseBody string, expectAfter bool, expectedFailedCalled bool) { 137 var ( 138 assert = assert.New(t) 139 require = require.New(t) 140 141 logger = logging.NewTestLogger(nil, t) 142 ctx = logging.WithLogger(context.Background(), logger) 143 original = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx) 144 response = httptest.NewRecorder() 145 146 fanoutAfterCalled = false 147 fanoutAfter = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context { 148 assert.False(fanoutAfterCalled) 149 fanoutAfterCalled = true 150 assert.Equal(ctx, actualCtx) 151 assert.Equal(response, actualResponse) 152 if assert.NotNil(result.Response) { 153 assert.Equal(expectedStatusCode, result.Response.StatusCode) 154 } 155 156 return actualCtx 157 } 158 159 clientAfterCalled = false 160 clientAfter = func(actualCtx context.Context, actualResponse *http.Response) context.Context { 161 assert.False(clientAfterCalled) 162 clientAfterCalled = true 163 assert.Equal(ctx, actualCtx) 164 assert.Equal(expectedStatusCode, actualResponse.StatusCode) 165 return actualCtx 166 } 167 168 fanoutFailedCalled = false 169 fanoutFail = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context { 170 assert.False(fanoutFailedCalled) 171 fanoutFailedCalled = true 172 assert.Equal(ctx, actualCtx) 173 return ctx 174 } 175 176 endpoints = generateEndpoints(len(expectedResponses)) 177 transactor = new(xhttptest.MockTransactor) 178 complete = make(chan struct{}, len(expectedResponses)) 179 180 handler = New(endpoints, 181 WithTransactor(transactor.Do), 182 WithClientBefore(gokithttp.SetRequestHeader("X-Test", "foobar")), 183 WithFanoutAfter(fanoutAfter), 184 WithClientAfter(clientAfter), 185 WithFanoutFailure(fanoutFail), 186 ) 187 ) 188 189 require.NotNil(handler) 190 for i, er := range expectedResponses { 191 transactor.OnDo( 192 xhttptest.MatchMethod("GET"), 193 xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"), 194 xhttptest.MatchHeader("X-Test", "foobar"), 195 ).RespondWith(er).Once().Run(func(mock.Arguments) { complete <- struct{}{} }) 196 } 197 198 handler.ServeHTTP(response, original) 199 assert.Equal(expectedStatusCode, response.Code) 200 201 after := time.After(5 * time.Second) 202 for i := 0; i < len(expectedResponses); i++ { 203 select { 204 case <-complete: 205 // passing 206 case <-after: 207 assert.Fail("Not all transactors completed") 208 i = len(expectedResponses) 209 } 210 } 211 212 assert.Equal(expectAfter, clientAfterCalled) 213 assert.Equal(expectedFailedCalled, fanoutFailedCalled) 214 transactor.AssertExpectations(t) 215 } 216 217 func testHandlerPost(t *testing.T, expectedResponses []xhttptest.ExpectedResponse, expectedStatusCode int, expectedResponseBody string, expectAfter bool, expectedFailedCalled bool) { 218 var ( 219 assert = assert.New(t) 220 require = require.New(t) 221 222 logger = logging.NewTestLogger(nil, t) 223 ctx = logging.WithLogger(context.Background(), logger) 224 expectedRequestBody = "posted body" 225 original = httptest.NewRequest("POST", "/api/v2/something", strings.NewReader(expectedRequestBody)).WithContext(ctx) 226 response = httptest.NewRecorder() 227 228 fanoutAfterCalled = false 229 fanoutAfter = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context { 230 assert.False(fanoutAfterCalled) 231 fanoutAfterCalled = true 232 assert.Equal(ctx, actualCtx) 233 assert.Equal(response, actualResponse) 234 if assert.NotNil(result.Response) { 235 assert.Equal(expectedStatusCode, result.Response.StatusCode) 236 } 237 238 return actualCtx 239 } 240 241 clientAfterCalled = false 242 clientAfter = func(actualCtx context.Context, actualResponse *http.Response) context.Context { 243 assert.False(clientAfterCalled) 244 clientAfterCalled = true 245 assert.Equal(ctx, actualCtx) 246 assert.Equal(expectedStatusCode, actualResponse.StatusCode) 247 return actualCtx 248 } 249 fanoutFailedCalled = false 250 fanoutFail = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context { 251 assert.False(fanoutFailedCalled) 252 fanoutFailedCalled = true 253 assert.Equal(ctx, actualCtx) 254 return ctx 255 } 256 257 endpoints = generateEndpoints(len(expectedResponses)) 258 transactor = new(xhttptest.MockTransactor) 259 complete = make(chan struct{}, len(expectedResponses)) 260 handler = New(endpoints, 261 WithTransactor(transactor.Do), 262 WithFanoutBefore(ForwardBody(true)), 263 WithClientBefore(gokithttp.SetRequestHeader("X-Test", "foobar")), 264 WithFanoutAfter(fanoutAfter), 265 WithClientAfter(clientAfter), 266 WithFanoutFailure(fanoutFail), 267 ) 268 ) 269 270 require.NotNil(handler) 271 for i, er := range expectedResponses { 272 transactor.OnDo( 273 xhttptest.MatchMethod("POST"), 274 xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"), 275 xhttptest.MatchHeader("X-Test", "foobar"), 276 xhttptest.MatchBodyString(expectedRequestBody), 277 ).RespondWith(er).Once().Run(func(mock.Arguments) { complete <- struct{}{} }) 278 } 279 280 handler.ServeHTTP(response, original) 281 assert.Equal(expectedStatusCode, response.Code) 282 assert.Equal(expectedResponseBody, response.Body.String()) 283 assert.Equal(expectAfter, clientAfterCalled) 284 assert.Equal(expectedFailedCalled, fanoutFailedCalled) 285 286 after := time.After(2 * time.Second) 287 for i := 0; i < len(expectedResponses); i++ { 288 select { 289 case <-complete: 290 // passing 291 case <-after: 292 assert.Fail("Not all transactors completed") 293 i = len(expectedResponses) 294 } 295 } 296 297 transactor.AssertExpectations(t) 298 } 299 300 func testHandlerTimeout(t *testing.T, endpointCount int) { 301 var ( 302 assert = assert.New(t) 303 require = require.New(t) 304 305 logger = logging.NewTestLogger(nil, t) 306 ctx, cancel = context.WithCancel(logging.WithLogger(context.Background(), logger)) 307 original = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx) 308 response = httptest.NewRecorder() 309 310 endpoints = generateEndpoints(endpointCount) 311 transactor = new(xhttptest.MockTransactor) 312 transactorWait = make(chan time.Time) 313 complete = make(chan struct{}, endpointCount) 314 handlerWait = make(chan struct{}) 315 handler = New(endpoints, 316 WithTransactor(transactor.Do), 317 ) 318 ) 319 320 require.NotNil(handler) 321 for i := 0; i < endpointCount; i++ { 322 transactor.OnDo( 323 xhttptest.MatchMethod("GET"), 324 xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"), 325 ).Respond(nil, nil).Once().WaitUntil(transactorWait).Run(func(mock.Arguments) { complete <- struct{}{} }) 326 } 327 328 go func() { 329 defer close(handlerWait) 330 handler.ServeHTTP(response, original) 331 }() 332 333 // simulate a context timeout 334 cancel() 335 select { 336 case <-handlerWait: 337 assert.Equal(http.StatusGatewayTimeout, response.Code) 338 case <-time.After(2 * time.Second): 339 assert.Fail("ServeHTTP did not return") 340 } 341 342 close(transactorWait) 343 after := time.After(2 * time.Second) 344 for i := 0; i < endpointCount; i++ { 345 select { 346 case <-complete: 347 // passing 348 case <-after: 349 assert.Fail("Not all transactors completed") 350 i = endpointCount 351 } 352 } 353 354 transactor.AssertExpectations(t) 355 } 356 357 func TestHandler(t *testing.T) { 358 t.Run("BodyError", testHandlerBodyError) 359 t.Run("NoEndpoints", testHandlerNoEndpoints) 360 t.Run("EndpointsError", testHandlerEndpointsError) 361 t.Run("BadTransactor", testHandlerBadTransactor) 362 363 t.Run("Fanout", func(t *testing.T) { 364 testData := []struct { 365 statusCodes []xhttptest.ExpectedResponse 366 expectedStatusCode int 367 expectedResponseBody string 368 expectAfter bool 369 expectedFailedCalled bool 370 }{ 371 { 372 []xhttptest.ExpectedResponse{ 373 {StatusCode: 504}, 374 }, 375 504, 376 "", 377 false, 378 true, 379 }, 380 { 381 []xhttptest.ExpectedResponse{ 382 {StatusCode: 500}, {StatusCode: 501}, {StatusCode: 502}, {StatusCode: 503}, {StatusCode: 504}, 383 }, 384 504, 385 "", 386 false, 387 true, 388 }, 389 { 390 []xhttptest.ExpectedResponse{ 391 {StatusCode: 504}, {StatusCode: 503}, {StatusCode: 502}, {StatusCode: 501}, {StatusCode: 500}, 392 }, 393 504, 394 "", 395 false, 396 true, 397 }, 398 { 399 []xhttptest.ExpectedResponse{ 400 {Err: errors.New("expected")}, 401 }, 402 http.StatusServiceUnavailable, 403 "expected", 404 false, 405 true, 406 }, 407 { 408 []xhttptest.ExpectedResponse{ 409 {StatusCode: 500}, {Err: errors.New("expected")}, 410 }, 411 http.StatusServiceUnavailable, 412 "expected", 413 false, 414 true, 415 }, 416 { 417 []xhttptest.ExpectedResponse{ 418 {StatusCode: 599}, {Err: errors.New("expected")}, 419 }, 420 599, 421 "", 422 false, 423 true, 424 }, 425 { 426 []xhttptest.ExpectedResponse{ 427 {StatusCode: 200, Body: []byte("expected body")}, 428 }, 429 200, 430 "expected body", 431 true, 432 false, 433 }, 434 { 435 []xhttptest.ExpectedResponse{ 436 {StatusCode: 404}, {StatusCode: 200, Body: []byte("expected body")}, {StatusCode: 503}, 437 }, 438 200, 439 "expected body", 440 true, 441 false, 442 }, 443 } 444 445 t.Run("GET", func(t *testing.T) { 446 for _, record := range testData { 447 testHandlerGet(t, record.statusCodes, record.expectedStatusCode, record.expectedResponseBody, record.expectAfter, record.expectedFailedCalled) 448 } 449 }) 450 451 t.Run("POST", func(t *testing.T) { 452 for _, record := range testData { 453 testHandlerPost(t, record.statusCodes, record.expectedStatusCode, record.expectedResponseBody, record.expectAfter, record.expectedFailedCalled) 454 } 455 }) 456 }) 457 458 t.Run("Timeout", func(t *testing.T) { 459 for _, endpointCount := range []int{1, 2, 3, 5} { 460 t.Run(fmt.Sprintf("EndpointCount=%d", endpointCount), func(t *testing.T) { 461 testHandlerTimeout(t, endpointCount) 462 }) 463 } 464 }) 465 } 466 467 func testNewNilEndpoints(t *testing.T) { 468 assert := assert.New(t) 469 assert.Panics(func() { 470 New(nil) 471 }) 472 } 473 474 func testNewNilConfiguration(t *testing.T) { 475 var ( 476 assert = assert.New(t) 477 require = require.New(t) 478 479 handler = New(FixedEndpoints{}, 480 WithShouldTerminate(nil), 481 WithErrorEncoder(nil), 482 WithTransactor(nil), 483 WithFanoutBefore(), 484 WithClientBefore(), 485 WithFanoutAfter(), 486 WithFanoutFailure(), 487 WithClientFailure(), 488 ) 489 ) 490 491 require.NotNil(handler) 492 assert.NotNil(handler.shouldTerminate) 493 assert.NotNil(handler.errorEncoder) 494 assert.NotNil(handler.transactor) 495 assert.Empty(handler.before) 496 assert.Empty(handler.after) 497 assert.Empty(handler.failure) 498 } 499 500 func testNewNoConfiguration(t *testing.T) { 501 var ( 502 assert = assert.New(t) 503 require = require.New(t) 504 505 handler = New(FixedEndpoints{}) 506 ) 507 508 require.NotNil(handler) 509 assert.NotNil(handler.shouldTerminate) 510 assert.NotNil(handler.errorEncoder) 511 assert.NotNil(handler.transactor) 512 assert.Empty(handler.before) 513 assert.Empty(handler.after) 514 } 515 516 func testNewShouldTerminate(t *testing.T) { 517 var ( 518 assert = assert.New(t) 519 require = require.New(t) 520 521 shouldTerminateCalled = false 522 shouldTerminate = func(Result) bool { 523 assert.False(shouldTerminateCalled) 524 shouldTerminateCalled = true 525 return true 526 } 527 528 handler = New(FixedEndpoints{}, WithShouldTerminate(shouldTerminate)) 529 ) 530 531 require.NotNil(handler) 532 assert.True(handler.shouldTerminate(Result{})) 533 assert.True(shouldTerminateCalled) 534 } 535 536 func testNewWithInjectedConfiguration(t *testing.T) { 537 var ( 538 assert = assert.New(t) 539 require = require.New(t) 540 541 expectedEndpoints = MustParseURLs("http://foobar.com:8080") 542 543 handler = New( 544 expectedEndpoints, 545 WithConfiguration(Configuration{ 546 Endpoints: []string{"localhost:1234"}, 547 Authorization: "deadbeef", 548 }), 549 ) 550 ) 551 552 require.NotNil(handler) 553 assert.NotNil(handler.transactor) 554 assert.Len(handler.before, 1) 555 assert.Equal(expectedEndpoints, handler.endpoints) 556 } 557 558 func TestNew(t *testing.T) { 559 t.Run("NilEndpoints", testNewNilEndpoints) 560 t.Run("NilConfiguration", testNewNilConfiguration) 561 t.Run("NoConfiguration", testNewNoConfiguration) 562 t.Run("ShouldTerminate", testNewShouldTerminate) 563 t.Run("WithInjectedConfiguration", testNewWithInjectedConfiguration) 564 }