k8s.io/apiserver@v0.31.1/pkg/endpoints/filters/request_deadline_test.go (about)

     1  /*
     2  Copyright 2020 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package filters
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"reflect"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  
    32  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    33  	"k8s.io/apimachinery/pkg/runtime"
    34  	"k8s.io/apimachinery/pkg/runtime/serializer"
    35  	auditinternal "k8s.io/apiserver/pkg/apis/audit"
    36  	"k8s.io/apiserver/pkg/audit"
    37  	"k8s.io/apiserver/pkg/audit/policy"
    38  	"k8s.io/apiserver/pkg/endpoints/request"
    39  	testingclock "k8s.io/utils/clock/testing"
    40  )
    41  
    42  func TestParseTimeout(t *testing.T) {
    43  	tests := []struct {
    44  		name            string
    45  		url             string
    46  		expected        bool
    47  		timeoutExpected time.Duration
    48  		message         string
    49  	}{
    50  		{
    51  			name: "the user does not specify a timeout",
    52  			url:  "/api/v1/namespaces?timeout=",
    53  		},
    54  		{
    55  			name:            "the user specifies a valid timeout",
    56  			url:             "/api/v1/namespaces?timeout=10s",
    57  			expected:        true,
    58  			timeoutExpected: 10 * time.Second,
    59  		},
    60  		{
    61  			name:     "the user specifies a timeout of 0s",
    62  			url:      "/api/v1/namespaces?timeout=0s",
    63  			expected: true,
    64  		},
    65  		{
    66  			name:    "the user specifies an invalid timeout",
    67  			url:     "/api/v1/namespaces?timeout=foo",
    68  			message: invalidTimeoutInURL,
    69  		},
    70  	}
    71  
    72  	for _, test := range tests {
    73  		t.Run(test.name, func(t *testing.T) {
    74  			request, err := http.NewRequest(http.MethodGet, test.url, nil)
    75  			if err != nil {
    76  				t.Fatalf("failed to create new http request - %v", err)
    77  			}
    78  
    79  			timeoutGot, ok, err := parseTimeout(request)
    80  
    81  			if test.expected != ok {
    82  				t.Errorf("expected: %t, but got: %t", test.expected, ok)
    83  			}
    84  			if test.timeoutExpected != timeoutGot {
    85  				t.Errorf("expected timeout: %s, but got: %s", test.timeoutExpected, timeoutGot)
    86  			}
    87  
    88  			errMessageGot := message(err)
    89  			if !strings.Contains(errMessageGot, test.message) {
    90  				t.Errorf("expected error message to contain: %s, but got: %s", test.message, errMessageGot)
    91  			}
    92  		})
    93  	}
    94  }
    95  
    96  func TestWithRequestDeadline(t *testing.T) {
    97  	const requestTimeoutMaximum = 60 * time.Second
    98  	tests := []struct {
    99  		name                     string
   100  		requestURL               string
   101  		longRunning              bool
   102  		hasDeadlineExpected      bool
   103  		deadlineExpected         time.Duration
   104  		handlerCallCountExpected int
   105  		statusCodeExpected       int
   106  	}{
   107  		{
   108  			name:                     "the user specifies a valid request timeout",
   109  			requestURL:               "/api/v1/namespaces?timeout=15s",
   110  			longRunning:              false,
   111  			handlerCallCountExpected: 1,
   112  			hasDeadlineExpected:      true,
   113  			deadlineExpected:         14 * time.Second, // to account for the delay in verification
   114  			statusCodeExpected:       http.StatusOK,
   115  		},
   116  		{
   117  			name:                     "the user specifies a valid request timeout",
   118  			requestURL:               "/api/v1/namespaces?timeout=15s",
   119  			longRunning:              false,
   120  			handlerCallCountExpected: 1,
   121  			hasDeadlineExpected:      true,
   122  			deadlineExpected:         14 * time.Second, // to account for the delay in verification
   123  			statusCodeExpected:       http.StatusOK,
   124  		},
   125  		{
   126  			name:                     "the specified timeout is 0s, default deadline is expected to be set",
   127  			requestURL:               "/api/v1/namespaces?timeout=0s",
   128  			longRunning:              false,
   129  			handlerCallCountExpected: 1,
   130  			hasDeadlineExpected:      true,
   131  			deadlineExpected:         requestTimeoutMaximum - time.Second, // to account for the delay in verification
   132  			statusCodeExpected:       http.StatusOK,
   133  		},
   134  		{
   135  			name:                     "the user does not specify any request timeout, default deadline is expected to be set",
   136  			requestURL:               "/api/v1/namespaces?timeout=",
   137  			longRunning:              false,
   138  			handlerCallCountExpected: 1,
   139  			hasDeadlineExpected:      true,
   140  			deadlineExpected:         requestTimeoutMaximum - time.Second, // to account for the delay in verification
   141  			statusCodeExpected:       http.StatusOK,
   142  		},
   143  		{
   144  			name:                     "the request is long running, no deadline is expected to be set",
   145  			requestURL:               "/api/v1/namespaces?timeout=10s",
   146  			longRunning:              true,
   147  			hasDeadlineExpected:      false,
   148  			handlerCallCountExpected: 1,
   149  			statusCodeExpected:       http.StatusOK,
   150  		},
   151  		{
   152  			name:               "the timeout specified is malformed, the request is aborted with HTTP 400",
   153  			requestURL:         "/api/v1/namespaces?timeout=foo",
   154  			longRunning:        false,
   155  			statusCodeExpected: http.StatusBadRequest,
   156  		},
   157  		{
   158  			name:                     "the timeout specified exceeds the maximum deadline allowed, the default deadline is used",
   159  			requestURL:               fmt.Sprintf("/api/v1/namespaces?timeout=%s", requestTimeoutMaximum+time.Second),
   160  			longRunning:              false,
   161  			statusCodeExpected:       http.StatusOK,
   162  			handlerCallCountExpected: 1,
   163  			hasDeadlineExpected:      true,
   164  			deadlineExpected:         requestTimeoutMaximum - time.Second, // to account for the delay in verification
   165  		},
   166  	}
   167  
   168  	for _, test := range tests {
   169  		t.Run(test.name, func(t *testing.T) {
   170  			var (
   171  				callCount      int
   172  				hasDeadlineGot bool
   173  				deadlineGot    time.Duration
   174  			)
   175  			handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
   176  				callCount++
   177  				deadlineGot, hasDeadlineGot = deadline(req)
   178  			})
   179  
   180  			fakeSink := &fakeAuditSink{}
   181  			fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   182  			withDeadline := WithRequestDeadline(handler, fakeSink, fakeRuleEvaluator,
   183  				func(_ *http.Request, _ *request.RequestInfo) bool { return test.longRunning },
   184  				newSerializer(), requestTimeoutMaximum)
   185  			withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
   186  
   187  			testRequest := newRequest(t, test.requestURL)
   188  
   189  			// make sure a default request does not have any deadline set
   190  			remaning, ok := deadline(testRequest)
   191  			if ok {
   192  				t.Fatalf("test setup failed, expected the new HTTP request context to have no deadline but got: %s", remaning)
   193  			}
   194  
   195  			w := httptest.NewRecorder()
   196  			withDeadline.ServeHTTP(w, testRequest)
   197  
   198  			if test.handlerCallCountExpected != callCount {
   199  				t.Errorf("expected the request handler to be invoked %d times, but was actually invoked %d times", test.handlerCallCountExpected, callCount)
   200  			}
   201  
   202  			if test.hasDeadlineExpected != hasDeadlineGot {
   203  				t.Errorf("expected the request context to have deadline set: %t but got: %t", test.hasDeadlineExpected, hasDeadlineGot)
   204  			}
   205  
   206  			deadlineGot = deadlineGot.Truncate(time.Second)
   207  			if test.deadlineExpected != deadlineGot {
   208  				t.Errorf("expected a request context with a deadline of %s but got: %s", test.deadlineExpected, deadlineGot)
   209  			}
   210  
   211  			statusCodeGot := w.Result().StatusCode
   212  			if test.statusCodeExpected != statusCodeGot {
   213  				t.Errorf("expected status code %d but got: %d", test.statusCodeExpected, statusCodeGot)
   214  			}
   215  		})
   216  	}
   217  }
   218  
   219  func TestWithRequestDeadlineWithClock(t *testing.T) {
   220  	var (
   221  		hasDeadlineGot bool
   222  		deadlineGot    time.Duration
   223  	)
   224  	handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
   225  		deadlineGot, hasDeadlineGot = deadline(req)
   226  	})
   227  
   228  	// if the deadline filter uses the clock instead of using the request started timestamp from the context
   229  	// then we will see a request deadline of about a minute.
   230  	receivedTimestampExpected := time.Now().Add(time.Minute)
   231  	fakeClock := testingclock.NewFakeClock(receivedTimestampExpected)
   232  
   233  	fakeSink := &fakeAuditSink{}
   234  	fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   235  	withDeadline := withRequestDeadline(handler, fakeSink, fakeRuleEvaluator,
   236  		func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute, fakeClock)
   237  	withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
   238  
   239  	testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s")
   240  	// the request has arrived just now.
   241  	testRequest = testRequest.WithContext(request.WithReceivedTimestamp(testRequest.Context(), time.Now()))
   242  
   243  	w := httptest.NewRecorder()
   244  	withDeadline.ServeHTTP(w, testRequest)
   245  
   246  	if !hasDeadlineGot {
   247  		t.Error("expected the request context to have deadline set")
   248  	}
   249  
   250  	// we expect a deadline <= 1s since the filter should use the request started timestamp from the context.
   251  	if deadlineGot > time.Second {
   252  		t.Errorf("expected a request context with a deadline <= %s, but got: %s", time.Second, deadlineGot)
   253  	}
   254  }
   255  
   256  func TestWithRequestDeadlineWithInvalidTimeoutIsAudited(t *testing.T) {
   257  	var handlerInvoked bool
   258  	handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
   259  		handlerInvoked = true
   260  	})
   261  
   262  	fakeSink := &fakeAuditSink{}
   263  	fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   264  	withDeadline := WithRequestDeadline(handler, fakeSink, fakeRuleEvaluator,
   265  		func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute)
   266  	withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
   267  
   268  	testRequest := newRequest(t, "/api/v1/namespaces?timeout=foo")
   269  	w := httptest.NewRecorder()
   270  	withDeadline.ServeHTTP(w, testRequest)
   271  
   272  	if handlerInvoked {
   273  		t.Error("expected the request to fail and the handler to be skipped")
   274  	}
   275  
   276  	statusCodeGot := w.Result().StatusCode
   277  	if statusCodeGot != http.StatusBadRequest {
   278  		t.Errorf("expected status code %d, but got: %d", http.StatusBadRequest, statusCodeGot)
   279  	}
   280  	// verify that the audit event from the request context is written to the audit sink.
   281  	if len(fakeSink.events) != 1 {
   282  		t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events))
   283  	}
   284  }
   285  
   286  func TestWithRequestDeadlineWithPanic(t *testing.T) {
   287  	var (
   288  		panicErrGot interface{}
   289  		ctxGot      context.Context
   290  	)
   291  
   292  	panicErrExpected := errors.New("apiserver panic'd")
   293  	handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
   294  		ctxGot = req.Context()
   295  		panic(panicErrExpected)
   296  	})
   297  
   298  	fakeSink := &fakeAuditSink{}
   299  	fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   300  	withDeadline := WithRequestDeadline(handler, fakeSink, fakeRuleEvaluator,
   301  		func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute)
   302  	withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
   303  	withPanicRecovery := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   304  		defer func() {
   305  			panicErrGot = recover()
   306  		}()
   307  		withDeadline.ServeHTTP(w, req)
   308  	})
   309  
   310  	testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s")
   311  	w := httptest.NewRecorder()
   312  	withPanicRecovery.ServeHTTP(w, testRequest)
   313  
   314  	if panicErrExpected != panicErrGot {
   315  		t.Errorf("expected panic error: %#v, but got: %#v", panicErrExpected, panicErrGot)
   316  	}
   317  	if ctxGot.Err() != context.Canceled {
   318  		t.Error("expected the request context to be canceled on handler panic")
   319  	}
   320  }
   321  
   322  func TestWithRequestDeadlineWithRequestTimesOut(t *testing.T) {
   323  	timeout := 100 * time.Millisecond
   324  	var errGot error
   325  	handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
   326  		ctx := req.Context()
   327  		select {
   328  		case <-time.After(timeout + time.Second):
   329  			errGot = fmt.Errorf("expected the request context to have timed out in %s", timeout)
   330  		case <-ctx.Done():
   331  			errGot = ctx.Err()
   332  		}
   333  	})
   334  
   335  	fakeSink := &fakeAuditSink{}
   336  	fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   337  	withDeadline := WithRequestDeadline(handler, fakeSink, fakeRuleEvaluator,
   338  		func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute)
   339  	withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
   340  
   341  	testRequest := newRequest(t, fmt.Sprintf("/api/v1/namespaces?timeout=%s", timeout))
   342  	w := httptest.NewRecorder()
   343  	withDeadline.ServeHTTP(w, testRequest)
   344  
   345  	if errGot != context.DeadlineExceeded {
   346  		t.Errorf("expected error: %#v, but got: %#v", context.DeadlineExceeded, errGot)
   347  	}
   348  }
   349  
   350  func TestWithFailedRequestAudit(t *testing.T) {
   351  	tests := []struct {
   352  		name                          string
   353  		statusErr                     *apierrors.StatusError
   354  		errorHandlerCallCountExpected int
   355  		statusCodeExpected            int
   356  		auditExpected                 bool
   357  	}{
   358  		{
   359  			name:                          "bad request, the error handler is invoked and the request is audited",
   360  			statusErr:                     apierrors.NewBadRequest("error serving request"),
   361  			errorHandlerCallCountExpected: 1,
   362  			statusCodeExpected:            http.StatusBadRequest,
   363  			auditExpected:                 true,
   364  		},
   365  	}
   366  
   367  	for _, test := range tests {
   368  		t.Run(test.name, func(t *testing.T) {
   369  			var (
   370  				errorHandlerCallCountGot int
   371  				rwGot                    http.ResponseWriter
   372  				requestGot               *http.Request
   373  			)
   374  
   375  			errorHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   376  				http.Error(rw, "error serving request", http.StatusBadRequest)
   377  
   378  				errorHandlerCallCountGot++
   379  				requestGot = req
   380  				rwGot = rw
   381  			})
   382  
   383  			fakeSink := &fakeAuditSink{}
   384  			fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil)
   385  
   386  			withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeRuleEvaluator)
   387  
   388  			w := httptest.NewRecorder()
   389  			testRequest := newRequest(t, "/apis/v1/namespaces/default/pods")
   390  			info := request.RequestInfo{}
   391  			testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info))
   392  
   393  			withAudit.ServeHTTP(w, testRequest)
   394  
   395  			if test.errorHandlerCallCountExpected != errorHandlerCallCountGot {
   396  				t.Errorf("expected the testRequest handler to be invoked %d times, but was actually invoked %d times", test.errorHandlerCallCountExpected, errorHandlerCallCountGot)
   397  			}
   398  
   399  			statusCodeGot := w.Result().StatusCode
   400  			if test.statusCodeExpected != statusCodeGot {
   401  				t.Errorf("expected status code %d, but got: %d", test.statusCodeExpected, statusCodeGot)
   402  			}
   403  
   404  			if test.auditExpected {
   405  				// verify that the right http.ResponseWriter is passed to the error handler
   406  				_, ok := rwGot.(*auditResponseWriter)
   407  				if !ok {
   408  					t.Errorf("expected an http.ResponseWriter of type: %T but got: %T", &auditResponseWriter{}, rwGot)
   409  				}
   410  
   411  				auditEventGot := audit.AuditEventFrom(requestGot.Context())
   412  				if auditEventGot == nil {
   413  					t.Fatal("expected an audit event object but got nil")
   414  				}
   415  				if auditEventGot.Stage != auditinternal.StageResponseStarted {
   416  					t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditEventGot.Stage)
   417  				}
   418  				if auditEventGot.ResponseStatus == nil {
   419  					t.Fatal("expected a ResponseStatus field of the audit event object, but got nil")
   420  				}
   421  				if test.statusCodeExpected != int(auditEventGot.ResponseStatus.Code) {
   422  					t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditEventGot.ResponseStatus.Code)
   423  				}
   424  				if test.statusErr.Error() != auditEventGot.ResponseStatus.Message {
   425  					t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditEventGot.ResponseStatus.Message)
   426  				}
   427  
   428  				// verify that the audit event from the request context is written to the audit sink.
   429  				if len(fakeSink.events) != 1 {
   430  					t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events))
   431  				}
   432  				auditEventFromSink := fakeSink.events[0]
   433  				if !reflect.DeepEqual(auditEventGot, auditEventFromSink) {
   434  					t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", cmp.Diff(auditEventGot, auditEventFromSink))
   435  				}
   436  			}
   437  		})
   438  	}
   439  }
   440  
   441  func newRequest(t *testing.T, requestURL string) *http.Request {
   442  	req, err := http.NewRequest(http.MethodGet, requestURL, nil)
   443  	if err != nil {
   444  		t.Fatalf("failed to create new http request - %v", err)
   445  	}
   446  	ctx := audit.WithAuditContext(req.Context())
   447  	return req.WithContext(ctx)
   448  }
   449  
   450  func message(err error) string {
   451  	if err != nil {
   452  		return err.Error()
   453  	}
   454  
   455  	return ""
   456  }
   457  
   458  func newSerializer() runtime.NegotiatedSerializer {
   459  	scheme := runtime.NewScheme()
   460  	return serializer.NewCodecFactory(scheme).WithoutConversion()
   461  }
   462  
   463  type fakeRequestResolver struct{}
   464  
   465  func (r fakeRequestResolver) NewRequestInfo(req *http.Request) (*request.RequestInfo, error) {
   466  	return &request.RequestInfo{}, nil
   467  }
   468  
   469  func deadline(r *http.Request) (time.Duration, bool) {
   470  	if deadline, ok := r.Context().Deadline(); ok {
   471  		remaining := time.Until(deadline)
   472  		return remaining, ok
   473  	}
   474  
   475  	return 0, false
   476  }