k8s.io/apiserver@v0.31.1/pkg/endpoints/responsewriter/wrapper_test.go (about)

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package responsewriter
    18  
    19  import (
    20  	"bufio"
    21  	"net"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  func TestWithHTTP1(t *testing.T) {
    30  	var originalWant http.ResponseWriter
    31  	counterGot := &counter{}
    32  	chain := func(h http.Handler) http.Handler {
    33  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    34  			if originalWant == nil {
    35  				originalWant = w
    36  			}
    37  
    38  			assertCloseNotifierFlusherHijacker(t, true, w)
    39  
    40  			decorator := &fakeResponseWriterDecorator{
    41  				ResponseWriter: w,
    42  				counter:        counterGot,
    43  			}
    44  			wrapped := WrapForHTTP1Or2(decorator)
    45  
    46  			assertCloseNotifierFlusherHijacker(t, true, wrapped)
    47  
    48  			originalGot := GetOriginal(wrapped)
    49  			if originalWant != originalGot {
    50  				t.Errorf("Expected GetOriginal to return the original ResponseWriter object")
    51  				return
    52  			}
    53  
    54  			h.ServeHTTP(wrapped, r)
    55  		})
    56  	}
    57  
    58  	// wrap the original http.ResponseWriter multiple times
    59  	handler := chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    60  		// at this point, the original ResponseWriter object has been wrapped three times
    61  		// so each decorator is expected to tick the count by one for each method.
    62  		defer counterGot.assert(t, &counter{FlushInvoked: 3, CloseNotifyInvoked: 3, HijackInvoked: 3})
    63  
    64  		//nolint:staticcheck // SA1019
    65  		w.(http.CloseNotifier).CloseNotify()
    66  		w.(http.Flusher).Flush()
    67  
    68  		conn, _, err := w.(http.Hijacker).Hijack()
    69  		if err != nil {
    70  			t.Errorf("Expected Hijack to succeed, but got error: %v", err)
    71  			return
    72  		}
    73  		conn.Close()
    74  	}))
    75  	handler = chain(handler)
    76  	handler = chain(handler)
    77  
    78  	server := newServer(t, handler, false)
    79  	defer server.Close()
    80  
    81  	sendRequest(t, server)
    82  }
    83  
    84  func TestWithHTTP2(t *testing.T) {
    85  	var originalWant http.ResponseWriter
    86  	counterGot := &counter{}
    87  	chain := func(h http.Handler) http.Handler {
    88  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    89  			if originalWant == nil {
    90  				originalWant = w
    91  			}
    92  
    93  			assertCloseNotifierFlusherHijacker(t, false, w)
    94  
    95  			decorator := &fakeResponseWriterDecorator{
    96  				ResponseWriter: w,
    97  				counter:        counterGot,
    98  			}
    99  			wrapped := WrapForHTTP1Or2(decorator)
   100  
   101  			assertCloseNotifierFlusherHijacker(t, false, wrapped)
   102  
   103  			originalGot := GetOriginal(wrapped)
   104  			if originalWant != originalGot {
   105  				t.Errorf("Expected GetOriginal to return the original ResponseWriter object")
   106  				return
   107  			}
   108  
   109  			h.ServeHTTP(wrapped, r)
   110  		})
   111  	}
   112  
   113  	// wrap the original http.ResponseWriter multiple times
   114  	handler := chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   115  		// at this point, the original ResponseWriter object has been wrapped three times
   116  		// so each decorator is expected to tick the count by one for each method.
   117  		defer counterGot.assert(t, &counter{FlushInvoked: 3, CloseNotifyInvoked: 3, HijackInvoked: 0})
   118  
   119  		//nolint:staticcheck // SA1019
   120  		w.(http.CloseNotifier).CloseNotify()
   121  		w.(http.Flusher).Flush()
   122  
   123  	}))
   124  	handler = chain(handler)
   125  	handler = chain(handler)
   126  
   127  	server := newServer(t, handler, true)
   128  	defer server.Close()
   129  
   130  	sendRequest(t, server)
   131  }
   132  
   133  func TestGetOriginal(t *testing.T) {
   134  	tests := []struct {
   135  		name          string
   136  		wrap          func() (http.ResponseWriter, http.ResponseWriter)
   137  		panicExpected bool
   138  	}{
   139  		{
   140  			name: "not wrapped",
   141  			wrap: func() (http.ResponseWriter, http.ResponseWriter) {
   142  				original := &FakeResponseWriter{}
   143  				return original, original
   144  			},
   145  		},
   146  		{
   147  			name: "wrapped once",
   148  			wrap: func() (http.ResponseWriter, http.ResponseWriter) {
   149  				original := &FakeResponseWriter{}
   150  				return original, &fakeResponseWriterDecorator{
   151  					ResponseWriter: original,
   152  				}
   153  			},
   154  		},
   155  		{
   156  			name: "wrapped multiple times",
   157  			wrap: func() (http.ResponseWriter, http.ResponseWriter) {
   158  				original := &FakeResponseWriter{}
   159  				return original, &fakeResponseWriterDecorator{
   160  					ResponseWriter: &fakeResponseWriterDecorator{
   161  						ResponseWriter: &fakeResponseWriterDecorator{
   162  							ResponseWriter: original,
   163  						},
   164  					},
   165  				}
   166  			},
   167  		},
   168  		{
   169  			name: "wraps itself",
   170  			wrap: func() (http.ResponseWriter, http.ResponseWriter) {
   171  				faulty := &fakeResponseWriterDecorator{}
   172  				faulty.ResponseWriter = faulty
   173  				return faulty, &fakeResponseWriterDecorator{
   174  					ResponseWriter: faulty,
   175  				}
   176  			},
   177  			panicExpected: true,
   178  		},
   179  	}
   180  
   181  	for _, test := range tests {
   182  		t.Run(test.name, func(t *testing.T) {
   183  			originalExpected, wrapped := test.wrap()
   184  
   185  			func() {
   186  				defer func() {
   187  					err := recover()
   188  					switch {
   189  					case err != nil:
   190  						if !test.panicExpected {
   191  							t.Errorf("Expected no panic, but got: %v", err)
   192  						}
   193  					default:
   194  						if test.panicExpected {
   195  							t.Errorf("Expected a panic")
   196  						}
   197  					}
   198  				}()
   199  
   200  				originalGot := GetOriginal(wrapped)
   201  				if originalExpected != originalGot {
   202  					t.Errorf("Expected to get tehe original http.ResponseWriter object")
   203  				}
   204  			}()
   205  		})
   206  	}
   207  }
   208  
   209  func newServer(t *testing.T, h http.Handler, http2 bool) *httptest.Server {
   210  	server := httptest.NewUnstartedServer(h)
   211  	if http2 {
   212  		server.EnableHTTP2 = true
   213  		server.StartTLS()
   214  	} else {
   215  		server.Start()
   216  	}
   217  	_, err := url.Parse(server.URL)
   218  	if err != nil {
   219  		t.Fatalf("Expected the server to have a valid URL, but got: %s", server.URL)
   220  	}
   221  	return server
   222  }
   223  
   224  func sendRequest(t *testing.T, server *httptest.Server) {
   225  	req, err := http.NewRequest("GET", server.URL, nil)
   226  	if err != nil {
   227  		t.Fatalf("error creating request: %v", err)
   228  	}
   229  
   230  	client := server.Client()
   231  	client.Timeout = 30 * time.Second
   232  	_, err = client.Do(req)
   233  	if err != nil {
   234  		t.Fatalf("Unexpected non-nil err from client.Do: %v", err)
   235  	}
   236  }
   237  
   238  func assertCloseNotifierFlusherHijacker(t *testing.T, hijackableExpected bool, w http.ResponseWriter) {
   239  	// the http.ResponseWriter object for both http/1.x and http2
   240  	// implement http.Flusher and http.CloseNotifier
   241  	if _, ok := w.(http.Flusher); !ok {
   242  		t.Errorf("Expected the http.ResponseWriter object to implement http.Flusher")
   243  	}
   244  
   245  	//nolint:staticcheck // SA1019
   246  	if _, ok := w.(http.CloseNotifier); !ok {
   247  		t.Errorf("Expected the http.ResponseWriter object to implement http.CloseNotifier")
   248  	}
   249  
   250  	// http/1.x implements http.Hijacker, not http2
   251  	if _, ok := w.(http.Hijacker); ok != hijackableExpected {
   252  		t.Errorf("Unexpected http.Hijacker implementation, expected: %t, but got: %t", hijackableExpected, ok)
   253  	}
   254  }
   255  
   256  type counter struct {
   257  	FlushInvoked       int
   258  	HijackInvoked      int
   259  	CloseNotifyInvoked int
   260  }
   261  
   262  func (c *counter) assert(t *testing.T, expected *counter) {
   263  	if expected.FlushInvoked != c.FlushInvoked {
   264  		t.Errorf("Expected Flush() count to match, wanted: %d, but got: %d", expected.FlushInvoked, c.FlushInvoked)
   265  	}
   266  	if expected.CloseNotifyInvoked != c.CloseNotifyInvoked {
   267  		t.Errorf("Expected CloseNotify() count to match, wanted: %d, but got: %d", expected.CloseNotifyInvoked, c.CloseNotifyInvoked)
   268  	}
   269  	if expected.HijackInvoked != c.HijackInvoked {
   270  		t.Errorf("Expected Hijack() count to match, wanted: %d, but got: %d", expected.HijackInvoked, c.HijackInvoked)
   271  	}
   272  }
   273  
   274  type fakeResponseWriterDecorator struct {
   275  	http.ResponseWriter
   276  	counter *counter
   277  }
   278  
   279  func (fw *fakeResponseWriterDecorator) Unwrap() http.ResponseWriter { return fw.ResponseWriter }
   280  func (fw *fakeResponseWriterDecorator) Flush() {
   281  	if fw.counter != nil {
   282  		fw.counter.FlushInvoked++
   283  	}
   284  	fw.ResponseWriter.(http.Flusher).Flush()
   285  }
   286  func (fw *fakeResponseWriterDecorator) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   287  	if fw.counter != nil {
   288  		fw.counter.HijackInvoked++
   289  	}
   290  	return fw.ResponseWriter.(http.Hijacker).Hijack()
   291  }
   292  func (fw *fakeResponseWriterDecorator) CloseNotify() <-chan bool {
   293  	if fw.counter != nil {
   294  		fw.counter.CloseNotifyInvoked++
   295  	}
   296  	//nolint:staticcheck // SA1019
   297  	return fw.ResponseWriter.(http.CloseNotifier).CloseNotify()
   298  }