k8s.io/apiserver@v0.31.1/pkg/server/filters/goaway_test.go (about)

     1  /*
     2  Copyright 2020 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package filters
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/tls"
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"math/rand"
    27  	"net"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"reflect"
    31  	"sync"
    32  	"testing"
    33  	"time"
    34  
    35  	"golang.org/x/net/http2"
    36  )
    37  
    38  func TestProbabilisticGoawayDecider(t *testing.T) {
    39  	cases := []struct {
    40  		name         string
    41  		chance       float64
    42  		nextFn       func(chance float64) func() float64
    43  		expectGOAWAY bool
    44  	}{
    45  		{
    46  			name:   "always not GOAWAY",
    47  			chance: 0,
    48  			nextFn: func(chance float64) func() float64 {
    49  				return rand.Float64
    50  			},
    51  			expectGOAWAY: false,
    52  		},
    53  		{
    54  			name:   "always GOAWAY",
    55  			chance: 1,
    56  			nextFn: func(chance float64) func() float64 {
    57  				return rand.Float64
    58  			},
    59  			expectGOAWAY: true,
    60  		},
    61  		{
    62  			name:   "hit GOAWAY",
    63  			chance: rand.Float64() + 0.01,
    64  			nextFn: func(chance float64) func() float64 {
    65  				return func() float64 {
    66  					return chance - 0.001
    67  				}
    68  			},
    69  			expectGOAWAY: true,
    70  		},
    71  		{
    72  			name:   "does not hit GOAWAY",
    73  			chance: rand.Float64() + 0.01,
    74  			nextFn: func(chance float64) func() float64 {
    75  				return func() float64 {
    76  					return chance + 0.001
    77  				}
    78  			},
    79  			expectGOAWAY: false,
    80  		},
    81  	}
    82  
    83  	for _, tc := range cases {
    84  		t.Run(tc.name, func(t *testing.T) {
    85  			d := probabilisticGoawayDecider{chance: tc.chance, next: tc.nextFn(tc.chance)}
    86  			result := d.Goaway(nil)
    87  			if result != tc.expectGOAWAY {
    88  				t.Errorf("expect GOAWAY: %v, got: %v", tc.expectGOAWAY, result)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  const (
    95  	urlGet             = "/get"
    96  	urlPost            = "/post"
    97  	urlWatch           = "/watch"
    98  	urlGetWithGoaway   = "/get-with-goaway"
    99  	urlPostWithGoaway  = "/post-with-goaway"
   100  	urlWatchWithGoaway = "/watch-with-goaway"
   101  )
   102  
   103  var (
   104  	// responseBody is the response body which test GOAWAY server sent for each request,
   105  	// for watch request, test GOAWAY server push 1 byte in every second.
   106  	responseBody = []byte("hello")
   107  
   108  	// requestPostBody is the request body which client must send to test GOAWAY server for POST method,
   109  	// otherwise, test GOAWAY server will respond 400 HTTP status code.
   110  	requestPostBody = responseBody
   111  )
   112  
   113  // newTestGOAWAYServer return a test GOAWAY server instance.
   114  func newTestGOAWAYServer() (*httptest.Server, error) {
   115  	watchHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  		timer := time.NewTicker(time.Second)
   117  		defer timer.Stop()
   118  
   119  		w.Header().Set("Transfer-Encoding", "chunked")
   120  		w.WriteHeader(200)
   121  
   122  		flusher, _ := w.(http.Flusher)
   123  		flusher.Flush()
   124  
   125  		count := 0
   126  		for {
   127  			<-timer.C
   128  			n, err := w.Write(responseBody[count : count+1])
   129  			if err != nil {
   130  				return
   131  			}
   132  			flusher.Flush()
   133  			count += n
   134  			if count == len(responseBody) {
   135  				return
   136  			}
   137  		}
   138  	})
   139  	getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   140  		w.WriteHeader(http.StatusOK)
   141  		w.Write(responseBody)
   142  		return
   143  	})
   144  	postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   145  		reqBody, err := ioutil.ReadAll(r.Body)
   146  		if err != nil {
   147  			http.Error(w, err.Error(), http.StatusInternalServerError)
   148  			return
   149  		}
   150  		if !reflect.DeepEqual(requestPostBody, reqBody) {
   151  			http.Error(w, fmt.Sprintf("expect request body: %s, got: %s", requestPostBody, reqBody), http.StatusBadRequest)
   152  			return
   153  		}
   154  
   155  		w.WriteHeader(http.StatusOK)
   156  		w.Write(responseBody)
   157  		return
   158  	})
   159  
   160  	mux := http.NewServeMux()
   161  	mux.Handle(urlGet, WithProbabilisticGoaway(getHandler, 0))
   162  	mux.Handle(urlPost, WithProbabilisticGoaway(postHandler, 0))
   163  	mux.Handle(urlWatch, WithProbabilisticGoaway(watchHandler, 0))
   164  	mux.Handle(urlGetWithGoaway, WithProbabilisticGoaway(getHandler, 1))
   165  	mux.Handle(urlPostWithGoaway, WithProbabilisticGoaway(postHandler, 1))
   166  	mux.Handle(urlWatchWithGoaway, WithProbabilisticGoaway(watchHandler, 1))
   167  
   168  	s := httptest.NewUnstartedServer(mux)
   169  
   170  	http2Options := &http2.Server{}
   171  
   172  	if err := http2.ConfigureServer(s.Config, http2Options); err != nil {
   173  		return nil, fmt.Errorf("failed to configure test server to be HTTP2 server, err: %v", err)
   174  	}
   175  
   176  	s.TLS = s.Config.TLSConfig
   177  
   178  	return s, nil
   179  }
   180  
   181  // watchResponse wraps watch response with data which server send and an error may occur.
   182  type watchResponse struct {
   183  	// body is the response data which test GOAWAY server sent to client
   184  	body []byte
   185  	// err will be set to be a non-nil value if watch request is not end with EOF nor http2.GoAwayError
   186  	err error
   187  }
   188  
   189  // requestGOAWAYServer request test GOAWAY server using specified method and data according to the given url.
   190  // A non-nil channel will be returned if the request is watch, and a watchResponse can be got from the channel when watch done.
   191  func requestGOAWAYServer(client *http.Client, serverBaseURL, url string) (<-chan watchResponse, error) {
   192  	method := http.MethodGet
   193  	var reqBody io.Reader
   194  
   195  	if url == urlPost || url == urlPostWithGoaway {
   196  		method = http.MethodPost
   197  		reqBody = bytes.NewReader(requestPostBody)
   198  	}
   199  
   200  	req, err := http.NewRequest(method, serverBaseURL+url, reqBody)
   201  	if err != nil {
   202  		return nil, fmt.Errorf("unexpect new request error: %v", err)
   203  	}
   204  	resp, err := client.Do(req)
   205  	if err != nil {
   206  		return nil, fmt.Errorf("failed request test server, err: %v", err)
   207  	}
   208  
   209  	if resp.StatusCode != http.StatusOK {
   210  		defer resp.Body.Close()
   211  		body, err := ioutil.ReadAll(resp.Body)
   212  		if err != nil {
   213  			return nil, fmt.Errorf("failed to read response body and status code is %d, error: %v", resp.StatusCode, err)
   214  		}
   215  
   216  		return nil, fmt.Errorf("expect response status code: %d, but got: %d. response body: %s", http.StatusOK, resp.StatusCode, body)
   217  	}
   218  
   219  	// encounter watch bytes received, does not expect to be broken
   220  	if url == urlWatch || url == urlWatchWithGoaway {
   221  		ch := make(chan watchResponse)
   222  		go func() {
   223  			defer resp.Body.Close()
   224  
   225  			body := make([]byte, 0)
   226  			buffer := make([]byte, 1)
   227  			for {
   228  				n, err := resp.Body.Read(buffer)
   229  				if err != nil {
   230  					// urlWatch will receive io.EOF,
   231  					// urlWatchWithGoaway will receive http2.GoAwayError
   232  					if err == io.EOF {
   233  						err = nil
   234  					} else if _, ok := err.(http2.GoAwayError); ok {
   235  						err = nil
   236  					}
   237  
   238  					ch <- watchResponse{
   239  						body: body,
   240  						err:  err,
   241  					}
   242  					return
   243  				}
   244  				body = append(body, buffer[0:n]...)
   245  			}
   246  		}()
   247  		return ch, nil
   248  	}
   249  
   250  	defer resp.Body.Close()
   251  	body, err := ioutil.ReadAll(resp.Body)
   252  	if err != nil {
   253  		return nil, fmt.Errorf("failed to read response body, error: %v", err)
   254  	}
   255  
   256  	if !reflect.DeepEqual(responseBody, body) {
   257  		return nil, fmt.Errorf("expect response body: %s, got: %s", string(responseBody), string(body))
   258  	}
   259  
   260  	return nil, nil
   261  }
   262  
   263  // TestClientReceivedGOAWAY tests the in-flight watch requests will not be affected and new requests use a new
   264  // connection after client received GOAWAY.
   265  func TestClientReceivedGOAWAY(t *testing.T) {
   266  	s, err := newTestGOAWAYServer()
   267  	if err != nil {
   268  		t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err)
   269  	}
   270  
   271  	s.StartTLS()
   272  	defer s.Close()
   273  
   274  	cases := []struct {
   275  		name string
   276  		reqs []string
   277  		// expectConnections always equals to GOAWAY requests(urlGoaway or urlWatchWithGoaway) + 1
   278  		expectConnections int
   279  	}{
   280  		{
   281  			name:              "all normal requests use only one connection",
   282  			reqs:              []string{urlGet, urlPost, urlGet},
   283  			expectConnections: 1,
   284  		},
   285  		{
   286  			name:              "got GOAWAY after set-up watch",
   287  			reqs:              []string{urlPost, urlWatch, urlGetWithGoaway, urlGet, urlPost},
   288  			expectConnections: 2,
   289  		},
   290  		{
   291  			name:              "got GOAWAY after set-up watch, and set-up a new watch",
   292  			reqs:              []string{urlGet, urlWatch, urlGetWithGoaway, urlWatch, urlGet, urlPost},
   293  			expectConnections: 2,
   294  		},
   295  		{
   296  			name:              "got 2 GOAWAY after set-up watch",
   297  			reqs:              []string{urlPost, urlWatch, urlGetWithGoaway, urlGetWithGoaway, urlGet, urlPost},
   298  			expectConnections: 3,
   299  		},
   300  		{
   301  			name:              "combine with watch-with-goaway",
   302  			reqs:              []string{urlGet, urlWatchWithGoaway, urlGet, urlWatch, urlGetWithGoaway, urlGet, urlPost},
   303  			expectConnections: 3,
   304  		},
   305  	}
   306  
   307  	for _, tc := range cases {
   308  		t.Run(tc.name, func(t *testing.T) {
   309  			// localAddr indicates how many TCP connection set up
   310  			localAddr := make([]string, 0)
   311  
   312  			// create the http client
   313  			dialFn := func(network, addr string, cfg *tls.Config) (conn net.Conn, err error) {
   314  				conn, err = tls.Dial(network, addr, cfg)
   315  				if err != nil {
   316  					t.Fatalf("unexpect connection err: %v", err)
   317  				}
   318  
   319  				localAddr = append(localAddr, conn.LocalAddr().String())
   320  				return
   321  			}
   322  			tlsConfig := &tls.Config{
   323  				InsecureSkipVerify: true,
   324  				NextProtos:         []string{http2.NextProtoTLS},
   325  			}
   326  			tr := &http.Transport{
   327  				TLSHandshakeTimeout: 10 * time.Second,
   328  				TLSClientConfig:     tlsConfig,
   329  				// Disable connection pooling to avoid additional connections
   330  				// that cause the test to flake
   331  				MaxIdleConnsPerHost: -1,
   332  				DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   333  					return dialFn(network, addr, tlsConfig)
   334  				},
   335  			}
   336  			if err := http2.ConfigureTransport(tr); err != nil {
   337  				t.Fatalf("failed to configure http transport, err: %v", err)
   338  			}
   339  
   340  			client := &http.Client{
   341  				Transport: tr,
   342  			}
   343  
   344  			watchChs := make([]<-chan watchResponse, 0)
   345  			for _, url := range tc.reqs {
   346  				w, err := requestGOAWAYServer(client, s.URL, url)
   347  				if err != nil {
   348  					t.Fatalf("failed to request server, err: %v", err)
   349  				}
   350  				if w != nil {
   351  					watchChs = append(watchChs, w)
   352  				}
   353  			}
   354  
   355  			// check TCP connection count
   356  			if tc.expectConnections != len(localAddr) {
   357  				t.Fatalf("expect TCP connection: %d, actual: %d", tc.expectConnections, len(localAddr))
   358  			}
   359  
   360  			// check if watch request is broken by GOAWAY frame
   361  			watchTimeout := time.NewTimer(time.Second * 10)
   362  			defer watchTimeout.Stop()
   363  			for _, watchCh := range watchChs {
   364  				select {
   365  				case watchResp := <-watchCh:
   366  					if watchResp.err != nil {
   367  						t.Fatalf("watch response got an unexepct error: %v", watchResp.err)
   368  					}
   369  					if !reflect.DeepEqual(responseBody, watchResp.body) {
   370  						t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body)
   371  					}
   372  				case <-watchTimeout.C:
   373  					t.Error("watch receive timeout")
   374  				}
   375  			}
   376  		})
   377  	}
   378  }
   379  
   380  // TestGOAWAYHTTP1Requests tests GOAWAY filter will not affect HTTP1.1 requests.
   381  func TestGOAWAYHTTP1Requests(t *testing.T) {
   382  	s := httptest.NewUnstartedServer(WithProbabilisticGoaway(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   383  		w.WriteHeader(http.StatusOK)
   384  		w.Write([]byte("hello"))
   385  	}), 1))
   386  
   387  	http2Options := &http2.Server{}
   388  
   389  	if err := http2.ConfigureServer(s.Config, http2Options); err != nil {
   390  		t.Fatalf("failed to configure test server to be HTTP2 server, err: %v", err)
   391  	}
   392  
   393  	s.TLS = s.Config.TLSConfig
   394  	s.StartTLS()
   395  	defer s.Close()
   396  
   397  	tlsConfig := &tls.Config{
   398  		InsecureSkipVerify: true,
   399  		NextProtos:         []string{"http/1.1"},
   400  	}
   401  
   402  	client := http.Client{
   403  		Transport: &http.Transport{
   404  			TLSClientConfig: tlsConfig,
   405  		},
   406  	}
   407  
   408  	resp, err := client.Get(s.URL)
   409  	if err != nil {
   410  		t.Fatalf("failed to request the server, err: %v", err)
   411  	}
   412  
   413  	if v := resp.Header.Get("Connection"); v != "" {
   414  		t.Errorf("expect response HTTP header Connection to be empty, but got: %s", v)
   415  	}
   416  }
   417  
   418  // TestGOAWAYConcurrency tests GOAWAY frame will not affect concurrency requests in a single http client instance.
   419  func TestGOAWAYConcurrency(t *testing.T) {
   420  	s, err := newTestGOAWAYServer()
   421  	if err != nil {
   422  		t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err)
   423  	}
   424  
   425  	s.StartTLS()
   426  	defer s.Close()
   427  
   428  	// create the http client
   429  	tlsConfig := &tls.Config{
   430  		InsecureSkipVerify: true,
   431  		NextProtos:         []string{http2.NextProtoTLS},
   432  	}
   433  	tr := &http.Transport{
   434  		TLSHandshakeTimeout: 10 * time.Second,
   435  		TLSClientConfig:     tlsConfig,
   436  		MaxIdleConnsPerHost: 25,
   437  	}
   438  	if err := http2.ConfigureTransport(tr); err != nil {
   439  		t.Fatalf("failed to configure http transport, err: %v", err)
   440  	}
   441  
   442  	client := &http.Client{
   443  		Transport: tr,
   444  	}
   445  	if err != nil {
   446  		t.Fatalf("failed to set-up client, err: %v", err)
   447  	}
   448  
   449  	const (
   450  		requestCount = 300
   451  		workers      = 10
   452  	)
   453  
   454  	expectWatchers := 0
   455  
   456  	urlsForTest := []string{urlGet, urlPost, urlWatch, urlGetWithGoaway, urlPostWithGoaway, urlWatchWithGoaway}
   457  	urls := make(chan string, requestCount)
   458  	for i := 0; i < requestCount; i++ {
   459  		index := rand.Intn(len(urlsForTest))
   460  		url := urlsForTest[index]
   461  
   462  		if url == urlWatch || url == urlWatchWithGoaway {
   463  			expectWatchers++
   464  		}
   465  
   466  		urls <- url
   467  	}
   468  	close(urls)
   469  
   470  	wg := &sync.WaitGroup{}
   471  	wg.Add(workers)
   472  
   473  	watchers := make(chan (<-chan watchResponse), expectWatchers)
   474  	for i := 0; i < workers; i++ {
   475  		go func() {
   476  			defer wg.Done()
   477  
   478  			for {
   479  				url, ok := <-urls
   480  				if !ok {
   481  					return
   482  				}
   483  
   484  				w, err := requestGOAWAYServer(client, s.URL, url)
   485  				if err != nil {
   486  					t.Errorf("failed to request %q, err: %v", url, err)
   487  				}
   488  
   489  				if w != nil {
   490  					watchers <- w
   491  				}
   492  			}
   493  		}()
   494  	}
   495  
   496  	wg.Wait()
   497  
   498  	// check if watch request is broken by GOAWAY frame
   499  	watchTimeout := time.NewTimer(time.Second * 10)
   500  	defer watchTimeout.Stop()
   501  	for i := 0; i < expectWatchers; i++ {
   502  		var watcher <-chan watchResponse
   503  
   504  		select {
   505  		case watcher = <-watchers:
   506  		default:
   507  			t.Fatalf("expect watcher count: %d, but got: %d", expectWatchers, i)
   508  		}
   509  
   510  		select {
   511  		case watchResp := <-watcher:
   512  			if watchResp.err != nil {
   513  				t.Fatalf("watch response got an unexepct error: %v", watchResp.err)
   514  			}
   515  			if !reflect.DeepEqual(responseBody, watchResp.body) {
   516  				t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body)
   517  			}
   518  		case <-watchTimeout.C:
   519  			t.Error("watch receive timeout")
   520  		}
   521  	}
   522  }