k8s.io/apiserver@v0.31.1/pkg/endpoints/responsewriter/wrapper_test.go (about) 1 /* 2 Copyright 2021 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package responsewriter 18 19 import ( 20 "bufio" 21 "net" 22 "net/http" 23 "net/http/httptest" 24 "net/url" 25 "testing" 26 "time" 27 ) 28 29 func TestWithHTTP1(t *testing.T) { 30 var originalWant http.ResponseWriter 31 counterGot := &counter{} 32 chain := func(h http.Handler) http.Handler { 33 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 if originalWant == nil { 35 originalWant = w 36 } 37 38 assertCloseNotifierFlusherHijacker(t, true, w) 39 40 decorator := &fakeResponseWriterDecorator{ 41 ResponseWriter: w, 42 counter: counterGot, 43 } 44 wrapped := WrapForHTTP1Or2(decorator) 45 46 assertCloseNotifierFlusherHijacker(t, true, wrapped) 47 48 originalGot := GetOriginal(wrapped) 49 if originalWant != originalGot { 50 t.Errorf("Expected GetOriginal to return the original ResponseWriter object") 51 return 52 } 53 54 h.ServeHTTP(wrapped, r) 55 }) 56 } 57 58 // wrap the original http.ResponseWriter multiple times 59 handler := chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 // at this point, the original ResponseWriter object has been wrapped three times 61 // so each decorator is expected to tick the count by one for each method. 62 defer counterGot.assert(t, &counter{FlushInvoked: 3, CloseNotifyInvoked: 3, HijackInvoked: 3}) 63 64 //nolint:staticcheck // SA1019 65 w.(http.CloseNotifier).CloseNotify() 66 w.(http.Flusher).Flush() 67 68 conn, _, err := w.(http.Hijacker).Hijack() 69 if err != nil { 70 t.Errorf("Expected Hijack to succeed, but got error: %v", err) 71 return 72 } 73 conn.Close() 74 })) 75 handler = chain(handler) 76 handler = chain(handler) 77 78 server := newServer(t, handler, false) 79 defer server.Close() 80 81 sendRequest(t, server) 82 } 83 84 func TestWithHTTP2(t *testing.T) { 85 var originalWant http.ResponseWriter 86 counterGot := &counter{} 87 chain := func(h http.Handler) http.Handler { 88 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 89 if originalWant == nil { 90 originalWant = w 91 } 92 93 assertCloseNotifierFlusherHijacker(t, false, w) 94 95 decorator := &fakeResponseWriterDecorator{ 96 ResponseWriter: w, 97 counter: counterGot, 98 } 99 wrapped := WrapForHTTP1Or2(decorator) 100 101 assertCloseNotifierFlusherHijacker(t, false, wrapped) 102 103 originalGot := GetOriginal(wrapped) 104 if originalWant != originalGot { 105 t.Errorf("Expected GetOriginal to return the original ResponseWriter object") 106 return 107 } 108 109 h.ServeHTTP(wrapped, r) 110 }) 111 } 112 113 // wrap the original http.ResponseWriter multiple times 114 handler := chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 // at this point, the original ResponseWriter object has been wrapped three times 116 // so each decorator is expected to tick the count by one for each method. 117 defer counterGot.assert(t, &counter{FlushInvoked: 3, CloseNotifyInvoked: 3, HijackInvoked: 0}) 118 119 //nolint:staticcheck // SA1019 120 w.(http.CloseNotifier).CloseNotify() 121 w.(http.Flusher).Flush() 122 123 })) 124 handler = chain(handler) 125 handler = chain(handler) 126 127 server := newServer(t, handler, true) 128 defer server.Close() 129 130 sendRequest(t, server) 131 } 132 133 func TestGetOriginal(t *testing.T) { 134 tests := []struct { 135 name string 136 wrap func() (http.ResponseWriter, http.ResponseWriter) 137 panicExpected bool 138 }{ 139 { 140 name: "not wrapped", 141 wrap: func() (http.ResponseWriter, http.ResponseWriter) { 142 original := &FakeResponseWriter{} 143 return original, original 144 }, 145 }, 146 { 147 name: "wrapped once", 148 wrap: func() (http.ResponseWriter, http.ResponseWriter) { 149 original := &FakeResponseWriter{} 150 return original, &fakeResponseWriterDecorator{ 151 ResponseWriter: original, 152 } 153 }, 154 }, 155 { 156 name: "wrapped multiple times", 157 wrap: func() (http.ResponseWriter, http.ResponseWriter) { 158 original := &FakeResponseWriter{} 159 return original, &fakeResponseWriterDecorator{ 160 ResponseWriter: &fakeResponseWriterDecorator{ 161 ResponseWriter: &fakeResponseWriterDecorator{ 162 ResponseWriter: original, 163 }, 164 }, 165 } 166 }, 167 }, 168 { 169 name: "wraps itself", 170 wrap: func() (http.ResponseWriter, http.ResponseWriter) { 171 faulty := &fakeResponseWriterDecorator{} 172 faulty.ResponseWriter = faulty 173 return faulty, &fakeResponseWriterDecorator{ 174 ResponseWriter: faulty, 175 } 176 }, 177 panicExpected: true, 178 }, 179 } 180 181 for _, test := range tests { 182 t.Run(test.name, func(t *testing.T) { 183 originalExpected, wrapped := test.wrap() 184 185 func() { 186 defer func() { 187 err := recover() 188 switch { 189 case err != nil: 190 if !test.panicExpected { 191 t.Errorf("Expected no panic, but got: %v", err) 192 } 193 default: 194 if test.panicExpected { 195 t.Errorf("Expected a panic") 196 } 197 } 198 }() 199 200 originalGot := GetOriginal(wrapped) 201 if originalExpected != originalGot { 202 t.Errorf("Expected to get tehe original http.ResponseWriter object") 203 } 204 }() 205 }) 206 } 207 } 208 209 func newServer(t *testing.T, h http.Handler, http2 bool) *httptest.Server { 210 server := httptest.NewUnstartedServer(h) 211 if http2 { 212 server.EnableHTTP2 = true 213 server.StartTLS() 214 } else { 215 server.Start() 216 } 217 _, err := url.Parse(server.URL) 218 if err != nil { 219 t.Fatalf("Expected the server to have a valid URL, but got: %s", server.URL) 220 } 221 return server 222 } 223 224 func sendRequest(t *testing.T, server *httptest.Server) { 225 req, err := http.NewRequest("GET", server.URL, nil) 226 if err != nil { 227 t.Fatalf("error creating request: %v", err) 228 } 229 230 client := server.Client() 231 client.Timeout = 30 * time.Second 232 _, err = client.Do(req) 233 if err != nil { 234 t.Fatalf("Unexpected non-nil err from client.Do: %v", err) 235 } 236 } 237 238 func assertCloseNotifierFlusherHijacker(t *testing.T, hijackableExpected bool, w http.ResponseWriter) { 239 // the http.ResponseWriter object for both http/1.x and http2 240 // implement http.Flusher and http.CloseNotifier 241 if _, ok := w.(http.Flusher); !ok { 242 t.Errorf("Expected the http.ResponseWriter object to implement http.Flusher") 243 } 244 245 //nolint:staticcheck // SA1019 246 if _, ok := w.(http.CloseNotifier); !ok { 247 t.Errorf("Expected the http.ResponseWriter object to implement http.CloseNotifier") 248 } 249 250 // http/1.x implements http.Hijacker, not http2 251 if _, ok := w.(http.Hijacker); ok != hijackableExpected { 252 t.Errorf("Unexpected http.Hijacker implementation, expected: %t, but got: %t", hijackableExpected, ok) 253 } 254 } 255 256 type counter struct { 257 FlushInvoked int 258 HijackInvoked int 259 CloseNotifyInvoked int 260 } 261 262 func (c *counter) assert(t *testing.T, expected *counter) { 263 if expected.FlushInvoked != c.FlushInvoked { 264 t.Errorf("Expected Flush() count to match, wanted: %d, but got: %d", expected.FlushInvoked, c.FlushInvoked) 265 } 266 if expected.CloseNotifyInvoked != c.CloseNotifyInvoked { 267 t.Errorf("Expected CloseNotify() count to match, wanted: %d, but got: %d", expected.CloseNotifyInvoked, c.CloseNotifyInvoked) 268 } 269 if expected.HijackInvoked != c.HijackInvoked { 270 t.Errorf("Expected Hijack() count to match, wanted: %d, but got: %d", expected.HijackInvoked, c.HijackInvoked) 271 } 272 } 273 274 type fakeResponseWriterDecorator struct { 275 http.ResponseWriter 276 counter *counter 277 } 278 279 func (fw *fakeResponseWriterDecorator) Unwrap() http.ResponseWriter { return fw.ResponseWriter } 280 func (fw *fakeResponseWriterDecorator) Flush() { 281 if fw.counter != nil { 282 fw.counter.FlushInvoked++ 283 } 284 fw.ResponseWriter.(http.Flusher).Flush() 285 } 286 func (fw *fakeResponseWriterDecorator) Hijack() (net.Conn, *bufio.ReadWriter, error) { 287 if fw.counter != nil { 288 fw.counter.HijackInvoked++ 289 } 290 return fw.ResponseWriter.(http.Hijacker).Hijack() 291 } 292 func (fw *fakeResponseWriterDecorator) CloseNotify() <-chan bool { 293 if fw.counter != nil { 294 fw.counter.CloseNotifyInvoked++ 295 } 296 //nolint:staticcheck // SA1019 297 return fw.ResponseWriter.(http.CloseNotifier).CloseNotify() 298 }