k8s.io/apimachinery@v0.29.2/pkg/util/httpstream/wsstream/conn_test.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package wsstream 18 19 import ( 20 "encoding/base64" 21 "io" 22 "net/http" 23 "net/http/httptest" 24 "reflect" 25 "sync" 26 "testing" 27 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 "golang.org/x/net/websocket" 31 ) 32 33 func newServer(handler http.Handler) (*httptest.Server, string) { 34 server := httptest.NewServer(handler) 35 serverAddr := server.Listener.Addr().String() 36 return server, serverAddr 37 } 38 39 func TestRawConn(t *testing.T) { 40 channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel} 41 conn := NewConn(NewDefaultChannelProtocols(channels)) 42 43 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 44 conn.Open(w, req) 45 })) 46 defer s.Close() 47 48 client, err := websocket.Dial("ws://"+addr, "", "http://localhost/") 49 if err != nil { 50 t.Fatal(err) 51 } 52 defer client.Close() 53 54 <-conn.ready 55 wg := sync.WaitGroup{} 56 57 // verify we can read a client write 58 wg.Add(1) 59 go func() { 60 defer wg.Done() 61 data, err := io.ReadAll(conn.channels[0]) 62 if err != nil { 63 t.Error(err) 64 return 65 } 66 if !reflect.DeepEqual(data, []byte("client")) { 67 t.Errorf("unexpected server read: %v", data) 68 } 69 }() 70 71 if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 { 72 t.Fatalf("%d: %v", n, err) 73 } 74 75 // verify we can read a server write 76 wg.Add(1) 77 go func() { 78 defer wg.Done() 79 if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { 80 t.Errorf("%d: %v", n, err) 81 } 82 }() 83 84 data := make([]byte, 1024) 85 if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil { 86 t.Fatalf("%d: %v", n, err) 87 } 88 if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) { 89 t.Errorf("unexpected client read: %v", data[:7]) 90 } 91 92 // verify that an ignore channel is empty in both directions. 93 if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil { 94 t.Errorf("writes should be ignored") 95 } 96 data = make([]byte, 1024) 97 if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF { 98 t.Errorf("reads should be ignored") 99 } 100 101 // verify that a write to a Read channel doesn't block 102 if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil { 103 t.Errorf("writes should be ignored") 104 } 105 106 // verify that a read from a Write channel doesn't block 107 data = make([]byte, 1024) 108 if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF { 109 t.Errorf("reads should be ignored") 110 } 111 112 // verify that a client write to a Write channel doesn't block (is dropped) 113 if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 { 114 t.Fatalf("%d: %v", n, err) 115 } 116 117 client.Close() 118 wg.Wait() 119 } 120 121 func TestBase64Conn(t *testing.T) { 122 conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel})) 123 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 124 conn.Open(w, req) 125 })) 126 defer s.Close() 127 128 config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") 129 if err != nil { 130 t.Fatal(err) 131 } 132 config.Protocol = []string{"base64.channel.k8s.io"} 133 client, err := websocket.DialConfig(config) 134 if err != nil { 135 t.Fatal(err) 136 } 137 defer client.Close() 138 139 <-conn.ready 140 wg := sync.WaitGroup{} 141 wg.Add(1) 142 go func() { 143 defer wg.Done() 144 data, err := io.ReadAll(conn.channels[0]) 145 if err != nil { 146 t.Error(err) 147 return 148 } 149 if !reflect.DeepEqual(data, []byte("client")) { 150 t.Errorf("unexpected server read: %s", string(data)) 151 } 152 }() 153 154 clientData := base64.StdEncoding.EncodeToString([]byte("client")) 155 if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 { 156 t.Fatalf("%d: %v", n, err) 157 } 158 159 wg.Add(1) 160 go func() { 161 defer wg.Done() 162 if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { 163 t.Errorf("%d: %v", n, err) 164 } 165 }() 166 167 data := make([]byte, 1024) 168 if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil { 169 t.Fatalf("%d: %v", n, err) 170 } 171 expect := []byte(base64.StdEncoding.EncodeToString([]byte("server"))) 172 173 if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) { 174 t.Errorf("unexpected client read: %v", data[:9]) 175 } 176 177 client.Close() 178 wg.Wait() 179 } 180 181 type versionTest struct { 182 supported map[string]bool // protocol -> binary 183 requested []string 184 error bool 185 expected string 186 } 187 188 func versionTests() []versionTest { 189 const ( 190 binary = true 191 base64 = false 192 ) 193 return []versionTest{ 194 { 195 supported: nil, 196 requested: []string{"raw"}, 197 error: true, 198 }, 199 { 200 supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, 201 requested: nil, 202 expected: "", 203 }, 204 { 205 supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, 206 requested: []string{"v1.raw"}, 207 error: true, 208 }, 209 { 210 supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, 211 requested: []string{"v1.raw", "v1.base64"}, 212 error: true, 213 }, { 214 supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, 215 requested: []string{"v1.raw", "raw"}, 216 expected: "raw", 217 }, 218 { 219 supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, 220 requested: []string{"v1.raw"}, 221 expected: "v1.raw", 222 }, 223 { 224 supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, 225 requested: []string{"v2.base64"}, 226 expected: "v2.base64", 227 }, 228 } 229 } 230 231 func TestVersionedConn(t *testing.T) { 232 for i, test := range versionTests() { 233 func() { 234 supportedProtocols := map[string]ChannelProtocolConfig{} 235 for p, binary := range test.supported { 236 supportedProtocols[p] = ChannelProtocolConfig{ 237 Binary: binary, 238 Channels: []ChannelType{ReadWriteChannel}, 239 } 240 } 241 conn := NewConn(supportedProtocols) 242 // note that it's not enough to wait for conn.ready to avoid a race here. Hence, 243 // we use a channel. 244 selectedProtocol := make(chan string) 245 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 246 p, _, _ := conn.Open(w, req) 247 selectedProtocol <- p 248 })) 249 defer s.Close() 250 251 config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") 252 if err != nil { 253 t.Fatal(err) 254 } 255 config.Protocol = test.requested 256 client, err := websocket.DialConfig(config) 257 if err != nil { 258 if !test.error { 259 t.Fatalf("test %d: didn't expect error: %v", i, err) 260 } else { 261 return 262 } 263 } 264 defer client.Close() 265 if test.error && err == nil { 266 t.Fatalf("test %d: expected an error", i) 267 } 268 269 <-conn.ready 270 if got, expected := <-selectedProtocol, test.expected; got != expected { 271 t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) 272 } 273 }() 274 } 275 } 276 277 func TestIsWebSocketRequestWithStreamCloseProtocol(t *testing.T) { 278 tests := map[string]struct { 279 headers map[string]string 280 expected bool 281 }{ 282 "No headers returns false": { 283 headers: map[string]string{}, 284 expected: false, 285 }, 286 "Only connection upgrade header is false": { 287 headers: map[string]string{ 288 "Connection": "upgrade", 289 }, 290 expected: false, 291 }, 292 "Only websocket upgrade header is false": { 293 headers: map[string]string{ 294 "Upgrade": "websocket", 295 }, 296 expected: false, 297 }, 298 "Only websocket and connection upgrade headers is false": { 299 headers: map[string]string{ 300 "Connection": "upgrade", 301 "Upgrade": "websocket", 302 }, 303 expected: false, 304 }, 305 "Missing connection/upgrade header is false": { 306 headers: map[string]string{ 307 "Upgrade": "websocket", 308 WebSocketProtocolHeader: "v5.channel.k8s.io", 309 }, 310 expected: false, 311 }, 312 "Websocket connection upgrade headers with v5 protocol is true": { 313 headers: map[string]string{ 314 "Connection": "upgrade", 315 "Upgrade": "websocket", 316 WebSocketProtocolHeader: "v5.channel.k8s.io", 317 }, 318 expected: true, 319 }, 320 "Websocket connection upgrade headers with wrong case v5 protocol is false": { 321 headers: map[string]string{ 322 "Connection": "upgrade", 323 "Upgrade": "websocket", 324 WebSocketProtocolHeader: "v5.CHANNEL.k8s.io", // header value is case-sensitive 325 }, 326 expected: false, 327 }, 328 "Websocket connection upgrade headers with v4 protocol is false": { 329 headers: map[string]string{ 330 "Connection": "upgrade", 331 "Upgrade": "websocket", 332 WebSocketProtocolHeader: "v4.channel.k8s.io", 333 }, 334 expected: false, 335 }, 336 "Websocket connection upgrade headers with multiple protocols but missing v5 is false": { 337 headers: map[string]string{ 338 "Connection": "upgrade", 339 "Upgrade": "websocket", 340 WebSocketProtocolHeader: "v4.channel.k8s.io,v3.channel.k8s.io,v2.channel.k8s.io", 341 }, 342 expected: false, 343 }, 344 "Websocket connection upgrade headers with multiple protocols including v5 and spaces is true": { 345 headers: map[string]string{ 346 "Connection": "upgrade", 347 "Upgrade": "websocket", 348 WebSocketProtocolHeader: "v5.channel.k8s.io, v4.channel.k8s.io", 349 }, 350 expected: true, 351 }, 352 "Websocket connection upgrade headers with multiple protocols out of order including v5 and spaces is true": { 353 headers: map[string]string{ 354 "Connection": "upgrade", 355 "Upgrade": "websocket", 356 WebSocketProtocolHeader: "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io", 357 }, 358 expected: true, 359 }, 360 361 "Websocket connection upgrade headers key is case-insensitive": { 362 headers: map[string]string{ 363 "Connection": "upgrade", 364 "Upgrade": "websocket", 365 "sec-websocket-protocol": "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io", 366 }, 367 expected: true, 368 }, 369 } 370 371 for name, test := range tests { 372 req, err := http.NewRequest("GET", "http://www.example.com/", nil) 373 require.NoError(t, err) 374 for key, value := range test.headers { 375 req.Header.Add(key, value) 376 } 377 actual := IsWebSocketRequestWithStreamCloseProtocol(req) 378 assert.Equal(t, test.expected, actual, "%s: expected (%t), got (%t)", name, test.expected, actual) 379 } 380 } 381 382 func TestProtocolSupportsStreamClose(t *testing.T) { 383 tests := map[string]struct { 384 protocol string 385 expected bool 386 }{ 387 "empty protocol returns false": { 388 protocol: "", 389 expected: false, 390 }, 391 "not binary protocol returns false": { 392 protocol: "base64.channel.k8s.io", 393 expected: false, 394 }, 395 "V1 protocol returns false": { 396 protocol: "channel.k8s.io", 397 expected: false, 398 }, 399 "V4 protocol returns false": { 400 protocol: "v4.channel.k8s.io", 401 expected: false, 402 }, 403 "V5 protocol returns true": { 404 protocol: "v5.channel.k8s.io", 405 expected: true, 406 }, 407 "V5 protocol wrong case returns false": { 408 protocol: "V5.channel.K8S.io", 409 expected: false, 410 }, 411 } 412 413 for name, test := range tests { 414 actual := protocolSupportsStreamClose(test.protocol) 415 assert.Equal(t, test.expected, actual, 416 "%s: expected (%t), got (%t)", name, test.expected, actual) 417 } 418 }