github.com/xmidt-org/webpa-common@v1.11.9/server/webpa_test.go (about)

     1  package server
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"net/http"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/justinas/alice"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/mock"
    14  	"github.com/stretchr/testify/require"
    15  	"github.com/xmidt-org/webpa-common/xmetrics"
    16  )
    17  
    18  func TestListenAndServeNonSecure(t *testing.T) {
    19  	var (
    20  		simpleError = errors.New("expected")
    21  		testData    = []struct {
    22  			certificateFile, keyFile string
    23  			expectedError            error
    24  			shouldCallFinal          bool
    25  		}{
    26  			{"", "", http.ErrServerClosed, true},
    27  			{"", "", simpleError, false},
    28  			{"file.cert", "", http.ErrServerClosed, true},
    29  			{"file.cert", "", simpleError, false},
    30  			{"", "file.key", http.ErrServerClosed, true},
    31  			{"", "file.key", simpleError, false},
    32  		}
    33  	)
    34  
    35  	for _, record := range testData {
    36  		t.Logf("%#v", record)
    37  		var (
    38  			assert = assert.New(t)
    39  
    40  			_, logger      = newTestLogger()
    41  			executorCalled = make(chan struct{}, 1)
    42  			mockExecutor   = new(mockExecutor)
    43  
    44  			finalizerCalled = make(chan struct{})
    45  			finalizer       = func() {
    46  				close(finalizerCalled)
    47  			}
    48  		)
    49  
    50  		mockExecutor.On("ListenAndServe").
    51  			Return(record.expectedError).
    52  			Run(func(mock.Arguments) { executorCalled <- struct{}{} })
    53  
    54  		ListenAndServe(logger, mockExecutor, finalizer)
    55  		select {
    56  		case <-executorCalled:
    57  			// passing
    58  		case <-time.After(time.Second):
    59  			assert.Fail("the executor was not called")
    60  		}
    61  
    62  		select {
    63  		case <-finalizerCalled:
    64  			// passing
    65  		case <-time.After(time.Second):
    66  			if record.shouldCallFinal {
    67  				assert.Fail("the finalizer was not called")
    68  			}
    69  		}
    70  
    71  		mockExecutor.AssertExpectations(t)
    72  	}
    73  }
    74  
    75  func TestListenAndServeSecure(t *testing.T) {
    76  	var (
    77  		testData = []struct {
    78  			expectedError   error
    79  			shouldCallFinal bool
    80  		}{
    81  			{http.ErrServerClosed, true},
    82  			{errors.New("expected"), false},
    83  		}
    84  	)
    85  
    86  	for _, record := range testData {
    87  		t.Logf("%#v", record)
    88  		var (
    89  			assert = assert.New(t)
    90  
    91  			_, logger      = newTestLogger()
    92  			executorCalled = make(chan struct{}, 1)
    93  			mockExecutor   = new(mockExecutor)
    94  
    95  			finalizerCalled = make(chan struct{})
    96  			finalizer       = func() {
    97  				close(finalizerCalled)
    98  			}
    99  		)
   100  
   101  		mockExecutor.On("ListenAndServe").
   102  			Return(record.expectedError).
   103  			Run(func(mock.Arguments) { executorCalled <- struct{}{} })
   104  
   105  		ListenAndServe(logger, mockExecutor, finalizer)
   106  		select {
   107  		case <-executorCalled:
   108  			// passing
   109  		case <-time.After(time.Second):
   110  			assert.Fail("the executor was not called")
   111  		}
   112  
   113  		select {
   114  		case <-finalizerCalled:
   115  			// passing
   116  		case <-time.After(time.Second):
   117  			if record.shouldCallFinal {
   118  				assert.Fail("the finalizer was not called")
   119  			}
   120  		}
   121  
   122  		mockExecutor.AssertExpectations(t)
   123  	}
   124  }
   125  
   126  func TestBasicNew(t *testing.T) {
   127  	const expectedName = "TestBasicNew"
   128  
   129  	var (
   130  		assert   = assert.New(t)
   131  		require  = require.New(t)
   132  		testData = []struct {
   133  			description        string
   134  			address            string
   135  			handler            *mockHandler
   136  			certFile           []string
   137  			keyFile            []string
   138  			clientCACertFile   string
   139  			minTLSVersion      uint16
   140  			maxTLSVersion      uint16
   141  			logConnectionState bool
   142  			expectTLS          bool
   143  			expectmTLS         bool
   144  			nilServer          bool
   145  		}{
   146  			{
   147  				description:        "No address",
   148  				address:            "",
   149  				handler:            nil,
   150  				logConnectionState: false,
   151  				nilServer:          true,
   152  			},
   153  			{
   154  				description:        "Nil handler",
   155  				address:            ":443",
   156  				handler:            nil,
   157  				logConnectionState: true,
   158  			},
   159  
   160  			{
   161  				description:        "Invalid cert file",
   162  				address:            ":443",
   163  				handler:            new(mockHandler),
   164  				logConnectionState: true,
   165  				certFile:           []string{"cert.pem", "missing-pair.pem"},
   166  				keyFile:            []string{"key.pem"},
   167  				nilServer:          true,
   168  			},
   169  
   170  			{
   171  				description:        "Invalid key file",
   172  				address:            ":443",
   173  				handler:            new(mockHandler),
   174  				logConnectionState: true,
   175  				certFile:           []string{"cert.pem"},
   176  				keyFile:            []string{"key.pem", "missing-pair.pem"},
   177  				nilServer:          true,
   178  			},
   179  
   180  			{
   181  				description:        "Invalid client CA cert file",
   182  				address:            ":443",
   183  				handler:            new(mockHandler),
   184  				logConnectionState: true,
   185  				certFile:           []string{"cert.pem"},
   186  				keyFile:            []string{"key.pem"},
   187  				clientCACertFile:   "missing-file.pem",
   188  				nilServer:          true,
   189  			},
   190  
   191  			{
   192  				description:        "Invalid client CA cert file",
   193  				address:            ":443",
   194  				handler:            new(mockHandler),
   195  				logConnectionState: true,
   196  				certFile:           []string{"cert.pem"},
   197  				keyFile:            []string{"key.pem"},
   198  				clientCACertFile:   "missing-file.pem",
   199  				nilServer:          true,
   200  			},
   201  
   202  			{
   203  				description:        "TLS enabled",
   204  				address:            ":443",
   205  				handler:            new(mockHandler),
   206  				logConnectionState: true,
   207  				certFile:           []string{"cert.pem"},
   208  				keyFile:            []string{"key.pem"},
   209  				minTLSVersion:      tls.VersionTLS11,
   210  				maxTLSVersion:      tls.VersionTLS12,
   211  				expectTLS:          true,
   212  			},
   213  
   214  			{
   215  				description:        "mTLS enabled",
   216  				address:            ":443",
   217  				handler:            new(mockHandler),
   218  				logConnectionState: true,
   219  				certFile:           []string{"cert.pem"},
   220  				keyFile:            []string{"key.pem"},
   221  				clientCACertFile:   "client_ca.pem",
   222  				minTLSVersion:      tls.VersionTLS12,
   223  				maxTLSVersion:      tls.VersionTLS13,
   224  				expectTLS:          true,
   225  				expectmTLS:         true,
   226  			},
   227  		}
   228  	)
   229  
   230  	for _, record := range testData {
   231  		t.Run(record.description, func(t *testing.T) {
   232  			var (
   233  				verify, logger = newTestLogger()
   234  				basic          = Basic{
   235  					Name:               expectedName,
   236  					Address:            record.address,
   237  					LogConnectionState: record.logConnectionState,
   238  					CertificateFile:    record.certFile,
   239  					KeyFile:            record.keyFile,
   240  					ClientCACertFile:   record.clientCACertFile,
   241  					MaxVersion:         record.maxTLSVersion,
   242  					MinVersion:         record.minTLSVersion,
   243  					DisableKeepAlives:  true,
   244  				}
   245  			)
   246  
   247  			server := basic.New(logger, record.handler)
   248  
   249  			if !record.nilServer {
   250  				require.NotNil(server)
   251  				assert.Equal(record.address, server.Addr)
   252  				assert.Equal(record.handler, server.Handler)
   253  				assertErrorLog(assert, verify, expectedName, server.ErrorLog)
   254  
   255  				if record.logConnectionState {
   256  					assertConnState(assert, verify, server.ConnState)
   257  				} else {
   258  					assert.Nil(server.ConnState)
   259  				}
   260  
   261  				if record.expectTLS {
   262  					assert.NotZero(server.TLSConfig.MaxVersion)
   263  					assert.Equal(record.minTLSVersion, server.TLSConfig.MinVersion)
   264  					assert.Equal(record.maxTLSVersion, server.TLSConfig.MaxVersion)
   265  					assert.NotNil(server.TLSConfig.Certificates)
   266  					if record.expectmTLS {
   267  						assert.NotNil(server.TLSConfig.ClientCAs)
   268  						assert.Equal(tls.RequireAndVerifyClientCert, server.TLSConfig.ClientAuth)
   269  					}
   270  				} else {
   271  					assert.Nil(server.TLSConfig)
   272  				}
   273  			} else {
   274  				require.Nil(server)
   275  			}
   276  
   277  			if record.handler != nil {
   278  				record.handler.AssertExpectations(t)
   279  			}
   280  		})
   281  	}
   282  }
   283  
   284  func TestHealthNew(t *testing.T) {
   285  	const (
   286  		expectedName                      = "TestHealthNew"
   287  		expectedLogInterval time.Duration = 45 * time.Second
   288  	)
   289  
   290  	var (
   291  		assert  = assert.New(t)
   292  		require = require.New(t)
   293  
   294  		expectedHandlerType *http.ServeMux = nil
   295  
   296  		testData = []struct {
   297  			address            string
   298  			logConnectionState bool
   299  			options            []string
   300  		}{
   301  			{"", false, nil},
   302  			{"", false, []string{}},
   303  			{"", false, []string{"Value1"}},
   304  			{"", false, []string{"Value1", "Value2"}},
   305  
   306  			{"", true, nil},
   307  			{"", true, []string{}},
   308  			{"", true, []string{"Value1"}},
   309  			{"", true, []string{"Value1", "Value2"}},
   310  
   311  			{":901", false, nil},
   312  			{":1987", false, []string{}},
   313  			{":http", false, []string{"Value1"}},
   314  			{":https", false, []string{"Value1", "Value2"}},
   315  
   316  			{"locahost:9001", true, nil},
   317  			{":57899", true, []string{}},
   318  			{":ftp", true, []string{"Value1"}},
   319  			{":0", true, []string{"Value1", "Value2"}},
   320  		}
   321  	)
   322  
   323  	for _, record := range testData {
   324  		t.Logf("%#v", record)
   325  
   326  		var (
   327  			verify, logger = newTestLogger()
   328  			health         = Health{
   329  				Name:               expectedName,
   330  				Address:            record.address,
   331  				LogConnectionState: record.logConnectionState,
   332  				LogInterval:        expectedLogInterval,
   333  				Options:            record.options,
   334  			}
   335  
   336  			handler, server = health.New(logger, alice.New(), nil)
   337  		)
   338  
   339  		if len(record.address) > 0 {
   340  			require.NotNil(handler)
   341  			require.NotNil(server)
   342  			assert.Equal(record.address, server.Addr)
   343  			assert.IsType(expectedHandlerType, server.Handler)
   344  			assertErrorLog(assert, verify, expectedName, server.ErrorLog)
   345  
   346  			if record.logConnectionState {
   347  				assertConnState(assert, verify, server.ConnState)
   348  			} else {
   349  				assert.Nil(server.ConnState)
   350  			}
   351  		} else {
   352  			require.Nil(handler)
   353  			require.Nil(server)
   354  		}
   355  	}
   356  }
   357  
   358  func TestWebPANoPrimaryAddress(t *testing.T) {
   359  	var (
   360  		assert  = assert.New(t)
   361  		require = require.New(t)
   362  	)
   363  
   364  	r, err := xmetrics.NewRegistry(nil, Metrics)
   365  	require.NoError(err)
   366  	require.NotNil(r)
   367  
   368  	var (
   369  		handler = new(mockHandler)
   370  		webPA   = WebPA{}
   371  
   372  		_, logger               = newTestLogger()
   373  		monitor, runnable, done = webPA.Prepare(logger, nil, xmetrics.MustNewRegistry(nil), handler)
   374  	)
   375  
   376  	assert.Nil(monitor)
   377  	require.NotNil(runnable)
   378  	assert.NotNil(done)
   379  
   380  	var (
   381  		waitGroup = new(sync.WaitGroup)
   382  		shutdown  = make(chan struct{})
   383  	)
   384  
   385  	defer close(shutdown)
   386  	assert.Equal(ErrorNoPrimaryAddress, runnable.Run(waitGroup, shutdown))
   387  	waitGroup.Wait() // nothing should have incremented the wait group
   388  	handler.AssertExpectations(t)
   389  }
   390  
   391  func TestWebPA(t *testing.T) {
   392  	var (
   393  		assert  = assert.New(t)
   394  		require = require.New(t)
   395  		handler = new(mockHandler)
   396  	)
   397  
   398  	r, err := xmetrics.NewRegistry(nil, Metrics)
   399  	require.NoError(err)
   400  	require.NotNil(r)
   401  
   402  	var (
   403  		// synthesize a WebPA instance that will start everything,
   404  		// close to how it would be unmarshalled from Viper.
   405  		webPA = WebPA{
   406  			Primary: Basic{
   407  				Name:    "test",
   408  				Address: ":0",
   409  			},
   410  			Alternate: Basic{
   411  				Name:    "test.alternate",
   412  				Address: ":0",
   413  			},
   414  			Health: Health{
   415  				Name:        "test.health",
   416  				Address:     ":0",
   417  				LogInterval: 60 * time.Minute,
   418  				Options:     []string{"Option1", "Option2"},
   419  			},
   420  			Pprof: Basic{
   421  				Name:    "test.pprof",
   422  				Address: ":0",
   423  			},
   424  
   425  			Metric: Metric{
   426  				Name:    "test.metrics",
   427  				Address: ":0",
   428  			},
   429  		}
   430  
   431  		_, logger               = newTestLogger()
   432  		monitor, runnable, done = webPA.Prepare(logger, nil, xmetrics.MustNewRegistry(nil), handler)
   433  	)
   434  
   435  	assert.NotNil(monitor)
   436  	require.NotNil(runnable)
   437  	assert.NotNil(done)
   438  
   439  	var (
   440  		waitGroup = new(sync.WaitGroup)
   441  		shutdown  = make(chan struct{})
   442  	)
   443  
   444  	assert.Nil(runnable.Run(waitGroup, shutdown))
   445  	close(shutdown)
   446  	waitGroup.Wait() // the http.Server instances will still be running after this returns
   447  	handler.AssertExpectations(t)
   448  }