k8s.io/apiserver@v0.31.1/pkg/server/filters/goaway_test.go (about) 1 /* 2 Copyright 2020 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package filters 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "fmt" 24 "io" 25 "io/ioutil" 26 "math/rand" 27 "net" 28 "net/http" 29 "net/http/httptest" 30 "reflect" 31 "sync" 32 "testing" 33 "time" 34 35 "golang.org/x/net/http2" 36 ) 37 38 func TestProbabilisticGoawayDecider(t *testing.T) { 39 cases := []struct { 40 name string 41 chance float64 42 nextFn func(chance float64) func() float64 43 expectGOAWAY bool 44 }{ 45 { 46 name: "always not GOAWAY", 47 chance: 0, 48 nextFn: func(chance float64) func() float64 { 49 return rand.Float64 50 }, 51 expectGOAWAY: false, 52 }, 53 { 54 name: "always GOAWAY", 55 chance: 1, 56 nextFn: func(chance float64) func() float64 { 57 return rand.Float64 58 }, 59 expectGOAWAY: true, 60 }, 61 { 62 name: "hit GOAWAY", 63 chance: rand.Float64() + 0.01, 64 nextFn: func(chance float64) func() float64 { 65 return func() float64 { 66 return chance - 0.001 67 } 68 }, 69 expectGOAWAY: true, 70 }, 71 { 72 name: "does not hit GOAWAY", 73 chance: rand.Float64() + 0.01, 74 nextFn: func(chance float64) func() float64 { 75 return func() float64 { 76 return chance + 0.001 77 } 78 }, 79 expectGOAWAY: false, 80 }, 81 } 82 83 for _, tc := range cases { 84 t.Run(tc.name, func(t *testing.T) { 85 d := probabilisticGoawayDecider{chance: tc.chance, next: tc.nextFn(tc.chance)} 86 result := d.Goaway(nil) 87 if result != tc.expectGOAWAY { 88 t.Errorf("expect GOAWAY: %v, got: %v", tc.expectGOAWAY, result) 89 } 90 }) 91 } 92 } 93 94 const ( 95 urlGet = "/get" 96 urlPost = "/post" 97 urlWatch = "/watch" 98 urlGetWithGoaway = "/get-with-goaway" 99 urlPostWithGoaway = "/post-with-goaway" 100 urlWatchWithGoaway = "/watch-with-goaway" 101 ) 102 103 var ( 104 // responseBody is the response body which test GOAWAY server sent for each request, 105 // for watch request, test GOAWAY server push 1 byte in every second. 106 responseBody = []byte("hello") 107 108 // requestPostBody is the request body which client must send to test GOAWAY server for POST method, 109 // otherwise, test GOAWAY server will respond 400 HTTP status code. 110 requestPostBody = responseBody 111 ) 112 113 // newTestGOAWAYServer return a test GOAWAY server instance. 114 func newTestGOAWAYServer() (*httptest.Server, error) { 115 watchHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 timer := time.NewTicker(time.Second) 117 defer timer.Stop() 118 119 w.Header().Set("Transfer-Encoding", "chunked") 120 w.WriteHeader(200) 121 122 flusher, _ := w.(http.Flusher) 123 flusher.Flush() 124 125 count := 0 126 for { 127 <-timer.C 128 n, err := w.Write(responseBody[count : count+1]) 129 if err != nil { 130 return 131 } 132 flusher.Flush() 133 count += n 134 if count == len(responseBody) { 135 return 136 } 137 } 138 }) 139 getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 140 w.WriteHeader(http.StatusOK) 141 w.Write(responseBody) 142 return 143 }) 144 postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 reqBody, err := ioutil.ReadAll(r.Body) 146 if err != nil { 147 http.Error(w, err.Error(), http.StatusInternalServerError) 148 return 149 } 150 if !reflect.DeepEqual(requestPostBody, reqBody) { 151 http.Error(w, fmt.Sprintf("expect request body: %s, got: %s", requestPostBody, reqBody), http.StatusBadRequest) 152 return 153 } 154 155 w.WriteHeader(http.StatusOK) 156 w.Write(responseBody) 157 return 158 }) 159 160 mux := http.NewServeMux() 161 mux.Handle(urlGet, WithProbabilisticGoaway(getHandler, 0)) 162 mux.Handle(urlPost, WithProbabilisticGoaway(postHandler, 0)) 163 mux.Handle(urlWatch, WithProbabilisticGoaway(watchHandler, 0)) 164 mux.Handle(urlGetWithGoaway, WithProbabilisticGoaway(getHandler, 1)) 165 mux.Handle(urlPostWithGoaway, WithProbabilisticGoaway(postHandler, 1)) 166 mux.Handle(urlWatchWithGoaway, WithProbabilisticGoaway(watchHandler, 1)) 167 168 s := httptest.NewUnstartedServer(mux) 169 170 http2Options := &http2.Server{} 171 172 if err := http2.ConfigureServer(s.Config, http2Options); err != nil { 173 return nil, fmt.Errorf("failed to configure test server to be HTTP2 server, err: %v", err) 174 } 175 176 s.TLS = s.Config.TLSConfig 177 178 return s, nil 179 } 180 181 // watchResponse wraps watch response with data which server send and an error may occur. 182 type watchResponse struct { 183 // body is the response data which test GOAWAY server sent to client 184 body []byte 185 // err will be set to be a non-nil value if watch request is not end with EOF nor http2.GoAwayError 186 err error 187 } 188 189 // requestGOAWAYServer request test GOAWAY server using specified method and data according to the given url. 190 // A non-nil channel will be returned if the request is watch, and a watchResponse can be got from the channel when watch done. 191 func requestGOAWAYServer(client *http.Client, serverBaseURL, url string) (<-chan watchResponse, error) { 192 method := http.MethodGet 193 var reqBody io.Reader 194 195 if url == urlPost || url == urlPostWithGoaway { 196 method = http.MethodPost 197 reqBody = bytes.NewReader(requestPostBody) 198 } 199 200 req, err := http.NewRequest(method, serverBaseURL+url, reqBody) 201 if err != nil { 202 return nil, fmt.Errorf("unexpect new request error: %v", err) 203 } 204 resp, err := client.Do(req) 205 if err != nil { 206 return nil, fmt.Errorf("failed request test server, err: %v", err) 207 } 208 209 if resp.StatusCode != http.StatusOK { 210 defer resp.Body.Close() 211 body, err := ioutil.ReadAll(resp.Body) 212 if err != nil { 213 return nil, fmt.Errorf("failed to read response body and status code is %d, error: %v", resp.StatusCode, err) 214 } 215 216 return nil, fmt.Errorf("expect response status code: %d, but got: %d. response body: %s", http.StatusOK, resp.StatusCode, body) 217 } 218 219 // encounter watch bytes received, does not expect to be broken 220 if url == urlWatch || url == urlWatchWithGoaway { 221 ch := make(chan watchResponse) 222 go func() { 223 defer resp.Body.Close() 224 225 body := make([]byte, 0) 226 buffer := make([]byte, 1) 227 for { 228 n, err := resp.Body.Read(buffer) 229 if err != nil { 230 // urlWatch will receive io.EOF, 231 // urlWatchWithGoaway will receive http2.GoAwayError 232 if err == io.EOF { 233 err = nil 234 } else if _, ok := err.(http2.GoAwayError); ok { 235 err = nil 236 } 237 238 ch <- watchResponse{ 239 body: body, 240 err: err, 241 } 242 return 243 } 244 body = append(body, buffer[0:n]...) 245 } 246 }() 247 return ch, nil 248 } 249 250 defer resp.Body.Close() 251 body, err := ioutil.ReadAll(resp.Body) 252 if err != nil { 253 return nil, fmt.Errorf("failed to read response body, error: %v", err) 254 } 255 256 if !reflect.DeepEqual(responseBody, body) { 257 return nil, fmt.Errorf("expect response body: %s, got: %s", string(responseBody), string(body)) 258 } 259 260 return nil, nil 261 } 262 263 // TestClientReceivedGOAWAY tests the in-flight watch requests will not be affected and new requests use a new 264 // connection after client received GOAWAY. 265 func TestClientReceivedGOAWAY(t *testing.T) { 266 s, err := newTestGOAWAYServer() 267 if err != nil { 268 t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err) 269 } 270 271 s.StartTLS() 272 defer s.Close() 273 274 cases := []struct { 275 name string 276 reqs []string 277 // expectConnections always equals to GOAWAY requests(urlGoaway or urlWatchWithGoaway) + 1 278 expectConnections int 279 }{ 280 { 281 name: "all normal requests use only one connection", 282 reqs: []string{urlGet, urlPost, urlGet}, 283 expectConnections: 1, 284 }, 285 { 286 name: "got GOAWAY after set-up watch", 287 reqs: []string{urlPost, urlWatch, urlGetWithGoaway, urlGet, urlPost}, 288 expectConnections: 2, 289 }, 290 { 291 name: "got GOAWAY after set-up watch, and set-up a new watch", 292 reqs: []string{urlGet, urlWatch, urlGetWithGoaway, urlWatch, urlGet, urlPost}, 293 expectConnections: 2, 294 }, 295 { 296 name: "got 2 GOAWAY after set-up watch", 297 reqs: []string{urlPost, urlWatch, urlGetWithGoaway, urlGetWithGoaway, urlGet, urlPost}, 298 expectConnections: 3, 299 }, 300 { 301 name: "combine with watch-with-goaway", 302 reqs: []string{urlGet, urlWatchWithGoaway, urlGet, urlWatch, urlGetWithGoaway, urlGet, urlPost}, 303 expectConnections: 3, 304 }, 305 } 306 307 for _, tc := range cases { 308 t.Run(tc.name, func(t *testing.T) { 309 // localAddr indicates how many TCP connection set up 310 localAddr := make([]string, 0) 311 312 // create the http client 313 dialFn := func(network, addr string, cfg *tls.Config) (conn net.Conn, err error) { 314 conn, err = tls.Dial(network, addr, cfg) 315 if err != nil { 316 t.Fatalf("unexpect connection err: %v", err) 317 } 318 319 localAddr = append(localAddr, conn.LocalAddr().String()) 320 return 321 } 322 tlsConfig := &tls.Config{ 323 InsecureSkipVerify: true, 324 NextProtos: []string{http2.NextProtoTLS}, 325 } 326 tr := &http.Transport{ 327 TLSHandshakeTimeout: 10 * time.Second, 328 TLSClientConfig: tlsConfig, 329 // Disable connection pooling to avoid additional connections 330 // that cause the test to flake 331 MaxIdleConnsPerHost: -1, 332 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 333 return dialFn(network, addr, tlsConfig) 334 }, 335 } 336 if err := http2.ConfigureTransport(tr); err != nil { 337 t.Fatalf("failed to configure http transport, err: %v", err) 338 } 339 340 client := &http.Client{ 341 Transport: tr, 342 } 343 344 watchChs := make([]<-chan watchResponse, 0) 345 for _, url := range tc.reqs { 346 w, err := requestGOAWAYServer(client, s.URL, url) 347 if err != nil { 348 t.Fatalf("failed to request server, err: %v", err) 349 } 350 if w != nil { 351 watchChs = append(watchChs, w) 352 } 353 } 354 355 // check TCP connection count 356 if tc.expectConnections != len(localAddr) { 357 t.Fatalf("expect TCP connection: %d, actual: %d", tc.expectConnections, len(localAddr)) 358 } 359 360 // check if watch request is broken by GOAWAY frame 361 watchTimeout := time.NewTimer(time.Second * 10) 362 defer watchTimeout.Stop() 363 for _, watchCh := range watchChs { 364 select { 365 case watchResp := <-watchCh: 366 if watchResp.err != nil { 367 t.Fatalf("watch response got an unexepct error: %v", watchResp.err) 368 } 369 if !reflect.DeepEqual(responseBody, watchResp.body) { 370 t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body) 371 } 372 case <-watchTimeout.C: 373 t.Error("watch receive timeout") 374 } 375 } 376 }) 377 } 378 } 379 380 // TestGOAWAYHTTP1Requests tests GOAWAY filter will not affect HTTP1.1 requests. 381 func TestGOAWAYHTTP1Requests(t *testing.T) { 382 s := httptest.NewUnstartedServer(WithProbabilisticGoaway(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 383 w.WriteHeader(http.StatusOK) 384 w.Write([]byte("hello")) 385 }), 1)) 386 387 http2Options := &http2.Server{} 388 389 if err := http2.ConfigureServer(s.Config, http2Options); err != nil { 390 t.Fatalf("failed to configure test server to be HTTP2 server, err: %v", err) 391 } 392 393 s.TLS = s.Config.TLSConfig 394 s.StartTLS() 395 defer s.Close() 396 397 tlsConfig := &tls.Config{ 398 InsecureSkipVerify: true, 399 NextProtos: []string{"http/1.1"}, 400 } 401 402 client := http.Client{ 403 Transport: &http.Transport{ 404 TLSClientConfig: tlsConfig, 405 }, 406 } 407 408 resp, err := client.Get(s.URL) 409 if err != nil { 410 t.Fatalf("failed to request the server, err: %v", err) 411 } 412 413 if v := resp.Header.Get("Connection"); v != "" { 414 t.Errorf("expect response HTTP header Connection to be empty, but got: %s", v) 415 } 416 } 417 418 // TestGOAWAYConcurrency tests GOAWAY frame will not affect concurrency requests in a single http client instance. 419 func TestGOAWAYConcurrency(t *testing.T) { 420 s, err := newTestGOAWAYServer() 421 if err != nil { 422 t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err) 423 } 424 425 s.StartTLS() 426 defer s.Close() 427 428 // create the http client 429 tlsConfig := &tls.Config{ 430 InsecureSkipVerify: true, 431 NextProtos: []string{http2.NextProtoTLS}, 432 } 433 tr := &http.Transport{ 434 TLSHandshakeTimeout: 10 * time.Second, 435 TLSClientConfig: tlsConfig, 436 MaxIdleConnsPerHost: 25, 437 } 438 if err := http2.ConfigureTransport(tr); err != nil { 439 t.Fatalf("failed to configure http transport, err: %v", err) 440 } 441 442 client := &http.Client{ 443 Transport: tr, 444 } 445 if err != nil { 446 t.Fatalf("failed to set-up client, err: %v", err) 447 } 448 449 const ( 450 requestCount = 300 451 workers = 10 452 ) 453 454 expectWatchers := 0 455 456 urlsForTest := []string{urlGet, urlPost, urlWatch, urlGetWithGoaway, urlPostWithGoaway, urlWatchWithGoaway} 457 urls := make(chan string, requestCount) 458 for i := 0; i < requestCount; i++ { 459 index := rand.Intn(len(urlsForTest)) 460 url := urlsForTest[index] 461 462 if url == urlWatch || url == urlWatchWithGoaway { 463 expectWatchers++ 464 } 465 466 urls <- url 467 } 468 close(urls) 469 470 wg := &sync.WaitGroup{} 471 wg.Add(workers) 472 473 watchers := make(chan (<-chan watchResponse), expectWatchers) 474 for i := 0; i < workers; i++ { 475 go func() { 476 defer wg.Done() 477 478 for { 479 url, ok := <-urls 480 if !ok { 481 return 482 } 483 484 w, err := requestGOAWAYServer(client, s.URL, url) 485 if err != nil { 486 t.Errorf("failed to request %q, err: %v", url, err) 487 } 488 489 if w != nil { 490 watchers <- w 491 } 492 } 493 }() 494 } 495 496 wg.Wait() 497 498 // check if watch request is broken by GOAWAY frame 499 watchTimeout := time.NewTimer(time.Second * 10) 500 defer watchTimeout.Stop() 501 for i := 0; i < expectWatchers; i++ { 502 var watcher <-chan watchResponse 503 504 select { 505 case watcher = <-watchers: 506 default: 507 t.Fatalf("expect watcher count: %d, but got: %d", expectWatchers, i) 508 } 509 510 select { 511 case watchResp := <-watcher: 512 if watchResp.err != nil { 513 t.Fatalf("watch response got an unexepct error: %v", watchResp.err) 514 } 515 if !reflect.DeepEqual(responseBody, watchResp.body) { 516 t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body) 517 } 518 case <-watchTimeout.C: 519 t.Error("watch receive timeout") 520 } 521 } 522 }