github.com/xmidt-org/webpa-common@v1.11.9/xhttp/fanout/handler_test.go (about)

     1  package fanout
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	gokithttp "github.com/go-kit/kit/transport/http"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/mock"
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/xmidt-org/webpa-common/logging"
    19  	"github.com/xmidt-org/webpa-common/xhttp"
    20  	"github.com/xmidt-org/webpa-common/xhttp/xhttptest"
    21  )
    22  
    23  func testHandlerBodyError(t *testing.T) {
    24  	var (
    25  		assert  = assert.New(t)
    26  		require = require.New(t)
    27  
    28  		expectedError = &xhttp.Error{Code: 599, Text: "body read error"}
    29  		body          = new(xhttptest.MockBody)
    30  		logger        = logging.NewTestLogger(nil, t)
    31  		ctx           = logging.WithLogger(context.Background(), logger)
    32  		original      = httptest.NewRequest("POST", "/something", body).WithContext(ctx)
    33  		response      = httptest.NewRecorder()
    34  
    35  		handler = New(FixedEndpoints{})
    36  	)
    37  
    38  	require.NotNil(handler)
    39  	body.OnReadError(expectedError).Once()
    40  
    41  	handler.ServeHTTP(response, original)
    42  	assert.Equal(599, response.Code)
    43  
    44  	body.AssertExpectations(t)
    45  }
    46  
    47  func testHandlerNoEndpoints(t *testing.T) {
    48  	var (
    49  		assert  = assert.New(t)
    50  		require = require.New(t)
    51  
    52  		body     = new(xhttptest.MockBody)
    53  		logger   = logging.NewTestLogger(nil, t)
    54  		ctx      = logging.WithLogger(context.Background(), logger)
    55  		original = httptest.NewRequest("POST", "/something", body).WithContext(ctx)
    56  		response = httptest.NewRecorder()
    57  
    58  		handler = New(FixedEndpoints{}, WithErrorEncoder(func(_ context.Context, err error, response http.ResponseWriter) {
    59  			response.WriteHeader(599)
    60  		}))
    61  	)
    62  
    63  	require.NotNil(handler)
    64  	body.OnReadError(io.EOF).Once()
    65  
    66  	handler.ServeHTTP(response, original)
    67  	assert.Equal(599, response.Code)
    68  
    69  	body.AssertExpectations(t)
    70  }
    71  
    72  func testHandlerEndpointsError(t *testing.T) {
    73  	var (
    74  		assert  = assert.New(t)
    75  		require = require.New(t)
    76  
    77  		expectedError = errors.New("endpoints error")
    78  		body          = new(xhttptest.MockBody)
    79  		endpoints     = new(mockEndpoints)
    80  
    81  		logger   = logging.NewTestLogger(nil, t)
    82  		ctx      = logging.WithLogger(context.Background(), logger)
    83  		original = httptest.NewRequest("POST", "/something", body).WithContext(ctx)
    84  		response = httptest.NewRecorder()
    85  
    86  		handler = New(endpoints, WithErrorEncoder(func(_ context.Context, err error, response http.ResponseWriter) {
    87  			response.WriteHeader(599)
    88  		}))
    89  	)
    90  
    91  	require.NotNil(handler)
    92  	body.OnReadError(io.EOF).Once()
    93  	endpoints.On("FanoutURLs", original).Once().Return(nil, expectedError)
    94  
    95  	handler.ServeHTTP(response, original)
    96  	assert.Equal(599, response.Code)
    97  
    98  	body.AssertExpectations(t)
    99  }
   100  
   101  func testHandlerBadTransactor(t *testing.T) {
   102  	var (
   103  		assert  = assert.New(t)
   104  		require = require.New(t)
   105  
   106  		logger   = logging.NewTestLogger(nil, t)
   107  		ctx      = logging.WithLogger(context.Background(), logger)
   108  		original = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx)
   109  		response = httptest.NewRecorder()
   110  
   111  		endpoints  = generateEndpoints(1)
   112  		transactor = new(xhttptest.MockTransactor)
   113  		complete   = make(chan struct{}, 1)
   114  		handler    = New(endpoints, WithTransactor(transactor.Do))
   115  	)
   116  
   117  	require.NotNil(handler)
   118  	transactor.OnDo(
   119  		xhttptest.MatchMethod("GET"),
   120  		xhttptest.MatchURLString(endpoints[0].String()+"/api/v2/something"),
   121  	).Respond(nil, nil).Once().Run(func(mock.Arguments) { complete <- struct{}{} })
   122  
   123  	handler.ServeHTTP(response, original)
   124  	assert.Equal(http.StatusServiceUnavailable, response.Code)
   125  
   126  	select {
   127  	case <-complete:
   128  		// passing
   129  	case <-time.After(5 * time.Second):
   130  		assert.Fail("Not all transactors completed")
   131  	}
   132  
   133  	transactor.AssertExpectations(t)
   134  }
   135  
   136  func testHandlerGet(t *testing.T, expectedResponses []xhttptest.ExpectedResponse, expectedStatusCode int, expectedResponseBody string, expectAfter bool, expectedFailedCalled bool) {
   137  	var (
   138  		assert  = assert.New(t)
   139  		require = require.New(t)
   140  
   141  		logger   = logging.NewTestLogger(nil, t)
   142  		ctx      = logging.WithLogger(context.Background(), logger)
   143  		original = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx)
   144  		response = httptest.NewRecorder()
   145  
   146  		fanoutAfterCalled = false
   147  		fanoutAfter       = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context {
   148  			assert.False(fanoutAfterCalled)
   149  			fanoutAfterCalled = true
   150  			assert.Equal(ctx, actualCtx)
   151  			assert.Equal(response, actualResponse)
   152  			if assert.NotNil(result.Response) {
   153  				assert.Equal(expectedStatusCode, result.Response.StatusCode)
   154  			}
   155  
   156  			return actualCtx
   157  		}
   158  
   159  		clientAfterCalled = false
   160  		clientAfter       = func(actualCtx context.Context, actualResponse *http.Response) context.Context {
   161  			assert.False(clientAfterCalled)
   162  			clientAfterCalled = true
   163  			assert.Equal(ctx, actualCtx)
   164  			assert.Equal(expectedStatusCode, actualResponse.StatusCode)
   165  			return actualCtx
   166  		}
   167  
   168  		fanoutFailedCalled = false
   169  		fanoutFail         = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context {
   170  			assert.False(fanoutFailedCalled)
   171  			fanoutFailedCalled = true
   172  			assert.Equal(ctx, actualCtx)
   173  			return ctx
   174  		}
   175  
   176  		endpoints  = generateEndpoints(len(expectedResponses))
   177  		transactor = new(xhttptest.MockTransactor)
   178  		complete   = make(chan struct{}, len(expectedResponses))
   179  
   180  		handler = New(endpoints,
   181  			WithTransactor(transactor.Do),
   182  			WithClientBefore(gokithttp.SetRequestHeader("X-Test", "foobar")),
   183  			WithFanoutAfter(fanoutAfter),
   184  			WithClientAfter(clientAfter),
   185  			WithFanoutFailure(fanoutFail),
   186  		)
   187  	)
   188  
   189  	require.NotNil(handler)
   190  	for i, er := range expectedResponses {
   191  		transactor.OnDo(
   192  			xhttptest.MatchMethod("GET"),
   193  			xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"),
   194  			xhttptest.MatchHeader("X-Test", "foobar"),
   195  		).RespondWith(er).Once().Run(func(mock.Arguments) { complete <- struct{}{} })
   196  	}
   197  
   198  	handler.ServeHTTP(response, original)
   199  	assert.Equal(expectedStatusCode, response.Code)
   200  
   201  	after := time.After(5 * time.Second)
   202  	for i := 0; i < len(expectedResponses); i++ {
   203  		select {
   204  		case <-complete:
   205  			// passing
   206  		case <-after:
   207  			assert.Fail("Not all transactors completed")
   208  			i = len(expectedResponses)
   209  		}
   210  	}
   211  
   212  	assert.Equal(expectAfter, clientAfterCalled)
   213  	assert.Equal(expectedFailedCalled, fanoutFailedCalled)
   214  	transactor.AssertExpectations(t)
   215  }
   216  
   217  func testHandlerPost(t *testing.T, expectedResponses []xhttptest.ExpectedResponse, expectedStatusCode int, expectedResponseBody string, expectAfter bool, expectedFailedCalled bool) {
   218  	var (
   219  		assert  = assert.New(t)
   220  		require = require.New(t)
   221  
   222  		logger              = logging.NewTestLogger(nil, t)
   223  		ctx                 = logging.WithLogger(context.Background(), logger)
   224  		expectedRequestBody = "posted body"
   225  		original            = httptest.NewRequest("POST", "/api/v2/something", strings.NewReader(expectedRequestBody)).WithContext(ctx)
   226  		response            = httptest.NewRecorder()
   227  
   228  		fanoutAfterCalled = false
   229  		fanoutAfter       = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context {
   230  			assert.False(fanoutAfterCalled)
   231  			fanoutAfterCalled = true
   232  			assert.Equal(ctx, actualCtx)
   233  			assert.Equal(response, actualResponse)
   234  			if assert.NotNil(result.Response) {
   235  				assert.Equal(expectedStatusCode, result.Response.StatusCode)
   236  			}
   237  
   238  			return actualCtx
   239  		}
   240  
   241  		clientAfterCalled = false
   242  		clientAfter       = func(actualCtx context.Context, actualResponse *http.Response) context.Context {
   243  			assert.False(clientAfterCalled)
   244  			clientAfterCalled = true
   245  			assert.Equal(ctx, actualCtx)
   246  			assert.Equal(expectedStatusCode, actualResponse.StatusCode)
   247  			return actualCtx
   248  		}
   249  		fanoutFailedCalled = false
   250  		fanoutFail         = func(actualCtx context.Context, actualResponse http.ResponseWriter, result Result) context.Context {
   251  			assert.False(fanoutFailedCalled)
   252  			fanoutFailedCalled = true
   253  			assert.Equal(ctx, actualCtx)
   254  			return ctx
   255  		}
   256  
   257  		endpoints  = generateEndpoints(len(expectedResponses))
   258  		transactor = new(xhttptest.MockTransactor)
   259  		complete   = make(chan struct{}, len(expectedResponses))
   260  		handler    = New(endpoints,
   261  			WithTransactor(transactor.Do),
   262  			WithFanoutBefore(ForwardBody(true)),
   263  			WithClientBefore(gokithttp.SetRequestHeader("X-Test", "foobar")),
   264  			WithFanoutAfter(fanoutAfter),
   265  			WithClientAfter(clientAfter),
   266  			WithFanoutFailure(fanoutFail),
   267  		)
   268  	)
   269  
   270  	require.NotNil(handler)
   271  	for i, er := range expectedResponses {
   272  		transactor.OnDo(
   273  			xhttptest.MatchMethod("POST"),
   274  			xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"),
   275  			xhttptest.MatchHeader("X-Test", "foobar"),
   276  			xhttptest.MatchBodyString(expectedRequestBody),
   277  		).RespondWith(er).Once().Run(func(mock.Arguments) { complete <- struct{}{} })
   278  	}
   279  
   280  	handler.ServeHTTP(response, original)
   281  	assert.Equal(expectedStatusCode, response.Code)
   282  	assert.Equal(expectedResponseBody, response.Body.String())
   283  	assert.Equal(expectAfter, clientAfterCalled)
   284  	assert.Equal(expectedFailedCalled, fanoutFailedCalled)
   285  
   286  	after := time.After(2 * time.Second)
   287  	for i := 0; i < len(expectedResponses); i++ {
   288  		select {
   289  		case <-complete:
   290  			// passing
   291  		case <-after:
   292  			assert.Fail("Not all transactors completed")
   293  			i = len(expectedResponses)
   294  		}
   295  	}
   296  
   297  	transactor.AssertExpectations(t)
   298  }
   299  
   300  func testHandlerTimeout(t *testing.T, endpointCount int) {
   301  	var (
   302  		assert  = assert.New(t)
   303  		require = require.New(t)
   304  
   305  		logger      = logging.NewTestLogger(nil, t)
   306  		ctx, cancel = context.WithCancel(logging.WithLogger(context.Background(), logger))
   307  		original    = httptest.NewRequest("GET", "/api/v2/something", nil).WithContext(ctx)
   308  		response    = httptest.NewRecorder()
   309  
   310  		endpoints      = generateEndpoints(endpointCount)
   311  		transactor     = new(xhttptest.MockTransactor)
   312  		transactorWait = make(chan time.Time)
   313  		complete       = make(chan struct{}, endpointCount)
   314  		handlerWait    = make(chan struct{})
   315  		handler        = New(endpoints,
   316  			WithTransactor(transactor.Do),
   317  		)
   318  	)
   319  
   320  	require.NotNil(handler)
   321  	for i := 0; i < endpointCount; i++ {
   322  		transactor.OnDo(
   323  			xhttptest.MatchMethod("GET"),
   324  			xhttptest.MatchURLString(endpoints[i].String()+"/api/v2/something"),
   325  		).Respond(nil, nil).Once().WaitUntil(transactorWait).Run(func(mock.Arguments) { complete <- struct{}{} })
   326  	}
   327  
   328  	go func() {
   329  		defer close(handlerWait)
   330  		handler.ServeHTTP(response, original)
   331  	}()
   332  
   333  	// simulate a context timeout
   334  	cancel()
   335  	select {
   336  	case <-handlerWait:
   337  		assert.Equal(http.StatusGatewayTimeout, response.Code)
   338  	case <-time.After(2 * time.Second):
   339  		assert.Fail("ServeHTTP did not return")
   340  	}
   341  
   342  	close(transactorWait)
   343  	after := time.After(2 * time.Second)
   344  	for i := 0; i < endpointCount; i++ {
   345  		select {
   346  		case <-complete:
   347  			// passing
   348  		case <-after:
   349  			assert.Fail("Not all transactors completed")
   350  			i = endpointCount
   351  		}
   352  	}
   353  
   354  	transactor.AssertExpectations(t)
   355  }
   356  
   357  func TestHandler(t *testing.T) {
   358  	t.Run("BodyError", testHandlerBodyError)
   359  	t.Run("NoEndpoints", testHandlerNoEndpoints)
   360  	t.Run("EndpointsError", testHandlerEndpointsError)
   361  	t.Run("BadTransactor", testHandlerBadTransactor)
   362  
   363  	t.Run("Fanout", func(t *testing.T) {
   364  		testData := []struct {
   365  			statusCodes          []xhttptest.ExpectedResponse
   366  			expectedStatusCode   int
   367  			expectedResponseBody string
   368  			expectAfter          bool
   369  			expectedFailedCalled bool
   370  		}{
   371  			{
   372  				[]xhttptest.ExpectedResponse{
   373  					{StatusCode: 504},
   374  				},
   375  				504,
   376  				"",
   377  				false,
   378  				true,
   379  			},
   380  			{
   381  				[]xhttptest.ExpectedResponse{
   382  					{StatusCode: 500}, {StatusCode: 501}, {StatusCode: 502}, {StatusCode: 503}, {StatusCode: 504},
   383  				},
   384  				504,
   385  				"",
   386  				false,
   387  				true,
   388  			},
   389  			{
   390  				[]xhttptest.ExpectedResponse{
   391  					{StatusCode: 504}, {StatusCode: 503}, {StatusCode: 502}, {StatusCode: 501}, {StatusCode: 500},
   392  				},
   393  				504,
   394  				"",
   395  				false,
   396  				true,
   397  			},
   398  			{
   399  				[]xhttptest.ExpectedResponse{
   400  					{Err: errors.New("expected")},
   401  				},
   402  				http.StatusServiceUnavailable,
   403  				"expected",
   404  				false,
   405  				true,
   406  			},
   407  			{
   408  				[]xhttptest.ExpectedResponse{
   409  					{StatusCode: 500}, {Err: errors.New("expected")},
   410  				},
   411  				http.StatusServiceUnavailable,
   412  				"expected",
   413  				false,
   414  				true,
   415  			},
   416  			{
   417  				[]xhttptest.ExpectedResponse{
   418  					{StatusCode: 599}, {Err: errors.New("expected")},
   419  				},
   420  				599,
   421  				"",
   422  				false,
   423  				true,
   424  			},
   425  			{
   426  				[]xhttptest.ExpectedResponse{
   427  					{StatusCode: 200, Body: []byte("expected body")},
   428  				},
   429  				200,
   430  				"expected body",
   431  				true,
   432  				false,
   433  			},
   434  			{
   435  				[]xhttptest.ExpectedResponse{
   436  					{StatusCode: 404}, {StatusCode: 200, Body: []byte("expected body")}, {StatusCode: 503},
   437  				},
   438  				200,
   439  				"expected body",
   440  				true,
   441  				false,
   442  			},
   443  		}
   444  
   445  		t.Run("GET", func(t *testing.T) {
   446  			for _, record := range testData {
   447  				testHandlerGet(t, record.statusCodes, record.expectedStatusCode, record.expectedResponseBody, record.expectAfter, record.expectedFailedCalled)
   448  			}
   449  		})
   450  
   451  		t.Run("POST", func(t *testing.T) {
   452  			for _, record := range testData {
   453  				testHandlerPost(t, record.statusCodes, record.expectedStatusCode, record.expectedResponseBody, record.expectAfter, record.expectedFailedCalled)
   454  			}
   455  		})
   456  	})
   457  
   458  	t.Run("Timeout", func(t *testing.T) {
   459  		for _, endpointCount := range []int{1, 2, 3, 5} {
   460  			t.Run(fmt.Sprintf("EndpointCount=%d", endpointCount), func(t *testing.T) {
   461  				testHandlerTimeout(t, endpointCount)
   462  			})
   463  		}
   464  	})
   465  }
   466  
   467  func testNewNilEndpoints(t *testing.T) {
   468  	assert := assert.New(t)
   469  	assert.Panics(func() {
   470  		New(nil)
   471  	})
   472  }
   473  
   474  func testNewNilConfiguration(t *testing.T) {
   475  	var (
   476  		assert  = assert.New(t)
   477  		require = require.New(t)
   478  
   479  		handler = New(FixedEndpoints{},
   480  			WithShouldTerminate(nil),
   481  			WithErrorEncoder(nil),
   482  			WithTransactor(nil),
   483  			WithFanoutBefore(),
   484  			WithClientBefore(),
   485  			WithFanoutAfter(),
   486  			WithFanoutFailure(),
   487  			WithClientFailure(),
   488  		)
   489  	)
   490  
   491  	require.NotNil(handler)
   492  	assert.NotNil(handler.shouldTerminate)
   493  	assert.NotNil(handler.errorEncoder)
   494  	assert.NotNil(handler.transactor)
   495  	assert.Empty(handler.before)
   496  	assert.Empty(handler.after)
   497  	assert.Empty(handler.failure)
   498  }
   499  
   500  func testNewNoConfiguration(t *testing.T) {
   501  	var (
   502  		assert  = assert.New(t)
   503  		require = require.New(t)
   504  
   505  		handler = New(FixedEndpoints{})
   506  	)
   507  
   508  	require.NotNil(handler)
   509  	assert.NotNil(handler.shouldTerminate)
   510  	assert.NotNil(handler.errorEncoder)
   511  	assert.NotNil(handler.transactor)
   512  	assert.Empty(handler.before)
   513  	assert.Empty(handler.after)
   514  }
   515  
   516  func testNewShouldTerminate(t *testing.T) {
   517  	var (
   518  		assert  = assert.New(t)
   519  		require = require.New(t)
   520  
   521  		shouldTerminateCalled = false
   522  		shouldTerminate       = func(Result) bool {
   523  			assert.False(shouldTerminateCalled)
   524  			shouldTerminateCalled = true
   525  			return true
   526  		}
   527  
   528  		handler = New(FixedEndpoints{}, WithShouldTerminate(shouldTerminate))
   529  	)
   530  
   531  	require.NotNil(handler)
   532  	assert.True(handler.shouldTerminate(Result{}))
   533  	assert.True(shouldTerminateCalled)
   534  }
   535  
   536  func testNewWithInjectedConfiguration(t *testing.T) {
   537  	var (
   538  		assert  = assert.New(t)
   539  		require = require.New(t)
   540  
   541  		expectedEndpoints = MustParseURLs("http://foobar.com:8080")
   542  
   543  		handler = New(
   544  			expectedEndpoints,
   545  			WithConfiguration(Configuration{
   546  				Endpoints:     []string{"localhost:1234"},
   547  				Authorization: "deadbeef",
   548  			}),
   549  		)
   550  	)
   551  
   552  	require.NotNil(handler)
   553  	assert.NotNil(handler.transactor)
   554  	assert.Len(handler.before, 1)
   555  	assert.Equal(expectedEndpoints, handler.endpoints)
   556  }
   557  
   558  func TestNew(t *testing.T) {
   559  	t.Run("NilEndpoints", testNewNilEndpoints)
   560  	t.Run("NilConfiguration", testNewNilConfiguration)
   561  	t.Run("NoConfiguration", testNewNoConfiguration)
   562  	t.Run("ShouldTerminate", testNewShouldTerminate)
   563  	t.Run("WithInjectedConfiguration", testNewWithInjectedConfiguration)
   564  }