github.com/palantir/witchcraft-go-server/v2@v2.76.0/integration/server_test.go (about)

     1  // Copyright (c) 2018 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integration
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/tls"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"os"
    26  	"path"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/nmiyake/pkg/dirs"
    31  	"github.com/palantir/conjure-go-runtime/v2/conjure-go-contract/errors"
    32  	"github.com/palantir/pkg/httpserver"
    33  	"github.com/palantir/pkg/tlsconfig"
    34  	"github.com/palantir/witchcraft-go-health/conjure/witchcraft/api/health"
    35  	"github.com/palantir/witchcraft-go-server/v2/config"
    36  	"github.com/palantir/witchcraft-go-server/v2/status"
    37  	"github.com/palantir/witchcraft-go-server/v2/witchcraft"
    38  	"github.com/stretchr/testify/assert"
    39  	"github.com/stretchr/testify/require"
    40  )
    41  
    42  // TestServerStarts ensures that a Witchcraft server starts and is able to serve requests.
    43  // It also verifies the service log output of starting two routers (application and management).
    44  func TestServerStarts(t *testing.T) {
    45  	logOutputBuffer := &bytes.Buffer{}
    46  	server, _, _, _, cleanup := createAndRunTestServer(t, nil, logOutputBuffer)
    47  	defer func() {
    48  		_ = server.Close()
    49  	}()
    50  	defer cleanup()
    51  
    52  	// verify service log output
    53  	msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
    54  	assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs)
    55  }
    56  
    57  // TestServerShutdown verifies the behavior when shutting down a Witchcraft server. There are two variants, graceful and abrupt.
    58  // Graceful: server.Shutdown, which allows in-flight requests to complete. We test this using a route which sleeps one second.
    59  // Abrupt: server.Close, which terminates the server immediately and sends EOF on active connections.
    60  func TestServerShutdown(t *testing.T) {
    61  	runTest := func(t *testing.T, graceful bool) {
    62  		logOutputBuffer := &bytes.Buffer{}
    63  		calledC := make(chan bool, 1)
    64  		doneC := make(chan bool, 1)
    65  		initFn := func(ctx context.Context, info witchcraft.InitInfo) (func(), error) {
    66  			return nil, info.Router.Get("/wait", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    67  				calledC <- true
    68  				// just wait for 1 second to hold the connection open.
    69  				time.Sleep(1 * time.Second)
    70  				doneC <- true
    71  			}))
    72  		}
    73  		server, port, _, serverErr, cleanup := createAndRunTestServer(t, initFn, logOutputBuffer)
    74  		defer func() {
    75  			_ = server.Close()
    76  		}()
    77  		defer cleanup()
    78  
    79  		go func() {
    80  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/wait", port))
    81  			if graceful {
    82  				if err != nil {
    83  					panic(fmt.Errorf("graceful shutdown request failed: %v", err))
    84  				}
    85  				if resp == nil {
    86  					panic(fmt.Errorf("returned response was nil"))
    87  				}
    88  				if resp.StatusCode != 200 {
    89  					panic(fmt.Errorf("server did not respond with 200 OK. Got %s", resp.Status))
    90  				}
    91  			} else {
    92  				if err == nil {
    93  					panic(fmt.Errorf("request allowed to finish successfully"))
    94  				}
    95  			}
    96  		}()
    97  
    98  		timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), time.Second)
    99  		defer timeoutCancel()
   100  
   101  		select {
   102  		case <-calledC:
   103  		case <-timeoutCtx.Done():
   104  			require.NoError(t, timeoutCtx.Err(), "timed out waiting for called")
   105  		}
   106  
   107  		select {
   108  		case done := <-doneC:
   109  			require.False(t, done, "Connection was already closed!")
   110  		default:
   111  		}
   112  
   113  		if graceful {
   114  			// Gracefully shut down server. This should block until the handler has completed.
   115  			require.NoError(t, server.Shutdown(context.Background()))
   116  		} else {
   117  			// Abruptly close server. This will send EOF on open connections.
   118  			require.NoError(t, server.Close())
   119  		}
   120  
   121  		var done bool
   122  		select {
   123  		case done = <-doneC:
   124  		default:
   125  		}
   126  
   127  		if graceful {
   128  			require.True(t, done, "Handler didn't execute the whole way!")
   129  
   130  			// verify service log output
   131  			msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
   132  			assert.Equal(t, []string{"Listening to https", "Listening to https", "Shutting down server", "example was closed", "example-management was closed"}, msgs)
   133  		} else {
   134  			require.False(t, done, "Handler allowed to execute the whole way")
   135  		}
   136  
   137  		select {
   138  		case err := <-serverErr:
   139  			require.NoError(t, err)
   140  		default:
   141  		}
   142  	}
   143  
   144  	t.Run("graceful shutdown", func(t *testing.T) {
   145  		runTest(t, true)
   146  	})
   147  	t.Run("abrupt shutdown", func(t *testing.T) {
   148  		runTest(t, false)
   149  	})
   150  }
   151  
   152  // TestEmptyPathHandler verifies that a route registered at the default path ("/") is served correctly.
   153  func TestEmptyPathHandler(t *testing.T) {
   154  	logOutputBuffer := &bytes.Buffer{}
   155  	var called bool
   156  	server, port, _, serverErr, cleanup := createAndRunTestServer(t, func(ctx context.Context, info witchcraft.InitInfo) (deferFn func(), rErr error) {
   157  		return nil, info.Router.Get("/", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   158  			called = true
   159  		}))
   160  	}, logOutputBuffer)
   161  	defer func() {
   162  		_ = server.Close()
   163  	}()
   164  	defer cleanup()
   165  
   166  	resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/%s", port, basePath))
   167  	require.NoError(t, err)
   168  	require.NotNil(t, resp)
   169  	require.Equal(t, "200 OK", resp.Status)
   170  	assert.True(t, called, "called boolean was not set to true (http handler did not execute)")
   171  
   172  	// verify service log output
   173  	msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
   174  	assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs)
   175  
   176  	select {
   177  	case err := <-serverErr:
   178  		require.NoError(t, err)
   179  	default:
   180  	}
   181  }
   182  
   183  // TestManagementRoutes verifies the behavior of /liveness, /readiness, and /health endpoints in their default configuration.
   184  // There are two variants - one with a dedicated management port/router and one using the application port/router.
   185  func TestManagementRoutes(t *testing.T) {
   186  	runTests := func(t *testing.T, mgmtPort int) {
   187  		t.Run("Liveness", func(t *testing.T) {
   188  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.LivenessEndpoint))
   189  			require.NoError(t, err)
   190  			require.NotNil(t, resp)
   191  			require.Equal(t, "200 OK", resp.Status)
   192  		})
   193  		t.Run("Readiness", func(t *testing.T) {
   194  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.ReadinessEndpoint))
   195  			require.NoError(t, err)
   196  			require.NotNil(t, resp)
   197  			require.Equal(t, "200 OK", resp.Status)
   198  		})
   199  		t.Run("Health", func(t *testing.T) {
   200  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.HealthEndpoint))
   201  			require.NoError(t, err)
   202  			require.NotNil(t, resp)
   203  			require.Equal(t, "200 OK", resp.Status)
   204  			body, err := ioutil.ReadAll(resp.Body)
   205  			require.NoError(t, err)
   206  			require.NoError(t, resp.Body.Close())
   207  			var healthResp health.HealthStatus
   208  			require.NoError(t, json.Unmarshal(body, &healthResp))
   209  			require.Equal(t, health.HealthStatus{Checks: map[health.CheckType]health.HealthCheckResult{
   210  				"CONFIG_RELOAD": {Type: "CONFIG_RELOAD", State: health.New_HealthState(health.HealthState_HEALTHY), Params: map[string]interface{}{}},
   211  				"SERVER_STATUS": {Type: "SERVER_STATUS", State: health.New_HealthState(health.HealthState_HEALTHY), Params: map[string]interface{}{}},
   212  			}}, healthResp)
   213  		})
   214  	}
   215  
   216  	t.Run("dedicated port", func(t *testing.T) {
   217  		server, _, managementPort, serverErr, cleanup := createAndRunTestServer(t, nil, ioutil.Discard)
   218  		defer func() {
   219  			_ = server.Close()
   220  		}()
   221  		defer cleanup()
   222  
   223  		runTests(t, managementPort)
   224  
   225  		select {
   226  		case err := <-serverErr:
   227  			require.NoError(t, err)
   228  		default:
   229  		}
   230  	})
   231  
   232  	t.Run("same port", func(t *testing.T) {
   233  		port, err := httpserver.AvailablePort()
   234  		require.NoError(t, err)
   235  		server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, nil, ioutil.Discard, createTestServer)
   236  
   237  		defer func() {
   238  			_ = server.Close()
   239  		}()
   240  		defer cleanup()
   241  
   242  		runTests(t, port)
   243  
   244  		select {
   245  		case err := <-serverErr:
   246  			require.NoError(t, err)
   247  		default:
   248  		}
   249  	})
   250  }
   251  
   252  // TestClientTLS verifies that a Witchcraft server configured to require client TLS authentication enforces that config.
   253  func TestClientTLS(t *testing.T) {
   254  	testDir, cleanup, err := dirs.TempDir("", "")
   255  	require.NoError(t, err)
   256  	defer cleanup()
   257  
   258  	wd, err := os.Getwd()
   259  	require.NoError(t, err)
   260  	defer func() {
   261  		err := os.Chdir(wd)
   262  		require.NoError(t, err)
   263  	}()
   264  
   265  	err = os.Chdir(testDir)
   266  	require.NoError(t, err)
   267  
   268  	port, err := httpserver.AvailablePort()
   269  	require.NoError(t, err)
   270  	managementPort, err := httpserver.AvailablePort()
   271  	require.NoError(t, err)
   272  
   273  	server := witchcraft.NewServer().
   274  		WithClientAuth(tls.RequireAndVerifyClientCert).
   275  		WithECVKeyProvider(witchcraft.ECVKeyNoOp()).
   276  		WithRuntimeConfig(struct{}{}).
   277  		WithInstallConfig(config.Install{
   278  			ProductName:   productName,
   279  			UseConsoleLog: true,
   280  			Server: config.Server{
   281  				Address:        "localhost",
   282  				Port:           port,
   283  				ManagementPort: managementPort,
   284  				ContextPath:    basePath,
   285  				CertFile:       path.Join(wd, "testdata/server-cert.pem"),
   286  				KeyFile:        path.Join(wd, "testdata/server-key.pem"),
   287  				ClientCAFiles:  []string{path.Join(wd, "testdata/ca-cert.pem")},
   288  			},
   289  		})
   290  
   291  	serverChan := make(chan error)
   292  	go func() {
   293  		serverChan <- server.Start()
   294  	}()
   295  
   296  	select {
   297  	case err := <-serverChan:
   298  		require.NoError(t, err)
   299  	default:
   300  	}
   301  
   302  	// Management port should not require client certs.
   303  	ready := <-waitForTestServerReady(managementPort, path.Join(basePath, status.LivenessEndpoint), 5*time.Second)
   304  	if !ready {
   305  		errMsg := "timed out waiting for server to start. This could mean management server incorrectly requested client certs."
   306  		select {
   307  		case err := <-serverChan:
   308  			errMsg = fmt.Sprintf("%s: %+v", errMsg, err)
   309  		}
   310  		require.Fail(t, errMsg)
   311  	}
   312  	require.True(t, ready)
   313  
   314  	defer func() {
   315  		require.NoError(t, server.Close())
   316  	}()
   317  
   318  	// Assert regular client receives error
   319  	_, err = testServerClient().Get(fmt.Sprintf("https://localhost:%d/example", port))
   320  	require.Error(t, err, "Client allowed to make request without cert")
   321  	assert.ErrorContains(t, err, "tls", "expected error to contain 'tls' in its string representation")
   322  
   323  	// Assert client w/ certs does not receive error
   324  	tlsConf, err := tlsconfig.NewClientConfig(tlsconfig.ClientRootCAFiles(path.Join(wd, "testdata/ca-cert.pem")), tlsconfig.ClientKeyPairFiles(path.Join(wd, "testdata/client-cert.pem"), path.Join(wd, "testdata/client-key.pem")))
   325  	require.NoError(t, err)
   326  	_, err = (&http.Client{Transport: &http.Transport{TLSClientConfig: tlsConf}}).Get(fmt.Sprintf("https://localhost:%d/example", port))
   327  	require.NoError(t, err)
   328  }
   329  
   330  func TestDefaultNotFoundHandler(t *testing.T) {
   331  	logOutputBuffer := &bytes.Buffer{}
   332  	server, port, _, serverErr, cleanup := createAndRunTestServer(t, func(ctx context.Context, info witchcraft.InitInfo) (deferFn func(), rErr error) {
   333  		return nil, info.Router.Get("/foo", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   334  			rw.WriteHeader(200)
   335  		}))
   336  	}, logOutputBuffer)
   337  	defer func() {
   338  		_ = server.Close()
   339  	}()
   340  	defer cleanup()
   341  
   342  	const testTraceID = "1000000000000001"
   343  	req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://localhost:%d/TestDefaultNotFoundHandler", port), nil)
   344  	require.NoError(t, err)
   345  	req.Header.Set("X-B3-TraceId", testTraceID)
   346  
   347  	resp, err := testServerClient().Do(req)
   348  	require.NoError(t, err)
   349  	require.NotNil(t, resp)
   350  
   351  	assert.Equal(t, "404 Not Found", resp.Status)
   352  
   353  	body, err := ioutil.ReadAll(resp.Body)
   354  	require.NoError(t, err)
   355  	cerr, err := errors.UnmarshalError(body)
   356  	if assert.NoError(t, err) {
   357  		assert.Equal(t, errors.NotFound, cerr.Code())
   358  	}
   359  
   360  	t.Run("request log", func(t *testing.T) {
   361  		// find request log for 404, assert trace ID matches request
   362  		reqlogs := getLogMessagesOfType(t, "request.2", logOutputBuffer.Bytes())
   363  		var notFoundReqLogs []map[string]interface{}
   364  		for _, reqlog := range reqlogs {
   365  			if reqlog["traceId"] == testTraceID {
   366  				notFoundReqLogs = append(notFoundReqLogs, reqlog)
   367  			}
   368  		}
   369  		if assert.Len(t, notFoundReqLogs, 1, "expected exactly one request log with trace id") {
   370  			reqlog := notFoundReqLogs[0]
   371  			assert.Equal(t, 404.0, reqlog["status"])
   372  			assert.Equal(t, "POST", reqlog["method"])
   373  			assert.Equal(t, "/*", reqlog["path"])
   374  		}
   375  	})
   376  
   377  	t.Run("service log", func(t *testing.T) {
   378  		// find service log for 404, assert trace ID matches request
   379  		svclogs := getLogMessagesOfType(t, "service.1", logOutputBuffer.Bytes())
   380  		var notFoundSvcLogs []map[string]interface{}
   381  		for _, svclog := range svclogs {
   382  			if svclog["traceId"] == testTraceID {
   383  				notFoundSvcLogs = append(notFoundSvcLogs, svclog)
   384  			}
   385  		}
   386  		if assert.Len(t, notFoundSvcLogs, 1, "expected exactly one service log with trace id") {
   387  			svclog := notFoundSvcLogs[0]
   388  			assert.Equal(t, "INFO", svclog["level"])
   389  			assert.Equal(t, "Error handling request", svclog["message"])
   390  			assert.Equal(t, map[string]interface{}{
   391  				"errorInstanceId": cerr.InstanceID().String(),
   392  				"errorName":       "Default:NotFound",
   393  			}, svclog["params"])
   394  		}
   395  	})
   396  
   397  	select {
   398  	case err := <-serverErr:
   399  		require.NoError(t, err)
   400  	default:
   401  	}
   402  }