github.com/xmidt-org/webpa-common@v1.11.9/xhttp/fanout/serviceEndpoints_test.go (about)

     1  package fanout
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"testing"
     9  
    10  	"github.com/go-kit/kit/sd"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/mock"
    13  	"github.com/stretchr/testify/require"
    14  	"github.com/xmidt-org/webpa-common/device"
    15  	"github.com/xmidt-org/webpa-common/service"
    16  	"github.com/xmidt-org/webpa-common/service/monitor"
    17  )
    18  
    19  func testNewServiceEndpointsHashError(t *testing.T) {
    20  	var (
    21  		assert  = assert.New(t)
    22  		require = require.New(t)
    23  
    24  		request = httptest.NewRequest("GET", "/", nil)
    25  
    26  		se = NewServiceEndpoints()
    27  	)
    28  
    29  	require.NotNil(se)
    30  	request.Header.Set(device.DeviceNameHeader, "mac:112233445566")
    31  
    32  	urls, err := se.FanoutURLs(request)
    33  	assert.Empty(urls)
    34  	assert.Error(err)
    35  }
    36  
    37  func testNewServiceEndpointsKeyFuncError(t *testing.T) {
    38  	var (
    39  		assert  = assert.New(t)
    40  		require = require.New(t)
    41  
    42  		request = httptest.NewRequest("GET", "/", nil)
    43  
    44  		expectedError = errors.New("expected error from KeyFunc")
    45  		keyFunc       = func(r *http.Request) ([]byte, error) {
    46  			return nil, expectedError
    47  		}
    48  
    49  		se = NewServiceEndpoints(WithKeyFunc(keyFunc))
    50  	)
    51  
    52  	require.NotNil(se)
    53  	urls, err := se.FanoutURLs(request)
    54  	assert.Empty(urls)
    55  	assert.Equal(expectedError, err)
    56  }
    57  
    58  func testNewServiceEndpointsDefault(t *testing.T, se *ServiceEndpoints) {
    59  	var (
    60  		assert  = assert.New(t)
    61  		request = httptest.NewRequest("GET", "/", nil)
    62  	)
    63  
    64  	request.Header.Set(device.DeviceNameHeader, "mac:112233445566")
    65  
    66  	urls, err := se.FanoutURLs(request)
    67  	assert.Empty(urls)
    68  	assert.Error(err)
    69  
    70  	se.MonitorEvent(monitor.Event{Key: "key1"})
    71  	urls, err = se.FanoutURLs(request)
    72  	assert.Empty(urls)
    73  	assert.Error(err)
    74  
    75  	se.MonitorEvent(monitor.Event{Key: "key1", Instances: []string{"http://localhost:8080"}})
    76  	urls, err = se.FanoutURLs(request)
    77  	assert.Len(urls, 1)
    78  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "localhost:8080"})
    79  	assert.NoError(err)
    80  
    81  	se.MonitorEvent(monitor.Event{Key: "key2", Instances: []string{"http://foobar.net:1234"}})
    82  	urls, err = se.FanoutURLs(request)
    83  	assert.Len(urls, 2)
    84  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "localhost:8080"})
    85  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "foobar.net:1234"})
    86  	assert.NoError(err)
    87  
    88  	se.MonitorEvent(monitor.Event{Key: "key1", Instances: []string{"https://somewhere.com"}})
    89  	urls, err = se.FanoutURLs(request)
    90  	assert.Len(urls, 2)
    91  	assert.Contains(urls, &url.URL{Scheme: "https", Host: "somewhere.com"})
    92  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "foobar.net:1234"})
    93  	assert.NoError(err)
    94  }
    95  
    96  func testNewServiceEndpointsCustom(t *testing.T) {
    97  	var (
    98  		assert  = assert.New(t)
    99  		require = require.New(t)
   100  
   101  		request = httptest.NewRequest("GET", "/", nil)
   102  
   103  		keyFuncCalled = false
   104  		keyFunc       = func(r *http.Request) ([]byte, error) {
   105  			keyFuncCalled = true
   106  			return device.IDHashParser(r)
   107  		}
   108  
   109  		accessorFactoryCalled = false
   110  		accessorFactory       = func(instances []string) service.Accessor {
   111  			accessorFactoryCalled = true
   112  			return service.DefaultAccessorFactory(instances)
   113  		}
   114  
   115  		se = NewServiceEndpoints(WithAccessorFactory(accessorFactory), WithKeyFunc(keyFunc))
   116  	)
   117  
   118  	require.NotNil(se)
   119  	request.Header.Set(device.DeviceNameHeader, "mac:112233445566")
   120  
   121  	urls, err := se.FanoutURLs(request)
   122  	assert.True(keyFuncCalled)
   123  	assert.Empty(urls)
   124  	assert.Error(err)
   125  
   126  	keyFuncCalled = false
   127  	accessorFactoryCalled = false
   128  	se.MonitorEvent(monitor.Event{Key: "key1"})
   129  	assert.True(accessorFactoryCalled)
   130  	urls, err = se.FanoutURLs(request)
   131  	assert.Empty(urls)
   132  	assert.Error(err)
   133  
   134  	keyFuncCalled = false
   135  	accessorFactoryCalled = false
   136  	se.MonitorEvent(monitor.Event{Key: "key1", Instances: []string{"http://localhost:8080"}})
   137  	urls, err = se.FanoutURLs(request)
   138  	assert.True(keyFuncCalled)
   139  	assert.True(accessorFactoryCalled)
   140  	assert.Len(urls, 1)
   141  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "localhost:8080"})
   142  	assert.NoError(err)
   143  
   144  	keyFuncCalled = false
   145  	accessorFactoryCalled = false
   146  	se.MonitorEvent(monitor.Event{Key: "key2", Instances: []string{"http://foobar.net:1234"}})
   147  	urls, err = se.FanoutURLs(request)
   148  	assert.True(keyFuncCalled)
   149  	assert.True(accessorFactoryCalled)
   150  	assert.Len(urls, 2)
   151  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "localhost:8080"})
   152  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "foobar.net:1234"})
   153  	assert.NoError(err)
   154  
   155  	keyFuncCalled = false
   156  	accessorFactoryCalled = false
   157  	se.MonitorEvent(monitor.Event{Key: "key1", Instances: []string{"https://somewhere.com"}})
   158  	urls, err = se.FanoutURLs(request)
   159  	assert.True(keyFuncCalled)
   160  	assert.True(accessorFactoryCalled)
   161  	assert.Len(urls, 2)
   162  	assert.Contains(urls, &url.URL{Scheme: "https", Host: "somewhere.com"})
   163  	assert.Contains(urls, &url.URL{Scheme: "http", Host: "foobar.net:1234"})
   164  	assert.NoError(err)
   165  }
   166  
   167  func TestNewServiceEndpoints(t *testing.T) {
   168  	t.Run("KeyFuncError", testNewServiceEndpointsKeyFuncError)
   169  
   170  	t.Run("Default", func(t *testing.T) {
   171  		testNewServiceEndpointsDefault(t, NewServiceEndpoints())
   172  		testNewServiceEndpointsDefault(t, NewServiceEndpoints(WithAccessorFactory(nil), WithKeyFunc(nil)))
   173  	})
   174  
   175  	t.Run("Custom", testNewServiceEndpointsCustom)
   176  }
   177  
   178  func TestServiceEndpointsAlternate(t *testing.T) {
   179  	var (
   180  		assert  = assert.New(t)
   181  		require = require.New(t)
   182  
   183  		e, err = ServiceEndpointsAlternate()()
   184  	)
   185  
   186  	require.NotNil(e)
   187  	assert.NoError(err)
   188  
   189  	se, ok := e.(*ServiceEndpoints)
   190  	require.True(ok)
   191  
   192  	assert.NotNil(se.keyFunc)
   193  	assert.NotNil(se.accessorFactory)
   194  }
   195  
   196  func testMonitorListenerWithNonListener(t *testing.T) {
   197  	var (
   198  		assert = assert.New(t)
   199  
   200  		fe     = FixedEndpoints{}
   201  		m, err = MonitorEndpoints(fe)
   202  	)
   203  
   204  	assert.Nil(m)
   205  	assert.NoError(err)
   206  }
   207  
   208  func testMonitorListenerWithListener(t *testing.T) {
   209  	var (
   210  		assert = assert.New(t)
   211  
   212  		deregisterWait = make(chan struct{})
   213  
   214  		i  = new(service.MockInstancer)
   215  		se = NewServiceEndpoints()
   216  	)
   217  
   218  	i.On("Register", mock.MatchedBy(func(chan<- sd.Event) bool { return true })).Once()
   219  	i.On("Deregister", mock.MatchedBy(func(chan<- sd.Event) bool { return true })).Once().Run(func(mock.Arguments) {
   220  		close(deregisterWait)
   221  	})
   222  
   223  	m, err := MonitorEndpoints(se, monitor.WithInstancers(service.Instancers{"key": i}))
   224  	assert.NotNil(m)
   225  	assert.NoError(err)
   226  	m.Stop()
   227  
   228  	<-deregisterWait
   229  	i.AssertExpectations(t)
   230  }
   231  
   232  func TestMonitorEndpoints(t *testing.T) {
   233  	t.Run("WithNonListener", testMonitorListenerWithNonListener)
   234  	t.Run("WithListener", testMonitorListenerWithListener)
   235  }