github.com/palantir/witchcraft-go-server/v2@v2.76.0/integration/server_test.go (about) 1 // Copyright (c) 2018 Palantir Technologies. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package integration 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/tls" 21 "encoding/json" 22 "fmt" 23 "io/ioutil" 24 "net/http" 25 "os" 26 "path" 27 "testing" 28 "time" 29 30 "github.com/nmiyake/pkg/dirs" 31 "github.com/palantir/conjure-go-runtime/v2/conjure-go-contract/errors" 32 "github.com/palantir/pkg/httpserver" 33 "github.com/palantir/pkg/tlsconfig" 34 "github.com/palantir/witchcraft-go-health/conjure/witchcraft/api/health" 35 "github.com/palantir/witchcraft-go-server/v2/config" 36 "github.com/palantir/witchcraft-go-server/v2/status" 37 "github.com/palantir/witchcraft-go-server/v2/witchcraft" 38 "github.com/stretchr/testify/assert" 39 "github.com/stretchr/testify/require" 40 ) 41 42 // TestServerStarts ensures that a Witchcraft server starts and is able to serve requests. 43 // It also verifies the service log output of starting two routers (application and management). 44 func TestServerStarts(t *testing.T) { 45 logOutputBuffer := &bytes.Buffer{} 46 server, _, _, _, cleanup := createAndRunTestServer(t, nil, logOutputBuffer) 47 defer func() { 48 _ = server.Close() 49 }() 50 defer cleanup() 51 52 // verify service log output 53 msgs := getLogFileMessages(t, logOutputBuffer.Bytes()) 54 assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs) 55 } 56 57 // TestServerShutdown verifies the behavior when shutting down a Witchcraft server. There are two variants, graceful and abrupt. 58 // Graceful: server.Shutdown, which allows in-flight requests to complete. We test this using a route which sleeps one second. 59 // Abrupt: server.Close, which terminates the server immediately and sends EOF on active connections. 60 func TestServerShutdown(t *testing.T) { 61 runTest := func(t *testing.T, graceful bool) { 62 logOutputBuffer := &bytes.Buffer{} 63 calledC := make(chan bool, 1) 64 doneC := make(chan bool, 1) 65 initFn := func(ctx context.Context, info witchcraft.InitInfo) (func(), error) { 66 return nil, info.Router.Get("/wait", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 67 calledC <- true 68 // just wait for 1 second to hold the connection open. 69 time.Sleep(1 * time.Second) 70 doneC <- true 71 })) 72 } 73 server, port, _, serverErr, cleanup := createAndRunTestServer(t, initFn, logOutputBuffer) 74 defer func() { 75 _ = server.Close() 76 }() 77 defer cleanup() 78 79 go func() { 80 resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/wait", port)) 81 if graceful { 82 if err != nil { 83 panic(fmt.Errorf("graceful shutdown request failed: %v", err)) 84 } 85 if resp == nil { 86 panic(fmt.Errorf("returned response was nil")) 87 } 88 if resp.StatusCode != 200 { 89 panic(fmt.Errorf("server did not respond with 200 OK. Got %s", resp.Status)) 90 } 91 } else { 92 if err == nil { 93 panic(fmt.Errorf("request allowed to finish successfully")) 94 } 95 } 96 }() 97 98 timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), time.Second) 99 defer timeoutCancel() 100 101 select { 102 case <-calledC: 103 case <-timeoutCtx.Done(): 104 require.NoError(t, timeoutCtx.Err(), "timed out waiting for called") 105 } 106 107 select { 108 case done := <-doneC: 109 require.False(t, done, "Connection was already closed!") 110 default: 111 } 112 113 if graceful { 114 // Gracefully shut down server. This should block until the handler has completed. 115 require.NoError(t, server.Shutdown(context.Background())) 116 } else { 117 // Abruptly close server. This will send EOF on open connections. 118 require.NoError(t, server.Close()) 119 } 120 121 var done bool 122 select { 123 case done = <-doneC: 124 default: 125 } 126 127 if graceful { 128 require.True(t, done, "Handler didn't execute the whole way!") 129 130 // verify service log output 131 msgs := getLogFileMessages(t, logOutputBuffer.Bytes()) 132 assert.Equal(t, []string{"Listening to https", "Listening to https", "Shutting down server", "example was closed", "example-management was closed"}, msgs) 133 } else { 134 require.False(t, done, "Handler allowed to execute the whole way") 135 } 136 137 select { 138 case err := <-serverErr: 139 require.NoError(t, err) 140 default: 141 } 142 } 143 144 t.Run("graceful shutdown", func(t *testing.T) { 145 runTest(t, true) 146 }) 147 t.Run("abrupt shutdown", func(t *testing.T) { 148 runTest(t, false) 149 }) 150 } 151 152 // TestEmptyPathHandler verifies that a route registered at the default path ("/") is served correctly. 153 func TestEmptyPathHandler(t *testing.T) { 154 logOutputBuffer := &bytes.Buffer{} 155 var called bool 156 server, port, _, serverErr, cleanup := createAndRunTestServer(t, func(ctx context.Context, info witchcraft.InitInfo) (deferFn func(), rErr error) { 157 return nil, info.Router.Get("/", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 158 called = true 159 })) 160 }, logOutputBuffer) 161 defer func() { 162 _ = server.Close() 163 }() 164 defer cleanup() 165 166 resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/%s", port, basePath)) 167 require.NoError(t, err) 168 require.NotNil(t, resp) 169 require.Equal(t, "200 OK", resp.Status) 170 assert.True(t, called, "called boolean was not set to true (http handler did not execute)") 171 172 // verify service log output 173 msgs := getLogFileMessages(t, logOutputBuffer.Bytes()) 174 assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs) 175 176 select { 177 case err := <-serverErr: 178 require.NoError(t, err) 179 default: 180 } 181 } 182 183 // TestManagementRoutes verifies the behavior of /liveness, /readiness, and /health endpoints in their default configuration. 184 // There are two variants - one with a dedicated management port/router and one using the application port/router. 185 func TestManagementRoutes(t *testing.T) { 186 runTests := func(t *testing.T, mgmtPort int) { 187 t.Run("Liveness", func(t *testing.T) { 188 resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.LivenessEndpoint)) 189 require.NoError(t, err) 190 require.NotNil(t, resp) 191 require.Equal(t, "200 OK", resp.Status) 192 }) 193 t.Run("Readiness", func(t *testing.T) { 194 resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.ReadinessEndpoint)) 195 require.NoError(t, err) 196 require.NotNil(t, resp) 197 require.Equal(t, "200 OK", resp.Status) 198 }) 199 t.Run("Health", func(t *testing.T) { 200 resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.HealthEndpoint)) 201 require.NoError(t, err) 202 require.NotNil(t, resp) 203 require.Equal(t, "200 OK", resp.Status) 204 body, err := ioutil.ReadAll(resp.Body) 205 require.NoError(t, err) 206 require.NoError(t, resp.Body.Close()) 207 var healthResp health.HealthStatus 208 require.NoError(t, json.Unmarshal(body, &healthResp)) 209 require.Equal(t, health.HealthStatus{Checks: map[health.CheckType]health.HealthCheckResult{ 210 "CONFIG_RELOAD": {Type: "CONFIG_RELOAD", State: health.New_HealthState(health.HealthState_HEALTHY), Params: map[string]interface{}{}}, 211 "SERVER_STATUS": {Type: "SERVER_STATUS", State: health.New_HealthState(health.HealthState_HEALTHY), Params: map[string]interface{}{}}, 212 }}, healthResp) 213 }) 214 } 215 216 t.Run("dedicated port", func(t *testing.T) { 217 server, _, managementPort, serverErr, cleanup := createAndRunTestServer(t, nil, ioutil.Discard) 218 defer func() { 219 _ = server.Close() 220 }() 221 defer cleanup() 222 223 runTests(t, managementPort) 224 225 select { 226 case err := <-serverErr: 227 require.NoError(t, err) 228 default: 229 } 230 }) 231 232 t.Run("same port", func(t *testing.T) { 233 port, err := httpserver.AvailablePort() 234 require.NoError(t, err) 235 server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, nil, ioutil.Discard, createTestServer) 236 237 defer func() { 238 _ = server.Close() 239 }() 240 defer cleanup() 241 242 runTests(t, port) 243 244 select { 245 case err := <-serverErr: 246 require.NoError(t, err) 247 default: 248 } 249 }) 250 } 251 252 // TestClientTLS verifies that a Witchcraft server configured to require client TLS authentication enforces that config. 253 func TestClientTLS(t *testing.T) { 254 testDir, cleanup, err := dirs.TempDir("", "") 255 require.NoError(t, err) 256 defer cleanup() 257 258 wd, err := os.Getwd() 259 require.NoError(t, err) 260 defer func() { 261 err := os.Chdir(wd) 262 require.NoError(t, err) 263 }() 264 265 err = os.Chdir(testDir) 266 require.NoError(t, err) 267 268 port, err := httpserver.AvailablePort() 269 require.NoError(t, err) 270 managementPort, err := httpserver.AvailablePort() 271 require.NoError(t, err) 272 273 server := witchcraft.NewServer(). 274 WithClientAuth(tls.RequireAndVerifyClientCert). 275 WithECVKeyProvider(witchcraft.ECVKeyNoOp()). 276 WithRuntimeConfig(struct{}{}). 277 WithInstallConfig(config.Install{ 278 ProductName: productName, 279 UseConsoleLog: true, 280 Server: config.Server{ 281 Address: "localhost", 282 Port: port, 283 ManagementPort: managementPort, 284 ContextPath: basePath, 285 CertFile: path.Join(wd, "testdata/server-cert.pem"), 286 KeyFile: path.Join(wd, "testdata/server-key.pem"), 287 ClientCAFiles: []string{path.Join(wd, "testdata/ca-cert.pem")}, 288 }, 289 }) 290 291 serverChan := make(chan error) 292 go func() { 293 serverChan <- server.Start() 294 }() 295 296 select { 297 case err := <-serverChan: 298 require.NoError(t, err) 299 default: 300 } 301 302 // Management port should not require client certs. 303 ready := <-waitForTestServerReady(managementPort, path.Join(basePath, status.LivenessEndpoint), 5*time.Second) 304 if !ready { 305 errMsg := "timed out waiting for server to start. This could mean management server incorrectly requested client certs." 306 select { 307 case err := <-serverChan: 308 errMsg = fmt.Sprintf("%s: %+v", errMsg, err) 309 } 310 require.Fail(t, errMsg) 311 } 312 require.True(t, ready) 313 314 defer func() { 315 require.NoError(t, server.Close()) 316 }() 317 318 // Assert regular client receives error 319 _, err = testServerClient().Get(fmt.Sprintf("https://localhost:%d/example", port)) 320 require.Error(t, err, "Client allowed to make request without cert") 321 assert.ErrorContains(t, err, "tls", "expected error to contain 'tls' in its string representation") 322 323 // Assert client w/ certs does not receive error 324 tlsConf, err := tlsconfig.NewClientConfig(tlsconfig.ClientRootCAFiles(path.Join(wd, "testdata/ca-cert.pem")), tlsconfig.ClientKeyPairFiles(path.Join(wd, "testdata/client-cert.pem"), path.Join(wd, "testdata/client-key.pem"))) 325 require.NoError(t, err) 326 _, err = (&http.Client{Transport: &http.Transport{TLSClientConfig: tlsConf}}).Get(fmt.Sprintf("https://localhost:%d/example", port)) 327 require.NoError(t, err) 328 } 329 330 func TestDefaultNotFoundHandler(t *testing.T) { 331 logOutputBuffer := &bytes.Buffer{} 332 server, port, _, serverErr, cleanup := createAndRunTestServer(t, func(ctx context.Context, info witchcraft.InitInfo) (deferFn func(), rErr error) { 333 return nil, info.Router.Get("/foo", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 334 rw.WriteHeader(200) 335 })) 336 }, logOutputBuffer) 337 defer func() { 338 _ = server.Close() 339 }() 340 defer cleanup() 341 342 const testTraceID = "1000000000000001" 343 req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://localhost:%d/TestDefaultNotFoundHandler", port), nil) 344 require.NoError(t, err) 345 req.Header.Set("X-B3-TraceId", testTraceID) 346 347 resp, err := testServerClient().Do(req) 348 require.NoError(t, err) 349 require.NotNil(t, resp) 350 351 assert.Equal(t, "404 Not Found", resp.Status) 352 353 body, err := ioutil.ReadAll(resp.Body) 354 require.NoError(t, err) 355 cerr, err := errors.UnmarshalError(body) 356 if assert.NoError(t, err) { 357 assert.Equal(t, errors.NotFound, cerr.Code()) 358 } 359 360 t.Run("request log", func(t *testing.T) { 361 // find request log for 404, assert trace ID matches request 362 reqlogs := getLogMessagesOfType(t, "request.2", logOutputBuffer.Bytes()) 363 var notFoundReqLogs []map[string]interface{} 364 for _, reqlog := range reqlogs { 365 if reqlog["traceId"] == testTraceID { 366 notFoundReqLogs = append(notFoundReqLogs, reqlog) 367 } 368 } 369 if assert.Len(t, notFoundReqLogs, 1, "expected exactly one request log with trace id") { 370 reqlog := notFoundReqLogs[0] 371 assert.Equal(t, 404.0, reqlog["status"]) 372 assert.Equal(t, "POST", reqlog["method"]) 373 assert.Equal(t, "/*", reqlog["path"]) 374 } 375 }) 376 377 t.Run("service log", func(t *testing.T) { 378 // find service log for 404, assert trace ID matches request 379 svclogs := getLogMessagesOfType(t, "service.1", logOutputBuffer.Bytes()) 380 var notFoundSvcLogs []map[string]interface{} 381 for _, svclog := range svclogs { 382 if svclog["traceId"] == testTraceID { 383 notFoundSvcLogs = append(notFoundSvcLogs, svclog) 384 } 385 } 386 if assert.Len(t, notFoundSvcLogs, 1, "expected exactly one service log with trace id") { 387 svclog := notFoundSvcLogs[0] 388 assert.Equal(t, "INFO", svclog["level"]) 389 assert.Equal(t, "Error handling request", svclog["message"]) 390 assert.Equal(t, map[string]interface{}{ 391 "errorInstanceId": cerr.InstanceID().String(), 392 "errorName": "Default:NotFound", 393 }, svclog["params"]) 394 } 395 }) 396 397 select { 398 case err := <-serverErr: 399 require.NoError(t, err) 400 default: 401 } 402 }