k8s.io/apiserver@v0.31.1/pkg/server/genericapiserver_graceful_termination_test.go (about)

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package server
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"log"
    27  	"net"
    28  	"net/http"
    29  	"net/http/httptrace"
    30  	"os"
    31  	"reflect"
    32  	"sync"
    33  	"syscall"
    34  	"testing"
    35  	"time"
    36  
    37  	utilnet "k8s.io/apimachinery/pkg/util/net"
    38  	"k8s.io/apimachinery/pkg/util/wait"
    39  	auditinternal "k8s.io/apiserver/pkg/apis/audit"
    40  	"k8s.io/apiserver/pkg/audit"
    41  	"k8s.io/apiserver/pkg/authorization/authorizer"
    42  	apirequest "k8s.io/apiserver/pkg/endpoints/request"
    43  	"k8s.io/apiserver/pkg/server/dynamiccertificates"
    44  	"k8s.io/klog/v2"
    45  	"k8s.io/klog/v2/ktesting"
    46  
    47  	"github.com/google/go-cmp/cmp"
    48  	"golang.org/x/net/http2"
    49  )
    50  
    51  func TestMain(m *testing.M) {
    52  	klog.InitFlags(nil)
    53  	os.Exit(m.Run())
    54  }
    55  
    56  // doer sends a request to the server
    57  type doer func(client *http.Client, gci func(httptrace.GotConnInfo), path string, timeout time.Duration) result
    58  
    59  func (d doer) Do(client *http.Client, gci func(httptrace.GotConnInfo), path string, timeout time.Duration) result {
    60  	return d(client, gci, path, timeout)
    61  }
    62  
    63  type result struct {
    64  	err      error
    65  	response *http.Response
    66  }
    67  
    68  // wrap a lifecycleSignal so the test can inject its own callback
    69  type wrappedLifecycleSignal struct {
    70  	lifecycleSignal
    71  	before func(lifecycleSignal)
    72  	after  func(lifecycleSignal)
    73  }
    74  
    75  func (w *wrappedLifecycleSignal) Signal() {
    76  	if w.before != nil {
    77  		w.before(w.lifecycleSignal)
    78  	}
    79  	w.lifecycleSignal.Signal()
    80  	if w.after != nil {
    81  		w.after(w.lifecycleSignal)
    82  	}
    83  }
    84  
    85  func wrapLifecycleSignalsWithRecorder(t *testing.T, signals *lifecycleSignals, before func(lifecycleSignal)) {
    86  	// it's important to record the signal being fired on a 'before' callback
    87  	// to avoid flakes, since on the server the signaling of events are
    88  	// an asynchronous process.
    89  	signals.AfterShutdownDelayDuration = wrapLifecycleSignal(t, signals.AfterShutdownDelayDuration, before, nil)
    90  	signals.PreShutdownHooksStopped = wrapLifecycleSignal(t, signals.PreShutdownHooksStopped, before, nil)
    91  	signals.NotAcceptingNewRequest = wrapLifecycleSignal(t, signals.NotAcceptingNewRequest, before, nil)
    92  	signals.HTTPServerStoppedListening = wrapLifecycleSignal(t, signals.HTTPServerStoppedListening, before, nil)
    93  	signals.InFlightRequestsDrained = wrapLifecycleSignal(t, signals.InFlightRequestsDrained, before, nil)
    94  	signals.ShutdownInitiated = wrapLifecycleSignal(t, signals.ShutdownInitiated, before, nil)
    95  }
    96  
    97  func wrapLifecycleSignal(t *testing.T, delegated lifecycleSignal, before, after func(_ lifecycleSignal)) lifecycleSignal {
    98  	return &wrappedLifecycleSignal{
    99  		lifecycleSignal: delegated,
   100  		before:          before,
   101  		after:           after,
   102  	}
   103  }
   104  
   105  // the server may not wait enough time between firing two events for
   106  // the test to execute its steps, this allows us to intercept the
   107  // signal and execute verification steps inside the goroutine that
   108  // is executing the test.
   109  type signalInterceptingTestStep struct {
   110  	doneCh chan struct{}
   111  }
   112  
   113  func (ts signalInterceptingTestStep) done() <-chan struct{} {
   114  	return ts.doneCh
   115  }
   116  func (ts signalInterceptingTestStep) execute(fn func()) {
   117  	defer close(ts.doneCh)
   118  	fn()
   119  }
   120  func newSignalInterceptingTestStep() *signalInterceptingTestStep {
   121  	return &signalInterceptingTestStep{
   122  		doneCh: make(chan struct{}),
   123  	}
   124  }
   125  
   126  //	 This test exercises the graceful termination scenario
   127  //	 described in the following diagram
   128  //	   - every vertical line is an independent timeline
   129  //	   - the leftmost vertical line represents the go routine that
   130  //	     is executing GenericAPIServer.Run method
   131  //	   - (signal name) indicates that the given lifecycle signal has been fired
   132  //
   133  //	                                 stopCh
   134  //	                                   |
   135  //	             |--------------------------------------------|
   136  //	             |                                            |
   137  //		    call PreShutdownHooks                        (ShutdownInitiated)
   138  //	             |                                            |
   139  //	  (PreShutdownHooksStopped)                   Sleep(ShutdownDelayDuration)
   140  //	             |                                            |
   141  //	             |                                 (AfterShutdownDelayDuration)
   142  //	             |                                            |
   143  //	             |                                            |
   144  //	             |--------------------------------------------|
   145  //	             |                                            |
   146  //	             |                                 (NotAcceptingNewRequest)
   147  //	             |                                            |
   148  //	             |                       |-------------------------------------------------|
   149  //	             |                       |                                                 |
   150  //	             |             close(stopHttpServerCh)                         NonLongRunningRequestWaitGroup.Wait()
   151  //	             |                       |                                                 |
   152  //	             |            server.Shutdown(timeout=60s)                                 |
   153  //	             |                       |                                         WatchRequestWaitGroup.Wait()
   154  //	             |              stop listener (net/http)                                   |
   155  //	             |                       |                                                 |
   156  //	             |          |-------------------------------------|                        |
   157  //	             |          |                                     |                        |
   158  //	             |          |                      (HTTPServerStoppedListening)            |
   159  //	             |          |                                                              |
   160  //	             |    wait up to 60s                                                       |
   161  //	             |          |                                                  (InFlightRequestsDrained)
   162  //	             |          |
   163  //	             |          |
   164  //	             |	stoppedCh is closed
   165  //	             |
   166  //	             |
   167  //	   <-drainedCh.Signaled()
   168  //	             |
   169  //	  s.AuditBackend.Shutdown()
   170  //	             |
   171  //	     <-listenerStoppedCh
   172  //	             |
   173  //	        <-stoppedCh
   174  //	             |
   175  //	         return nil
   176  func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t *testing.T) {
   177  	fakeAudit := &fakeAudit{}
   178  	s := newGenericAPIServer(t, fakeAudit, false)
   179  	connReusingClient := newClient(false)
   180  	doer := setupDoer(t, s.SecureServingInfo)
   181  
   182  	// handler for a non long-running and a watch request that
   183  	// we want to keep in flight through to the end.
   184  	inflightNonLongRunning := setupInFlightNonLongRunningRequestHandler(s)
   185  	inflightWatch := setupInFlightWatchRequestHandler(s)
   186  
   187  	// API calls from the pre-shutdown hook(s) must succeed up to
   188  	// the point where the HTTP server is shut down.
   189  	preShutdownHook := setupPreShutdownHookHandler(t, s, doer, newClient(true))
   190  
   191  	signals := &s.lifecycleSignals
   192  	recorder := &signalRecorder{}
   193  	wrapLifecycleSignalsWithRecorder(t, signals, recorder.before)
   194  
   195  	// before the AfterShutdownDelayDuration signal is fired, we want
   196  	// the test to execute a verification step.
   197  	beforeShutdownDelayDurationStep := newSignalInterceptingTestStep()
   198  	signals.AfterShutdownDelayDuration = wrapLifecycleSignal(t, signals.AfterShutdownDelayDuration, func(_ lifecycleSignal) {
   199  		// wait for the test to execute verification steps before
   200  		// the server signals the next steps
   201  		<-beforeShutdownDelayDurationStep.done()
   202  	}, nil)
   203  
   204  	// start the API server
   205  	_, ctx := ktesting.NewTestContext(t)
   206  	stopCtx, stop := context.WithCancelCause(ctx)
   207  	defer stop(errors.New("test has completed"))
   208  	runCompletedCh := make(chan struct{})
   209  	go func() {
   210  		defer close(runCompletedCh)
   211  		if err := s.PrepareRun().RunWithContext(stopCtx); err != nil {
   212  			t.Errorf("unexpected error from RunWithContext: %v", err)
   213  		}
   214  	}()
   215  	waitForAPIServerStarted(t, doer)
   216  
   217  	// fire the non long-running and the watch request so it is
   218  	// in-flight on the server now, and we will unblock them
   219  	// after ShutdownDelayDuration elapses.
   220  	inflightNonLongRunning.launch(doer, connReusingClient)
   221  	waitForeverUntil(t, inflightNonLongRunning.startedCh, "in-flight non long-running request did not reach the server")
   222  	inflightWatch.launch(doer, connReusingClient)
   223  	waitForeverUntil(t, inflightWatch.startedCh, "in-flight watch request did not reach the server")
   224  
   225  	// /readyz should return OK
   226  	resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second)
   227  	if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   228  		t.Errorf("%s", err.Error())
   229  	}
   230  
   231  	// signal termination event: initiate a shutdown
   232  	stop(errors.New("shutting down"))
   233  	waitForeverUntilSignaled(t, signals.ShutdownInitiated)
   234  
   235  	// /readyz must return an error, but we need to give it some time
   236  	err := wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) {
   237  		resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second)
   238  		// wait until we have a non 200 response
   239  		if resultGot.response != nil && resultGot.response.StatusCode == http.StatusOK {
   240  			return false, nil
   241  		}
   242  
   243  		if err := assertResponseStatusCode(resultGot, http.StatusInternalServerError); err != nil {
   244  			return true, err
   245  		}
   246  		return true, nil
   247  	})
   248  	if err != nil {
   249  		t.Errorf("Expected /readyz to return 500 status code, but got: %v", err)
   250  	}
   251  
   252  	// before ShutdownDelayDuration elapses new request(s) should be served successfully.
   253  	beforeShutdownDelayDurationStep.execute(func() {
   254  		t.Log("Before ShutdownDelayDuration elapses new request(s) should be served")
   255  		resultGot := doer.Do(connReusingClient, shouldReuseConnection(t), "/echo?message=request-on-an-existing-connection-should-succeed", time.Second)
   256  		if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   257  			t.Errorf("%s", err.Error())
   258  		}
   259  		resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-succeed", time.Second)
   260  		if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   261  			t.Errorf("%s", err.Error())
   262  		}
   263  	})
   264  
   265  	waitForeverUntilSignaled(t, signals.AfterShutdownDelayDuration)
   266  
   267  	// preshutdown hook has not completed yet, new incomng request should succeed
   268  	resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-succeed", time.Second)
   269  	if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   270  		t.Errorf("%s", err.Error())
   271  	}
   272  
   273  	// let the preshutdown hook issue an API call now, and then
   274  	// let's wait for it to return the result.
   275  	close(preShutdownHook.blockedCh)
   276  	preShutdownHookResult := <-preShutdownHook.resultCh
   277  	waitForeverUntilSignaled(t, signals.PreShutdownHooksStopped)
   278  	if err := assertResponseStatusCode(preShutdownHookResult, http.StatusOK); err != nil {
   279  		t.Errorf("%s", err.Error())
   280  	}
   281  
   282  	waitForeverUntilSignaled(t, signals.PreShutdownHooksStopped)
   283  	// both AfterShutdownDelayDuration and PreShutdownHooksCompleted
   284  	// have been signaled, we should not be accepting new request
   285  	waitForeverUntilSignaled(t, signals.NotAcceptingNewRequest)
   286  	waitForeverUntilSignaled(t, signals.HTTPServerStoppedListening)
   287  
   288  	resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-fail-with-503", time.Second)
   289  	if !utilnet.IsConnectionRefused(resultGot.err) {
   290  		t.Errorf("Expected error %v, but got: %v %v", syscall.ECONNREFUSED, resultGot.err, resultGot.response)
   291  	}
   292  
   293  	// even though Server.Serve() has returned, an existing connection on
   294  	// the server may take some time to be in "closing" state, the following
   295  	// poll eliminates any flake due to that delay.
   296  	if err := wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) {
   297  		result := doer.Do(connReusingClient, shouldReuseConnection(t), "/echo?message=waiting-for-the-existing-connection-to-reject-incoming-request", time.Second)
   298  		if result.response != nil {
   299  			t.Logf("Still waiting for the server to return error - response: %v", result.response)
   300  			return false, nil
   301  		}
   302  		return true, nil
   303  	}); err != nil {
   304  		t.Errorf("Expected no error, but got: %v", err)
   305  	}
   306  
   307  	// TODO: our original intention was for any incoming request to receive a 503
   308  	// via the WithWaitGroup filter, but, at this point, any incoming requests
   309  	// will get a 'connection refused' error since the net/http server has
   310  	// stopped listening.
   311  	resultGot = doer.Do(connReusingClient, shouldReuseConnection(t), "/echo?message=request-on-an-existing-connection-should-fail-with-error", time.Second)
   312  	if !utilnet.IsConnectionRefused(resultGot.err) {
   313  		t.Errorf("Expected error %v, but got: %v %v", syscall.ECONNREFUSED, resultGot.err, resultGot.response)
   314  	}
   315  
   316  	// the server has stopped listening but we still have a non long-running,
   317  	// and a watch request in flight, unblock both of these, and we expect
   318  	// the requests to return appropriate response to the caller.
   319  	inflightNonLongRunningResultGot := inflightNonLongRunning.unblockAndWaitForResult(t)
   320  	if err := assertResponseStatusCode(inflightNonLongRunningResultGot, http.StatusOK); err != nil {
   321  		t.Errorf("%s", err.Error())
   322  	}
   323  	if err := assertRequestAudited(inflightNonLongRunningResultGot, fakeAudit); err != nil {
   324  		t.Errorf("%s", err.Error())
   325  	}
   326  	inflightWatchResultGot := inflightWatch.unblockAndWaitForResult(t)
   327  	if err := assertResponseStatusCode(inflightWatchResultGot, http.StatusOK); err != nil {
   328  		t.Errorf("%s", err.Error())
   329  	}
   330  	if err := assertRequestAudited(inflightWatchResultGot, fakeAudit); err != nil {
   331  		t.Errorf("%s", err.Error())
   332  	}
   333  
   334  	// all requests in flight have drained
   335  	waitForeverUntilSignaled(t, signals.InFlightRequestsDrained)
   336  
   337  	t.Log("Waiting for the apiserver Run method to return")
   338  	waitForeverUntil(t, runCompletedCh, "the apiserver Run method did not return")
   339  
   340  	if !fakeAudit.shutdownCompleted() {
   341  		t.Errorf("Expected AuditBackend.Shutdown to be completed")
   342  	}
   343  
   344  	if err := recorder.verify([]string{
   345  		"ShutdownInitiated",
   346  		"AfterShutdownDelayDuration",
   347  		"PreShutdownHooksStopped",
   348  		"NotAcceptingNewRequest",
   349  		"HTTPServerStoppedListening",
   350  		"InFlightRequestsDrained",
   351  	}); err != nil {
   352  		t.Errorf("%s", err.Error())
   353  	}
   354  }
   355  
   356  // This test exercises the graceful termination scenario
   357  // described in the following diagram
   358  //
   359  //   - every vertical line is an independent timeline
   360  //
   361  //   - the leftmost vertical line represents the go routine that
   362  //     is executing GenericAPIServer.Run method
   363  //
   364  //   - (signal) indicates that the given lifecycle signal has been fired
   365  //
   366  //     stopCh
   367  //     |
   368  //     |--------------------------------------------|
   369  //     |                                            |
   370  //     call PreShutdownHooks                       (ShutdownInitiated)
   371  //     |                                            |
   372  //     (PreShutdownHooksCompleted)                  Sleep(ShutdownDelayDuration)
   373  //     |                                            |
   374  //     |                                 (AfterShutdownDelayDuration)
   375  //     |                                            |
   376  //     |                                            |
   377  //     |--------------------------------------------|
   378  //     |                                            |
   379  //     |                               (NotAcceptingNewRequest)
   380  //     |                                            |
   381  //     |                              NonLongRunningRequestWaitGroup.Wait()
   382  //     |                                            |
   383  //     |                                 WatchRequestWaitGroup.Wait()
   384  //     |                                            |
   385  //     |                                (InFlightRequestsDrained)
   386  //     |                                            |
   387  //     |                                            |
   388  //     |------------------------------------------------------------|
   389  //     |                                                            |
   390  //     <-drainedCh.Signaled()                                     close(stopHttpServerCh)
   391  //     |                                                            |
   392  //     s.AuditBackend.Shutdown()                                 server.Shutdown(timeout=2s)
   393  //     |                                                            |
   394  //     |                                                   stop listener (net/http)
   395  //     |                                                            |
   396  //     |                                         |-------------------------------------|
   397  //     |                                         |                                     |
   398  //     |                                   wait up to 2s                 (HTTPServerStoppedListening)
   399  //     <-listenerStoppedCh                                |
   400  //     |                                stoppedCh is closed
   401  //     <-stoppedCh
   402  //     |
   403  //     return nil
   404  func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationEnabled(t *testing.T) {
   405  	fakeAudit := &fakeAudit{}
   406  	s := newGenericAPIServer(t, fakeAudit, true)
   407  	connReusingClient := newClient(false)
   408  	doer := setupDoer(t, s.SecureServingInfo)
   409  
   410  	// handler for a non long-running and a watch request that
   411  	// we want to keep in flight through to the end.
   412  	inflightNonLongRunning := setupInFlightNonLongRunningRequestHandler(s)
   413  	inflightWatch := setupInFlightWatchRequestHandler(s)
   414  
   415  	// API calls from the pre-shutdown hook(s) must succeed up to
   416  	// the point where the HTTP server is shut down.
   417  	preShutdownHook := setupPreShutdownHookHandler(t, s, doer, newClient(true))
   418  
   419  	signals := &s.lifecycleSignals
   420  	recorder := &signalRecorder{}
   421  	wrapLifecycleSignalsWithRecorder(t, signals, recorder.before)
   422  
   423  	// before the AfterShutdownDelayDuration signal is fired, we want
   424  	// the test to execute a verification step.
   425  	beforeShutdownDelayDurationStep := newSignalInterceptingTestStep()
   426  	signals.AfterShutdownDelayDuration = wrapLifecycleSignal(t, signals.AfterShutdownDelayDuration, func(_ lifecycleSignal) {
   427  		// Before AfterShutdownDelayDuration event is signaled, the test
   428  		// will send request(s) to assert on expected behavior.
   429  		<-beforeShutdownDelayDurationStep.done()
   430  	}, nil)
   431  
   432  	// start the API server
   433  	_, ctx := ktesting.NewTestContext(t)
   434  	stopCtx, stop := context.WithCancelCause(ctx)
   435  	defer stop(errors.New("test has completed"))
   436  	runCompletedCh := make(chan struct{})
   437  	go func() {
   438  		defer close(runCompletedCh)
   439  		if err := s.PrepareRun().RunWithContext(stopCtx); err != nil {
   440  			t.Errorf("unexpected error from RunWithContext: %v", err)
   441  		}
   442  	}()
   443  	waitForAPIServerStarted(t, doer)
   444  
   445  	// fire the non long-running and the watch request so it is
   446  	// in-flight on the server now, and we will unblock them
   447  	// after ShutdownDelayDuration elapses.
   448  	inflightNonLongRunning.launch(doer, connReusingClient)
   449  	waitForeverUntil(t, inflightNonLongRunning.startedCh, "in-flight request did not reach the server")
   450  	inflightWatch.launch(doer, connReusingClient)
   451  	waitForeverUntil(t, inflightWatch.startedCh, "in-flight watch request did not reach the server")
   452  
   453  	// /readyz should return OK
   454  	resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second)
   455  	if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   456  		t.Errorf("%s", err.Error())
   457  	}
   458  
   459  	// signal termination event: initiate a shutdown
   460  	stop(errors.New("shutting down"))
   461  	waitForeverUntilSignaled(t, signals.ShutdownInitiated)
   462  
   463  	// /readyz must return an error, but we need to give it some time
   464  	err := wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) {
   465  		resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second)
   466  		// wait until we have a non 200 response
   467  		if resultGot.response != nil && resultGot.response.StatusCode == http.StatusOK {
   468  			return false, nil
   469  		}
   470  
   471  		if err := assertResponseStatusCode(resultGot, http.StatusInternalServerError); err != nil {
   472  			return true, err
   473  		}
   474  		return true, nil
   475  	})
   476  	if err != nil {
   477  		t.Errorf("Expected /readyz to return 500 status code, but got: %v", err)
   478  	}
   479  
   480  	// before ShutdownDelayDuration elapses new request(s) should be served successfully.
   481  	beforeShutdownDelayDurationStep.execute(func() {
   482  		t.Log("Before ShutdownDelayDuration elapses new request(s) should be served")
   483  		resultGot := doer.Do(connReusingClient, shouldReuseConnection(t), "/echo?message=request-on-an-existing-connection-should-succeed", time.Second)
   484  		if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   485  			t.Errorf("%s", err.Error())
   486  		}
   487  		resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-succeed", time.Second)
   488  		if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   489  			t.Errorf("%s", err.Error())
   490  		}
   491  	})
   492  
   493  	waitForeverUntilSignaled(t, signals.AfterShutdownDelayDuration)
   494  
   495  	// preshutdown hook has not completed yet, new incomng request should succeed
   496  	resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-succeed", time.Second)
   497  	if err := assertResponseStatusCode(resultGot, http.StatusOK); err != nil {
   498  		t.Errorf("%s", err.Error())
   499  	}
   500  
   501  	// let the preshutdown hook issue an API call now, and then let's wait
   502  	// for it to return the result, it should succeed.
   503  	close(preShutdownHook.blockedCh)
   504  	preShutdownHookResult := <-preShutdownHook.resultCh
   505  	waitForeverUntilSignaled(t, signals.PreShutdownHooksStopped)
   506  	if err := assertResponseStatusCode(preShutdownHookResult, http.StatusOK); err != nil {
   507  		t.Errorf("%s", err.Error())
   508  	}
   509  
   510  	waitForeverUntilSignaled(t, signals.NotAcceptingNewRequest)
   511  
   512  	// both AfterShutdownDelayDuration and PreShutdownHooksCompleted
   513  	// have been signaled, any incoming request should receive 429
   514  	resultGot = doer.Do(newClient(true), shouldUseNewConnection(t), "/echo?message=request-on-a-new-tcp-connection-should-fail-with-429", time.Second)
   515  	if err := requestMustFailWithRetryHeader(resultGot, http.StatusTooManyRequests); err != nil {
   516  		t.Errorf("%s", err.Error())
   517  	}
   518  	resultGot = doer.Do(connReusingClient, shouldReuseConnection(t), "/echo?message=request-on-an-existing-connection-should-fail-with-429", time.Second)
   519  	if err := requestMustFailWithRetryHeader(resultGot, http.StatusTooManyRequests); err != nil {
   520  		t.Errorf("%s", err.Error())
   521  	}
   522  
   523  	// we still have a non long-running, and a watch request in flight,
   524  	// unblock both of these, and we expect the requests
   525  	// to return appropriate response to the caller.
   526  	inflightNonLongRunningResultGot := inflightNonLongRunning.unblockAndWaitForResult(t)
   527  	if err := assertResponseStatusCode(inflightNonLongRunningResultGot, http.StatusOK); err != nil {
   528  		t.Errorf("%s", err.Error())
   529  	}
   530  	if err := assertRequestAudited(inflightNonLongRunningResultGot, fakeAudit); err != nil {
   531  		t.Errorf("%s", err.Error())
   532  	}
   533  	inflightWatchResultGot := inflightWatch.unblockAndWaitForResult(t)
   534  	if err := assertResponseStatusCode(inflightWatchResultGot, http.StatusOK); err != nil {
   535  		t.Errorf("%s", err.Error())
   536  	}
   537  	if err := assertRequestAudited(inflightWatchResultGot, fakeAudit); err != nil {
   538  		t.Errorf("%s", err.Error())
   539  	}
   540  
   541  	// all requests in flight have drained
   542  	waitForeverUntilSignaled(t, signals.InFlightRequestsDrained)
   543  	waitForeverUntilSignaled(t, signals.HTTPServerStoppedListening)
   544  
   545  	t.Log("Waiting for the apiserver Run method to return")
   546  	waitForeverUntil(t, runCompletedCh, "the apiserver Run method did not return")
   547  
   548  	if !fakeAudit.shutdownCompleted() {
   549  		t.Errorf("Expected AuditBackend.Shutdown to be completed")
   550  	}
   551  
   552  	if err := recorder.verify([]string{
   553  		"ShutdownInitiated",
   554  		"AfterShutdownDelayDuration",
   555  		"PreShutdownHooksStopped",
   556  		"NotAcceptingNewRequest",
   557  		"InFlightRequestsDrained",
   558  		"HTTPServerStoppedListening",
   559  	}); err != nil {
   560  		t.Errorf("%s", err.Error())
   561  	}
   562  }
   563  
   564  func TestMuxAndDiscoveryComplete(t *testing.T) {
   565  	// setup
   566  	testSignal1 := make(chan struct{})
   567  	testSignal2 := make(chan struct{})
   568  	s := newGenericAPIServer(t, &fakeAudit{}, true)
   569  	s.muxAndDiscoveryCompleteSignals["TestSignal1"] = testSignal1
   570  	s.muxAndDiscoveryCompleteSignals["TestSignal2"] = testSignal2
   571  	doer := setupDoer(t, s.SecureServingInfo)
   572  	isChanClosed := func(ch <-chan struct{}, delay time.Duration) bool {
   573  		time.Sleep(delay)
   574  		select {
   575  		case <-ch:
   576  			return true
   577  		default:
   578  			return false
   579  		}
   580  	}
   581  
   582  	// start the API server
   583  	_, ctx := ktesting.NewTestContext(t)
   584  	stopCtx, stop := context.WithCancelCause(ctx)
   585  	defer stop(errors.New("test has completed"))
   586  	runCompletedCh := make(chan struct{})
   587  	go func() {
   588  		defer close(runCompletedCh)
   589  		if err := s.PrepareRun().RunWithContext(stopCtx); err != nil {
   590  			t.Errorf("unexpected error from RunWithContext: %v", err)
   591  		}
   592  	}()
   593  	waitForAPIServerStarted(t, doer)
   594  
   595  	// act
   596  	if isChanClosed(s.lifecycleSignals.MuxAndDiscoveryComplete.Signaled(), 1*time.Second) {
   597  		t.Fatalf("%s is closed whereas the TestSignal is still open", s.lifecycleSignals.MuxAndDiscoveryComplete.Name())
   598  	}
   599  
   600  	close(testSignal1)
   601  	if isChanClosed(s.lifecycleSignals.MuxAndDiscoveryComplete.Signaled(), 1*time.Second) {
   602  		t.Fatalf("%s is closed whereas the TestSignal2 is still open", s.lifecycleSignals.MuxAndDiscoveryComplete.Name())
   603  	}
   604  
   605  	close(testSignal2)
   606  	if !isChanClosed(s.lifecycleSignals.MuxAndDiscoveryComplete.Signaled(), 1*time.Second) {
   607  		t.Fatalf("%s wasn't closed", s.lifecycleSignals.MuxAndDiscoveryComplete.Name())
   608  	}
   609  }
   610  
   611  func TestPreShutdownHooks(t *testing.T) {
   612  	tests := []struct {
   613  		name   string
   614  		server func() *GenericAPIServer
   615  	}{
   616  		{
   617  			name: "ShutdownSendRetryAfter is disabled",
   618  			server: func() *GenericAPIServer {
   619  				return newGenericAPIServer(t, &fakeAudit{}, false)
   620  			},
   621  		},
   622  		{
   623  			name: "ShutdownSendRetryAfter is enabled",
   624  			server: func() *GenericAPIServer {
   625  				return newGenericAPIServer(t, &fakeAudit{}, true)
   626  			},
   627  		},
   628  	}
   629  
   630  	for _, test := range tests {
   631  		t.Run(test.name, func(t *testing.T) {
   632  			_, ctx := ktesting.NewTestContext(t)
   633  			stopCtx, stop := context.WithCancelCause(ctx)
   634  			defer stop(errors.New("test has completed"))
   635  			s := test.server()
   636  			doer := setupDoer(t, s.SecureServingInfo)
   637  
   638  			// preshutdown hook should not block when sending to the error channel
   639  			preShutdownHookErrCh := make(chan error, 1)
   640  			err := s.AddPreShutdownHook("test-backend", func() error {
   641  				// this pre-shutdown hook waits for the shutdown duration to elapse,
   642  				// and then send a series of requests to the apiserver, and
   643  				// we expect these series of requests to be completed successfully
   644  				<-s.lifecycleSignals.AfterShutdownDelayDuration.Signaled()
   645  
   646  				// we send 5 requests, one every second
   647  				var err error
   648  				client := newClient(true)
   649  				for i := 0; i < 5; i++ {
   650  					r := doer.Do(client, func(httptrace.GotConnInfo) {}, fmt.Sprintf("/echo?message=attempt-%d", i), 1*time.Second)
   651  					err = r.err
   652  					if err == nil && r.response.StatusCode != http.StatusOK {
   653  						err = fmt.Errorf("did not get status code 200 - %#v", r.response)
   654  						break
   655  					}
   656  					time.Sleep(time.Second)
   657  				}
   658  				preShutdownHookErrCh <- err
   659  				return nil
   660  			})
   661  			if err != nil {
   662  				t.Fatalf("Failed to add pre-shutdown hook - %v", err)
   663  			}
   664  
   665  			// start the API server
   666  			runCompletedCh := make(chan struct{})
   667  			go func() {
   668  				defer close(runCompletedCh)
   669  				if err := s.PrepareRun().RunWithContext(stopCtx); err != nil {
   670  					t.Errorf("unexpected error from RunWithContext: %v", err)
   671  				}
   672  			}()
   673  			waitForAPIServerStarted(t, doer)
   674  
   675  			stop(errors.New("shutting down"))
   676  
   677  			waitForeverUntil(t, runCompletedCh, "the apiserver Run method did not return")
   678  
   679  			select {
   680  			case err := <-preShutdownHookErrCh:
   681  				if err != nil {
   682  					t.Errorf("PreSHutdown hook can not access the API server - %v", err)
   683  				}
   684  			case <-time.After(wait.ForeverTestTimeout):
   685  				t.Fatalf("pre-shutdown hook did not complete as expected")
   686  			}
   687  		})
   688  	}
   689  }
   690  
   691  type signalRecorder struct {
   692  	lock  sync.Mutex
   693  	order []string
   694  }
   695  
   696  func (r *signalRecorder) before(s lifecycleSignal) {
   697  	r.lock.Lock()
   698  	defer r.lock.Unlock()
   699  	r.order = append(r.order, s.Name())
   700  }
   701  
   702  func (r *signalRecorder) verify(got []string) error {
   703  	r.lock.Lock()
   704  	defer r.lock.Unlock()
   705  	want := r.order
   706  	if !reflect.DeepEqual(want, got) {
   707  		return fmt.Errorf("Expected order of termination event signal to match, diff: %s", cmp.Diff(want, got))
   708  	}
   709  	return nil
   710  }
   711  
   712  type inFlightRequest struct {
   713  	blockedCh, startedCh chan struct{}
   714  	resultCh             chan result
   715  	url                  string
   716  }
   717  
   718  func setupInFlightNonLongRunningRequestHandler(s *GenericAPIServer) *inFlightRequest {
   719  	inflight := &inFlightRequest{
   720  		blockedCh: make(chan struct{}),
   721  		startedCh: make(chan struct{}),
   722  		resultCh:  make(chan result),
   723  		url:       "/in-flight-non-long-running-request-as-designed",
   724  	}
   725  	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   726  		close(inflight.startedCh)
   727  		// this request handler blocks until we deliberately unblock it.
   728  		<-inflight.blockedCh
   729  		w.WriteHeader(http.StatusOK)
   730  	})
   731  	s.Handler.NonGoRestfulMux.Handle(inflight.url, handler)
   732  	return inflight
   733  }
   734  
   735  func setupInFlightWatchRequestHandler(s *GenericAPIServer) *inFlightRequest {
   736  	inflight := &inFlightRequest{
   737  		blockedCh: make(chan struct{}),
   738  		startedCh: make(chan struct{}),
   739  		resultCh:  make(chan result),
   740  		url:       "/apis/watches.group/v1/namespaces/foo/bar?watch=true",
   741  	}
   742  
   743  	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   744  		close(inflight.startedCh)
   745  		// this request handler blocks until we deliberately unblock it.
   746  		<-inflight.blockedCh
   747  
   748  		// this simulates a watch well enough for our test
   749  		signals := apirequest.ServerShutdownSignalFrom(req.Context())
   750  		if signals == nil {
   751  			w.WriteHeader(http.StatusInternalServerError)
   752  			return
   753  		}
   754  		<-signals.ShuttingDown()
   755  		w.WriteHeader(http.StatusOK)
   756  	})
   757  	s.Handler.NonGoRestfulMux.Handle("/apis/watches.group/v1/namespaces/foo/bar", handler)
   758  	return inflight
   759  }
   760  
   761  func (ifr *inFlightRequest) launch(doer doer, client *http.Client) {
   762  	go func() {
   763  		result := doer.Do(client, func(httptrace.GotConnInfo) {}, ifr.url, 0)
   764  		ifr.resultCh <- result
   765  	}()
   766  }
   767  
   768  func (ifr *inFlightRequest) unblockAndWaitForResult(t *testing.T) result {
   769  	close(ifr.blockedCh)
   770  
   771  	var resultGot result
   772  	select {
   773  	case resultGot = <-ifr.resultCh:
   774  		return resultGot
   775  	case <-time.After(wait.ForeverTestTimeout):
   776  		t.Fatal("Expected the server to send a response")
   777  	}
   778  	return resultGot
   779  }
   780  
   781  type preShutdownHookHandler struct {
   782  	blockedCh chan struct{}
   783  	resultCh  chan result
   784  }
   785  
   786  func setupPreShutdownHookHandler(t *testing.T, s *GenericAPIServer, doer doer, client *http.Client) *preShutdownHookHandler {
   787  	hook := &preShutdownHookHandler{
   788  		blockedCh: make(chan struct{}),
   789  		resultCh:  make(chan result),
   790  	}
   791  	if err := s.AddPreShutdownHook("test-preshutdown-hook", func() error {
   792  		// wait until the test commands this pre shutdown
   793  		// hook to invoke an API call.
   794  		<-hook.blockedCh
   795  
   796  		resultGot := doer.Do(client, func(httptrace.GotConnInfo) {}, "/echo?message=request-from-pre-shutdown-hook-should-succeed", time.Second)
   797  		hook.resultCh <- resultGot
   798  		return nil
   799  	}); err != nil {
   800  		t.Fatalf("Failed to register preshutdown hook - %v", err)
   801  	}
   802  
   803  	return hook
   804  }
   805  
   806  type fakeAudit struct {
   807  	shutdownCh chan struct{}
   808  	lock       sync.Mutex
   809  	audits     map[string]struct{}
   810  	completed  bool
   811  }
   812  
   813  func (a *fakeAudit) Run(stopCh <-chan struct{}) error {
   814  	a.shutdownCh = make(chan struct{})
   815  	go func() {
   816  		defer close(a.shutdownCh)
   817  		<-stopCh
   818  	}()
   819  	return nil
   820  }
   821  
   822  func (a *fakeAudit) Shutdown() {
   823  	<-a.shutdownCh
   824  
   825  	a.lock.Lock()
   826  	defer a.lock.Unlock()
   827  	a.completed = true
   828  }
   829  
   830  func (a *fakeAudit) String() string {
   831  	return "fake-audit"
   832  }
   833  
   834  func (a *fakeAudit) shutdownCompleted() bool {
   835  	a.lock.Lock()
   836  	defer a.lock.Unlock()
   837  
   838  	return a.completed
   839  }
   840  
   841  func (a *fakeAudit) ProcessEvents(events ...*auditinternal.Event) bool {
   842  	a.lock.Lock()
   843  	defer a.lock.Unlock()
   844  	if len(a.audits) == 0 {
   845  		a.audits = map[string]struct{}{}
   846  	}
   847  	for _, event := range events {
   848  		a.audits[string(event.AuditID)] = struct{}{}
   849  	}
   850  
   851  	return true
   852  }
   853  
   854  func (a *fakeAudit) requestAudited(auditID string) bool {
   855  	a.lock.Lock()
   856  	defer a.lock.Unlock()
   857  	_, exists := a.audits[auditID]
   858  	return exists
   859  }
   860  
   861  func (a *fakeAudit) EvaluatePolicyRule(attrs authorizer.Attributes) audit.RequestAuditConfig {
   862  	return audit.RequestAuditConfig{
   863  		Level: auditinternal.LevelMetadata,
   864  	}
   865  }
   866  
   867  func assertRequestAudited(resultGot result, backend *fakeAudit) error {
   868  	resp := resultGot.response
   869  	if resp == nil {
   870  		return fmt.Errorf("Expected a response, but got nil")
   871  	}
   872  	auditIDGot := resp.Header.Get(auditinternal.HeaderAuditID)
   873  	if len(auditIDGot) == 0 {
   874  		return fmt.Errorf("Expected non-empty %q response header, but got: %v", auditinternal.HeaderAuditID, resp)
   875  	}
   876  	if !backend.requestAudited(auditIDGot) {
   877  		return fmt.Errorf("Expected the request to be audited: %q", auditIDGot)
   878  	}
   879  	return nil
   880  }
   881  
   882  func waitForeverUntilSignaled(t *testing.T, s lifecycleSignal) {
   883  	waitForeverUntil(t, s.Signaled(), fmt.Sprintf("Expected the server to signal %s event", s.Name()))
   884  }
   885  
   886  func waitForeverUntil(t *testing.T, ch <-chan struct{}, msg string) {
   887  	select {
   888  	case <-ch:
   889  	case <-time.After(wait.ForeverTestTimeout):
   890  		t.Fatalf("%s", msg)
   891  	}
   892  }
   893  
   894  func shouldReuseConnection(t *testing.T) func(httptrace.GotConnInfo) {
   895  	return func(ci httptrace.GotConnInfo) {
   896  		if !ci.Reused {
   897  			t.Errorf("Expected the request to use an existing TCP connection, but got: %+v", ci)
   898  		}
   899  	}
   900  }
   901  
   902  func shouldUseNewConnection(t *testing.T) func(httptrace.GotConnInfo) {
   903  	return func(ci httptrace.GotConnInfo) {
   904  		if ci.Reused {
   905  			t.Errorf("Expected the request to use a new TCP connection, but got: %+v", ci)
   906  		}
   907  	}
   908  }
   909  
   910  func assertResponseStatusCode(resultGot result, statusCodeExpected int) error {
   911  	if resultGot.err != nil {
   912  		return fmt.Errorf("Expected no error, but got: %v", resultGot.err)
   913  	}
   914  	if resultGot.response.StatusCode != statusCodeExpected {
   915  		return fmt.Errorf("Expected Status Code: %d, but got: %d", statusCodeExpected, resultGot.response.StatusCode)
   916  	}
   917  	return nil
   918  }
   919  
   920  func requestMustFailWithRetryHeader(resultGot result, statusCodedExpected int) error {
   921  	if resultGot.err != nil {
   922  		return fmt.Errorf("Expected no error, but got: %v", resultGot.err)
   923  	}
   924  	if statusCodedExpected != resultGot.response.StatusCode {
   925  		return fmt.Errorf("Expected Status Code: %d, but got: %d", statusCodedExpected, resultGot.response.StatusCode)
   926  	}
   927  	retryAfterGot := resultGot.response.Header.Get("Retry-After")
   928  	if retryAfterGot != "5" {
   929  		return fmt.Errorf("Expected Retry-After Response Header, but got: %v", resultGot.response)
   930  	}
   931  	return nil
   932  }
   933  
   934  func waitForAPIServerStarted(t *testing.T, doer doer) {
   935  	client := newClient(true)
   936  	i := 1
   937  	err := wait.PollImmediate(100*time.Millisecond, 5*time.Second, func() (done bool, err error) {
   938  		result := doer.Do(client, func(httptrace.GotConnInfo) {}, fmt.Sprintf("/echo?message=attempt-%d", i), time.Second)
   939  		i++
   940  
   941  		if result.err != nil {
   942  			t.Logf("Still waiting for the server to start - err: %v", err)
   943  			return false, nil
   944  		}
   945  		if result.response.StatusCode != http.StatusOK {
   946  			t.Logf("Still waiting for the server to start - expecting: %d, but got: %v", http.StatusOK, result.response)
   947  			return false, nil
   948  		}
   949  
   950  		t.Log("The API server has started")
   951  		return true, nil
   952  	})
   953  
   954  	if err != nil {
   955  		t.Fatalf("The server has failed to start - err: %v", err)
   956  	}
   957  }
   958  
   959  func setupDoer(t *testing.T, info *SecureServingInfo) doer {
   960  	_, port, err := info.HostPort()
   961  	if err != nil {
   962  		t.Fatalf("Expected host, port from SecureServingInfo, but got: %v", err)
   963  	}
   964  
   965  	return func(client *http.Client, callback func(httptrace.GotConnInfo), path string, timeout time.Duration) result {
   966  		url := fmt.Sprintf("https://%s:%d%s", "127.0.0.1", port, path)
   967  		t.Logf("Sending request - timeout: %s, url: %s", timeout, url)
   968  
   969  		req, err := http.NewRequest("GET", url, nil)
   970  		if err != nil {
   971  			return result{response: nil, err: err}
   972  		}
   973  
   974  		// setup request timeout
   975  		var ctx context.Context
   976  		if timeout > 0 {
   977  			var cancel context.CancelFunc
   978  			ctx, cancel = context.WithTimeout(req.Context(), timeout)
   979  			defer cancel()
   980  
   981  			req = req.WithContext(ctx)
   982  		}
   983  
   984  		// setup trace
   985  		trace := &httptrace.ClientTrace{
   986  			GotConn: func(connInfo httptrace.GotConnInfo) {
   987  				callback(connInfo)
   988  			},
   989  		}
   990  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   991  
   992  		response, err := client.Do(req)
   993  		// in this test, we don't depend on the body of the response, so we can
   994  		// close the Body here to ensure the underlying transport can be reused
   995  		if response != nil {
   996  			io.ReadAll(response.Body)
   997  			response.Body.Close()
   998  		}
   999  		return result{
  1000  			err:      err,
  1001  			response: response,
  1002  		}
  1003  	}
  1004  }
  1005  
  1006  func newClient(useNewConnection bool) *http.Client {
  1007  	clientCACertPool := x509.NewCertPool()
  1008  	clientCACertPool.AppendCertsFromPEM(backendCrt)
  1009  	tlsConfig := &tls.Config{
  1010  		RootCAs:    clientCACertPool,
  1011  		NextProtos: []string{http2.NextProtoTLS},
  1012  	}
  1013  
  1014  	tr := &http.Transport{
  1015  		TLSClientConfig:   tlsConfig,
  1016  		DisableKeepAlives: useNewConnection,
  1017  	}
  1018  	if err := http2.ConfigureTransport(tr); err != nil {
  1019  		log.Fatalf("Failed to configure HTTP2 transport: %v", err)
  1020  	}
  1021  	return &http.Client{
  1022  		Timeout:   0,
  1023  		Transport: tr,
  1024  	}
  1025  }
  1026  
  1027  func newGenericAPIServer(t *testing.T, fAudit *fakeAudit, keepListening bool) *GenericAPIServer {
  1028  	config, _ := setUp(t)
  1029  	config.ShutdownDelayDuration = 100 * time.Millisecond
  1030  	config.ShutdownSendRetryAfter = keepListening
  1031  	// we enable watch draining, any positive value will do that
  1032  	config.ShutdownWatchTerminationGracePeriod = 2 * time.Second
  1033  	config.AuditPolicyRuleEvaluator = fAudit
  1034  	config.AuditBackend = fAudit
  1035  
  1036  	s, err := config.Complete(nil).New("test", NewEmptyDelegate())
  1037  	if err != nil {
  1038  		t.Fatalf("Error in bringing up the server: %v", err)
  1039  	}
  1040  
  1041  	ln, err := net.Listen("tcp", "0.0.0.0:0")
  1042  	if err != nil {
  1043  		t.Fatalf("failed to listen on %v: %v", "0.0.0.0:0", err)
  1044  	}
  1045  	s.SecureServingInfo = &SecureServingInfo{}
  1046  	s.SecureServingInfo.Listener = &wrappedListener{ln, t}
  1047  
  1048  	cert, err := dynamiccertificates.NewStaticCertKeyContent("serving-cert", backendCrt, backendKey)
  1049  	if err != nil {
  1050  		t.Fatalf("failed to load cert - %v", err)
  1051  	}
  1052  	s.SecureServingInfo.Cert = cert
  1053  
  1054  	// we use this handler to send a test request to the server.
  1055  	s.Handler.NonGoRestfulMux.Handle("/echo", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  1056  		t.Logf("[server] received a request, proto: %s, url: %s", req.Proto, req.RequestURI)
  1057  
  1058  		w.Header().Add("echo", req.URL.Query().Get("message"))
  1059  		w.WriteHeader(http.StatusOK)
  1060  	}))
  1061  
  1062  	return s
  1063  }
  1064  
  1065  type wrappedListener struct {
  1066  	net.Listener
  1067  	t *testing.T
  1068  }
  1069  
  1070  func (ln wrappedListener) Accept() (net.Conn, error) {
  1071  	c, err := ln.Listener.Accept()
  1072  
  1073  	if tc, ok := c.(*net.TCPConn); ok {
  1074  		ln.t.Logf("[server] seen new connection: %#v", tc)
  1075  	}
  1076  	return c, err
  1077  }