github.com/arthur-befumo/witchcraft-go-server@v1.12.0/integration/server_test.go (about)

     1  // Copyright (c) 2018 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integration
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/tls"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"os"
    26  	"path"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/nmiyake/pkg/dirs"
    31  	"github.com/palantir/pkg/httpserver"
    32  	"github.com/palantir/pkg/tlsconfig"
    33  	"github.com/palantir/witchcraft-go-server/config"
    34  	"github.com/palantir/witchcraft-go-server/conjure/witchcraft/api/health"
    35  	"github.com/palantir/witchcraft-go-server/status"
    36  	"github.com/palantir/witchcraft-go-server/witchcraft"
    37  	"github.com/stretchr/testify/assert"
    38  	"github.com/stretchr/testify/require"
    39  )
    40  
    41  // TestServerStarts ensures that a Witchcraft server starts and is able to serve requests.
    42  // It also verifies the service log output of starting two routers (application and management).
    43  func TestServerStarts(t *testing.T) {
    44  	logOutputBuffer := &bytes.Buffer{}
    45  	server, _, _, _, cleanup := createAndRunTestServer(t, nil, logOutputBuffer)
    46  	defer func() {
    47  		_ = server.Close()
    48  	}()
    49  	defer cleanup()
    50  
    51  	// verify service log output
    52  	msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
    53  	assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs)
    54  }
    55  
    56  // TestServerShutdown verifies the behavior when shutting down a Witchcraft server. There are two variants, graceful and abrupt.
    57  // Graceful: server.Shutdown, which allows in-flight requests to complete. We test this using a route which sleeps one second.
    58  // Abrupt: server.Close, which terminates the server immediately and sends EOF on active connections.
    59  func TestServerShutdown(t *testing.T) {
    60  	runTest := func(t *testing.T, graceful bool) {
    61  		logOutputBuffer := &bytes.Buffer{}
    62  		calledC := make(chan bool, 1)
    63  		doneC := make(chan bool, 1)
    64  		initFn := func(ctx context.Context, info witchcraft.InitInfo) (func(), error) {
    65  			return nil, info.Router.Get("/wait", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    66  				calledC <- true
    67  				// just wait for 1 second to hold the connection open.
    68  				time.Sleep(1 * time.Second)
    69  				doneC <- true
    70  			}))
    71  		}
    72  		server, port, _, serverErr, cleanup := createAndRunTestServer(t, initFn, logOutputBuffer)
    73  		defer func() {
    74  			_ = server.Close()
    75  		}()
    76  		defer cleanup()
    77  
    78  		go func() {
    79  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/wait", port))
    80  			if graceful {
    81  				if err != nil {
    82  					panic(fmt.Errorf("graceful shutdown request failed: %v", err))
    83  				}
    84  				if resp == nil {
    85  					panic(fmt.Errorf("returned response was nil"))
    86  				}
    87  				if resp.StatusCode != 200 {
    88  					panic(fmt.Errorf("server did not respond with 200 OK. Got %s", resp.Status))
    89  				}
    90  			} else {
    91  				if err == nil {
    92  					panic(fmt.Errorf("request allowed to finish successfully"))
    93  				}
    94  			}
    95  		}()
    96  
    97  		timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), time.Second)
    98  		defer timeoutCancel()
    99  
   100  		select {
   101  		case <-calledC:
   102  		case <-timeoutCtx.Done():
   103  			require.NoError(t, timeoutCtx.Err(), "timed out waiting for called")
   104  		}
   105  
   106  		select {
   107  		case done := <-doneC:
   108  			require.False(t, done, "Connection was already closed!")
   109  		default:
   110  		}
   111  
   112  		if graceful {
   113  			// Gracefully shut down server. This should block until the handler has completed.
   114  			require.NoError(t, server.Shutdown(context.Background()))
   115  		} else {
   116  			// Abruptly close server. This will send EOF on open connections.
   117  			require.NoError(t, server.Close())
   118  		}
   119  
   120  		var done bool
   121  		select {
   122  		case done = <-doneC:
   123  		default:
   124  		}
   125  
   126  		if graceful {
   127  			require.True(t, done, "Handler didn't execute the whole way!")
   128  
   129  			// verify service log output
   130  			msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
   131  			assert.Equal(t, []string{"Listening to https", "Listening to https", "Shutting down server", "example was closed", "example-management was closed"}, msgs)
   132  		} else {
   133  			require.False(t, done, "Handler allowed to execute the whole way")
   134  		}
   135  
   136  		select {
   137  		case err := <-serverErr:
   138  			require.NoError(t, err)
   139  		default:
   140  		}
   141  	}
   142  
   143  	t.Run("graceful shutdown", func(t *testing.T) {
   144  		runTest(t, true)
   145  	})
   146  	t.Run("abrupt shutdown", func(t *testing.T) {
   147  		runTest(t, false)
   148  	})
   149  }
   150  
   151  // TestEmptyPathHandler verifies that a route registered at the default path ("/") is served correctly.
   152  func TestEmptyPathHandler(t *testing.T) {
   153  	logOutputBuffer := &bytes.Buffer{}
   154  	var called bool
   155  	server, port, _, serverErr, cleanup := createAndRunTestServer(t, func(ctx context.Context, info witchcraft.InitInfo) (deferFn func(), rErr error) {
   156  		return nil, info.Router.Get("/", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   157  			called = true
   158  		}))
   159  	}, logOutputBuffer)
   160  	defer func() {
   161  		_ = server.Close()
   162  	}()
   163  	defer cleanup()
   164  
   165  	resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/%s", port, basePath))
   166  	require.NoError(t, err)
   167  	require.NotNil(t, resp)
   168  	require.Equal(t, "200 OK", resp.Status)
   169  	assert.True(t, called, "called boolean was not set to true (http handler did not execute)")
   170  
   171  	// verify service log output
   172  	msgs := getLogFileMessages(t, logOutputBuffer.Bytes())
   173  	assert.Equal(t, []string{"Listening to https", "Listening to https"}, msgs)
   174  
   175  	select {
   176  	case err := <-serverErr:
   177  		require.NoError(t, err)
   178  	default:
   179  	}
   180  }
   181  
   182  // TestManagementRoutes verifies the behavior of /liveness, /readiness, and /health endpoints in their default configuration.
   183  // There are two variants - one with a dedicated management port/router and one using the application port/router.
   184  func TestManagementRoutes(t *testing.T) {
   185  	runTests := func(t *testing.T, mgmtPort int) {
   186  		t.Run("Liveness", func(t *testing.T) {
   187  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.LivenessEndpoint))
   188  			require.NoError(t, err)
   189  			require.NotNil(t, resp)
   190  			require.Equal(t, "200 OK", resp.Status)
   191  		})
   192  		t.Run("Readiness", func(t *testing.T) {
   193  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.ReadinessEndpoint))
   194  			require.NoError(t, err)
   195  			require.NotNil(t, resp)
   196  			require.Equal(t, "200 OK", resp.Status)
   197  		})
   198  		t.Run("Health", func(t *testing.T) {
   199  			resp, err := testServerClient().Get(fmt.Sprintf("https://localhost:%d/example/%s", mgmtPort, status.HealthEndpoint))
   200  			require.NoError(t, err)
   201  			require.NotNil(t, resp)
   202  			require.Equal(t, "200 OK", resp.Status)
   203  			body, err := ioutil.ReadAll(resp.Body)
   204  			require.NoError(t, err)
   205  			require.NoError(t, resp.Body.Close())
   206  			var healthResp health.HealthStatus
   207  			require.NoError(t, json.Unmarshal(body, &healthResp))
   208  			require.Equal(t, health.HealthStatus{Checks: map[health.CheckType]health.HealthCheckResult{
   209  				"CONFIG_RELOAD": {Type: "CONFIG_RELOAD", State: health.HealthStateHealthy, Params: map[string]interface{}{}},
   210  				"SERVER_STATUS": {Type: "SERVER_STATUS", State: health.HealthStateHealthy, Params: map[string]interface{}{}},
   211  			}}, healthResp)
   212  		})
   213  	}
   214  
   215  	t.Run("dedicated port", func(t *testing.T) {
   216  		server, _, managementPort, serverErr, cleanup := createAndRunTestServer(t, nil, ioutil.Discard)
   217  		defer func() {
   218  			_ = server.Close()
   219  		}()
   220  		defer cleanup()
   221  
   222  		runTests(t, managementPort)
   223  
   224  		select {
   225  		case err := <-serverErr:
   226  			require.NoError(t, err)
   227  		default:
   228  		}
   229  	})
   230  
   231  	t.Run("same port", func(t *testing.T) {
   232  		port, err := httpserver.AvailablePort()
   233  		require.NoError(t, err)
   234  		server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, nil, ioutil.Discard, createTestServer)
   235  
   236  		defer func() {
   237  			_ = server.Close()
   238  		}()
   239  		defer cleanup()
   240  
   241  		runTests(t, port)
   242  
   243  		select {
   244  		case err := <-serverErr:
   245  			require.NoError(t, err)
   246  		default:
   247  		}
   248  	})
   249  }
   250  
   251  // TestClientTLS verifies that a Witchcraft server configured to require client TLS authentication enforces that config.
   252  func TestClientTLS(t *testing.T) {
   253  	testDir, cleanup, err := dirs.TempDir("", "")
   254  	require.NoError(t, err)
   255  	defer cleanup()
   256  
   257  	wd, err := os.Getwd()
   258  	require.NoError(t, err)
   259  	defer func() {
   260  		err := os.Chdir(wd)
   261  		require.NoError(t, err)
   262  	}()
   263  
   264  	err = os.Chdir(testDir)
   265  	require.NoError(t, err)
   266  
   267  	port, err := httpserver.AvailablePort()
   268  	require.NoError(t, err)
   269  	managementPort, err := httpserver.AvailablePort()
   270  	require.NoError(t, err)
   271  
   272  	server := witchcraft.NewServer().
   273  		WithClientAuth(tls.RequireAndVerifyClientCert).
   274  		WithECVKeyProvider(witchcraft.ECVKeyNoOp()).
   275  		WithRuntimeConfig(struct{}{}).
   276  		WithInstallConfig(config.Install{
   277  			ProductName:   productName,
   278  			UseConsoleLog: true,
   279  			Server: config.Server{
   280  				Address:        "localhost",
   281  				Port:           port,
   282  				ManagementPort: managementPort,
   283  				ContextPath:    basePath,
   284  				CertFile:       path.Join(wd, "testdata/server-cert.pem"),
   285  				KeyFile:        path.Join(wd, "testdata/server-key.pem"),
   286  				ClientCAFiles:  []string{path.Join(wd, "testdata/ca-cert.pem")},
   287  			},
   288  		})
   289  
   290  	serverChan := make(chan error)
   291  	go func() {
   292  		serverChan <- server.Start()
   293  	}()
   294  
   295  	select {
   296  	case err := <-serverChan:
   297  		require.NoError(t, err)
   298  	default:
   299  	}
   300  
   301  	// Management port should not require client certs.
   302  	ready := <-waitForTestServerReady(managementPort, path.Join(basePath, status.LivenessEndpoint), 5*time.Second)
   303  	if !ready {
   304  		errMsg := "timed out waiting for server to start. This could mean management server incorrectly requested client certs."
   305  		select {
   306  		case err := <-serverChan:
   307  			errMsg = fmt.Sprintf("%s: %+v", errMsg, err)
   308  		}
   309  		require.Fail(t, errMsg)
   310  	}
   311  	require.True(t, ready)
   312  
   313  	defer func() {
   314  		require.NoError(t, server.Close())
   315  	}()
   316  
   317  	// Assert regular client receives error
   318  	_, err = testServerClient().Get(fmt.Sprintf("https://localhost:%d/example", port))
   319  	require.Error(t, err, "Client allowed to make request without cert")
   320  	assert.Contains(t, err.Error(), "tls: bad certificate")
   321  
   322  	// Assert client w/ certs does not receive error
   323  	tlsConf, err := tlsconfig.NewClientConfig(tlsconfig.ClientRootCAFiles(path.Join(wd, "testdata/ca-cert.pem")), tlsconfig.ClientKeyPairFiles(path.Join(wd, "testdata/client-cert.pem"), path.Join(wd, "testdata/client-key.pem")))
   324  	require.NoError(t, err)
   325  	_, err = (&http.Client{Transport: &http.Transport{TLSClientConfig: tlsConf}}).Get(fmt.Sprintf("https://localhost:%d/example", port))
   326  	require.NoError(t, err)
   327  }