github.com/xmidt-org/webpa-common@v1.11.9/server/webpa_test.go (about) 1 package server 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "net/http" 7 "sync" 8 "testing" 9 "time" 10 11 "github.com/justinas/alice" 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/mock" 14 "github.com/stretchr/testify/require" 15 "github.com/xmidt-org/webpa-common/xmetrics" 16 ) 17 18 func TestListenAndServeNonSecure(t *testing.T) { 19 var ( 20 simpleError = errors.New("expected") 21 testData = []struct { 22 certificateFile, keyFile string 23 expectedError error 24 shouldCallFinal bool 25 }{ 26 {"", "", http.ErrServerClosed, true}, 27 {"", "", simpleError, false}, 28 {"file.cert", "", http.ErrServerClosed, true}, 29 {"file.cert", "", simpleError, false}, 30 {"", "file.key", http.ErrServerClosed, true}, 31 {"", "file.key", simpleError, false}, 32 } 33 ) 34 35 for _, record := range testData { 36 t.Logf("%#v", record) 37 var ( 38 assert = assert.New(t) 39 40 _, logger = newTestLogger() 41 executorCalled = make(chan struct{}, 1) 42 mockExecutor = new(mockExecutor) 43 44 finalizerCalled = make(chan struct{}) 45 finalizer = func() { 46 close(finalizerCalled) 47 } 48 ) 49 50 mockExecutor.On("ListenAndServe"). 51 Return(record.expectedError). 52 Run(func(mock.Arguments) { executorCalled <- struct{}{} }) 53 54 ListenAndServe(logger, mockExecutor, finalizer) 55 select { 56 case <-executorCalled: 57 // passing 58 case <-time.After(time.Second): 59 assert.Fail("the executor was not called") 60 } 61 62 select { 63 case <-finalizerCalled: 64 // passing 65 case <-time.After(time.Second): 66 if record.shouldCallFinal { 67 assert.Fail("the finalizer was not called") 68 } 69 } 70 71 mockExecutor.AssertExpectations(t) 72 } 73 } 74 75 func TestListenAndServeSecure(t *testing.T) { 76 var ( 77 testData = []struct { 78 expectedError error 79 shouldCallFinal bool 80 }{ 81 {http.ErrServerClosed, true}, 82 {errors.New("expected"), false}, 83 } 84 ) 85 86 for _, record := range testData { 87 t.Logf("%#v", record) 88 var ( 89 assert = assert.New(t) 90 91 _, logger = newTestLogger() 92 executorCalled = make(chan struct{}, 1) 93 mockExecutor = new(mockExecutor) 94 95 finalizerCalled = make(chan struct{}) 96 finalizer = func() { 97 close(finalizerCalled) 98 } 99 ) 100 101 mockExecutor.On("ListenAndServe"). 102 Return(record.expectedError). 103 Run(func(mock.Arguments) { executorCalled <- struct{}{} }) 104 105 ListenAndServe(logger, mockExecutor, finalizer) 106 select { 107 case <-executorCalled: 108 // passing 109 case <-time.After(time.Second): 110 assert.Fail("the executor was not called") 111 } 112 113 select { 114 case <-finalizerCalled: 115 // passing 116 case <-time.After(time.Second): 117 if record.shouldCallFinal { 118 assert.Fail("the finalizer was not called") 119 } 120 } 121 122 mockExecutor.AssertExpectations(t) 123 } 124 } 125 126 func TestBasicNew(t *testing.T) { 127 const expectedName = "TestBasicNew" 128 129 var ( 130 assert = assert.New(t) 131 require = require.New(t) 132 testData = []struct { 133 description string 134 address string 135 handler *mockHandler 136 certFile []string 137 keyFile []string 138 clientCACertFile string 139 minTLSVersion uint16 140 maxTLSVersion uint16 141 logConnectionState bool 142 expectTLS bool 143 expectmTLS bool 144 nilServer bool 145 }{ 146 { 147 description: "No address", 148 address: "", 149 handler: nil, 150 logConnectionState: false, 151 nilServer: true, 152 }, 153 { 154 description: "Nil handler", 155 address: ":443", 156 handler: nil, 157 logConnectionState: true, 158 }, 159 160 { 161 description: "Invalid cert file", 162 address: ":443", 163 handler: new(mockHandler), 164 logConnectionState: true, 165 certFile: []string{"cert.pem", "missing-pair.pem"}, 166 keyFile: []string{"key.pem"}, 167 nilServer: true, 168 }, 169 170 { 171 description: "Invalid key file", 172 address: ":443", 173 handler: new(mockHandler), 174 logConnectionState: true, 175 certFile: []string{"cert.pem"}, 176 keyFile: []string{"key.pem", "missing-pair.pem"}, 177 nilServer: true, 178 }, 179 180 { 181 description: "Invalid client CA cert file", 182 address: ":443", 183 handler: new(mockHandler), 184 logConnectionState: true, 185 certFile: []string{"cert.pem"}, 186 keyFile: []string{"key.pem"}, 187 clientCACertFile: "missing-file.pem", 188 nilServer: true, 189 }, 190 191 { 192 description: "Invalid client CA cert file", 193 address: ":443", 194 handler: new(mockHandler), 195 logConnectionState: true, 196 certFile: []string{"cert.pem"}, 197 keyFile: []string{"key.pem"}, 198 clientCACertFile: "missing-file.pem", 199 nilServer: true, 200 }, 201 202 { 203 description: "TLS enabled", 204 address: ":443", 205 handler: new(mockHandler), 206 logConnectionState: true, 207 certFile: []string{"cert.pem"}, 208 keyFile: []string{"key.pem"}, 209 minTLSVersion: tls.VersionTLS11, 210 maxTLSVersion: tls.VersionTLS12, 211 expectTLS: true, 212 }, 213 214 { 215 description: "mTLS enabled", 216 address: ":443", 217 handler: new(mockHandler), 218 logConnectionState: true, 219 certFile: []string{"cert.pem"}, 220 keyFile: []string{"key.pem"}, 221 clientCACertFile: "client_ca.pem", 222 minTLSVersion: tls.VersionTLS12, 223 maxTLSVersion: tls.VersionTLS13, 224 expectTLS: true, 225 expectmTLS: true, 226 }, 227 } 228 ) 229 230 for _, record := range testData { 231 t.Run(record.description, func(t *testing.T) { 232 var ( 233 verify, logger = newTestLogger() 234 basic = Basic{ 235 Name: expectedName, 236 Address: record.address, 237 LogConnectionState: record.logConnectionState, 238 CertificateFile: record.certFile, 239 KeyFile: record.keyFile, 240 ClientCACertFile: record.clientCACertFile, 241 MaxVersion: record.maxTLSVersion, 242 MinVersion: record.minTLSVersion, 243 DisableKeepAlives: true, 244 } 245 ) 246 247 server := basic.New(logger, record.handler) 248 249 if !record.nilServer { 250 require.NotNil(server) 251 assert.Equal(record.address, server.Addr) 252 assert.Equal(record.handler, server.Handler) 253 assertErrorLog(assert, verify, expectedName, server.ErrorLog) 254 255 if record.logConnectionState { 256 assertConnState(assert, verify, server.ConnState) 257 } else { 258 assert.Nil(server.ConnState) 259 } 260 261 if record.expectTLS { 262 assert.NotZero(server.TLSConfig.MaxVersion) 263 assert.Equal(record.minTLSVersion, server.TLSConfig.MinVersion) 264 assert.Equal(record.maxTLSVersion, server.TLSConfig.MaxVersion) 265 assert.NotNil(server.TLSConfig.Certificates) 266 if record.expectmTLS { 267 assert.NotNil(server.TLSConfig.ClientCAs) 268 assert.Equal(tls.RequireAndVerifyClientCert, server.TLSConfig.ClientAuth) 269 } 270 } else { 271 assert.Nil(server.TLSConfig) 272 } 273 } else { 274 require.Nil(server) 275 } 276 277 if record.handler != nil { 278 record.handler.AssertExpectations(t) 279 } 280 }) 281 } 282 } 283 284 func TestHealthNew(t *testing.T) { 285 const ( 286 expectedName = "TestHealthNew" 287 expectedLogInterval time.Duration = 45 * time.Second 288 ) 289 290 var ( 291 assert = assert.New(t) 292 require = require.New(t) 293 294 expectedHandlerType *http.ServeMux = nil 295 296 testData = []struct { 297 address string 298 logConnectionState bool 299 options []string 300 }{ 301 {"", false, nil}, 302 {"", false, []string{}}, 303 {"", false, []string{"Value1"}}, 304 {"", false, []string{"Value1", "Value2"}}, 305 306 {"", true, nil}, 307 {"", true, []string{}}, 308 {"", true, []string{"Value1"}}, 309 {"", true, []string{"Value1", "Value2"}}, 310 311 {":901", false, nil}, 312 {":1987", false, []string{}}, 313 {":http", false, []string{"Value1"}}, 314 {":https", false, []string{"Value1", "Value2"}}, 315 316 {"locahost:9001", true, nil}, 317 {":57899", true, []string{}}, 318 {":ftp", true, []string{"Value1"}}, 319 {":0", true, []string{"Value1", "Value2"}}, 320 } 321 ) 322 323 for _, record := range testData { 324 t.Logf("%#v", record) 325 326 var ( 327 verify, logger = newTestLogger() 328 health = Health{ 329 Name: expectedName, 330 Address: record.address, 331 LogConnectionState: record.logConnectionState, 332 LogInterval: expectedLogInterval, 333 Options: record.options, 334 } 335 336 handler, server = health.New(logger, alice.New(), nil) 337 ) 338 339 if len(record.address) > 0 { 340 require.NotNil(handler) 341 require.NotNil(server) 342 assert.Equal(record.address, server.Addr) 343 assert.IsType(expectedHandlerType, server.Handler) 344 assertErrorLog(assert, verify, expectedName, server.ErrorLog) 345 346 if record.logConnectionState { 347 assertConnState(assert, verify, server.ConnState) 348 } else { 349 assert.Nil(server.ConnState) 350 } 351 } else { 352 require.Nil(handler) 353 require.Nil(server) 354 } 355 } 356 } 357 358 func TestWebPANoPrimaryAddress(t *testing.T) { 359 var ( 360 assert = assert.New(t) 361 require = require.New(t) 362 ) 363 364 r, err := xmetrics.NewRegistry(nil, Metrics) 365 require.NoError(err) 366 require.NotNil(r) 367 368 var ( 369 handler = new(mockHandler) 370 webPA = WebPA{} 371 372 _, logger = newTestLogger() 373 monitor, runnable, done = webPA.Prepare(logger, nil, xmetrics.MustNewRegistry(nil), handler) 374 ) 375 376 assert.Nil(monitor) 377 require.NotNil(runnable) 378 assert.NotNil(done) 379 380 var ( 381 waitGroup = new(sync.WaitGroup) 382 shutdown = make(chan struct{}) 383 ) 384 385 defer close(shutdown) 386 assert.Equal(ErrorNoPrimaryAddress, runnable.Run(waitGroup, shutdown)) 387 waitGroup.Wait() // nothing should have incremented the wait group 388 handler.AssertExpectations(t) 389 } 390 391 func TestWebPA(t *testing.T) { 392 var ( 393 assert = assert.New(t) 394 require = require.New(t) 395 handler = new(mockHandler) 396 ) 397 398 r, err := xmetrics.NewRegistry(nil, Metrics) 399 require.NoError(err) 400 require.NotNil(r) 401 402 var ( 403 // synthesize a WebPA instance that will start everything, 404 // close to how it would be unmarshalled from Viper. 405 webPA = WebPA{ 406 Primary: Basic{ 407 Name: "test", 408 Address: ":0", 409 }, 410 Alternate: Basic{ 411 Name: "test.alternate", 412 Address: ":0", 413 }, 414 Health: Health{ 415 Name: "test.health", 416 Address: ":0", 417 LogInterval: 60 * time.Minute, 418 Options: []string{"Option1", "Option2"}, 419 }, 420 Pprof: Basic{ 421 Name: "test.pprof", 422 Address: ":0", 423 }, 424 425 Metric: Metric{ 426 Name: "test.metrics", 427 Address: ":0", 428 }, 429 } 430 431 _, logger = newTestLogger() 432 monitor, runnable, done = webPA.Prepare(logger, nil, xmetrics.MustNewRegistry(nil), handler) 433 ) 434 435 assert.NotNil(monitor) 436 require.NotNil(runnable) 437 assert.NotNil(done) 438 439 var ( 440 waitGroup = new(sync.WaitGroup) 441 shutdown = make(chan struct{}) 442 ) 443 444 assert.Nil(runnable.Run(waitGroup, shutdown)) 445 close(shutdown) 446 waitGroup.Wait() // the http.Server instances will still be running after this returns 447 handler.AssertExpectations(t) 448 }