k8s.io/client-go@v0.31.1/tools/remotecommand/websocket_test.go (about) 1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package remotecommand 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/rand" 23 "encoding/json" 24 "fmt" 25 "io" 26 "math" 27 mrand "math/rand" 28 "net/http" 29 "net/http/httptest" 30 "net/url" 31 "reflect" 32 "strings" 33 "sync" 34 "testing" 35 "time" 36 37 gwebsocket "github.com/gorilla/websocket" 38 39 v1 "k8s.io/api/core/v1" 40 apierrors "k8s.io/apimachinery/pkg/api/errors" 41 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 42 "k8s.io/apimachinery/pkg/util/httpstream" 43 "k8s.io/apimachinery/pkg/util/httpstream/wsstream" 44 "k8s.io/apimachinery/pkg/util/remotecommand" 45 "k8s.io/apimachinery/pkg/util/wait" 46 "k8s.io/client-go/rest" 47 clientcmdapi "k8s.io/client-go/tools/clientcmd/api" 48 ) 49 50 // TestWebSocketClient_LoopbackStdinToStdout returns random data sent on the STDIN channel 51 // back down the STDOUT channel. A subsequent comparison checks if the data 52 // sent on the STDIN channel is the same as the data returned on the STDOUT 53 // channel. This test can be run many times by the "stress" tool to check 54 // if there is any data which would cause problems with the WebSocket streams. 55 func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) { 56 // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. 57 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 58 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 59 if err != nil { 60 t.Fatalf("error on webSocketServerStreams: %v", err) 61 } 62 defer conns.conn.Close() 63 // Loopback the STDIN stream onto the STDOUT stream. 64 _, err = io.Copy(conns.stdoutStream, conns.stdinStream) 65 if err != nil { 66 t.Fatalf("error copying STDIN to STDOUT: %v", err) 67 } 68 })) 69 defer websocketServer.Close() 70 71 // Now create the WebSocket client (executor), and point it to the "websocketServer". 72 // Must add STDIN and STDOUT query params for the WebSocket client request. 73 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" 74 websocketLocation, err := url.Parse(websocketServer.URL) 75 if err != nil { 76 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 77 } 78 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 79 if err != nil { 80 t.Errorf("unexpected error creating websocket executor: %v", err) 81 } 82 // Generate random data, and set it up to stream on STDIN. The data will be 83 // returned on the STDOUT buffer. 84 randomSize := 1024 * 1024 85 randomData := make([]byte, randomSize) 86 if _, err := rand.Read(randomData); err != nil { 87 t.Errorf("unexpected error reading random data: %v", err) 88 } 89 var stdout bytes.Buffer 90 options := &StreamOptions{ 91 Stdin: bytes.NewReader(randomData), 92 Stdout: &stdout, 93 } 94 errorChan := make(chan error) 95 go func() { 96 // Start the streaming on the WebSocket "exec" client. 97 errorChan <- exec.StreamWithContext(context.Background(), *options) 98 }() 99 100 select { 101 case <-time.After(wait.ForeverTestTimeout): 102 t.Fatalf("expect stream to be closed after connection is closed.") 103 case err := <-errorChan: 104 if err != nil { 105 t.Errorf("unexpected error") 106 } 107 // Validate remote command v5 protocol was negotiated. 108 streamExec := exec.(*wsStreamExecutor) 109 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 110 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 111 } 112 } 113 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) 114 if err != nil { 115 t.Fatalf("error reading the stream: %v", err) 116 } 117 // Check the random data sent on STDIN was the same returned on STDOUT. 118 if !bytes.Equal(randomData, data) { 119 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) 120 } 121 } 122 123 // TestWebSocketClient_DifferentBufferSizes runs the previous loopback (STDIN -> STDOUT) test with different 124 // buffer sizes for reading from the opposite end of the websocket connection (in the websocket server). 125 func TestWebSocketClient_DifferentBufferSizes(t *testing.T) { 126 // 1k, 4k, 64k, and 128k buffer sizes for reading STDIN at websocket server endpoint. 127 // The standard buffer size for io.Copy is 32k. 128 bufferSizes := []int{1 * 1024, 4 * 1024, 64 * 1024, 128 * 1024} 129 for _, bufferSize := range bufferSizes { 130 // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. 131 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 132 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 133 if err != nil { 134 t.Fatalf("error on webSocketServerStreams: %v", err) 135 } 136 defer conns.conn.Close() 137 // Loopback the STDIN stream onto the STDOUT stream, using buffer with size. 138 buffer := make([]byte, bufferSize) 139 _, err = io.CopyBuffer(conns.stdoutStream, conns.stdinStream, buffer) 140 if err != nil { 141 t.Fatalf("error copying STDIN to STDOUT: %v", err) 142 } 143 })) 144 defer websocketServer.Close() 145 146 // Now create the WebSocket client (executor), and point it to the "websocketServer". 147 // Must add STDIN and STDOUT query params for the WebSocket client request. 148 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" 149 websocketLocation, err := url.Parse(websocketServer.URL) 150 if err != nil { 151 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 152 } 153 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 154 if err != nil { 155 t.Errorf("unexpected error creating websocket executor: %v", err) 156 } 157 // Generate random data, and set it up to stream on STDIN. The data will be 158 // returned on the STDOUT buffer. 159 randomSize := 1024 * 1024 160 randomData := make([]byte, randomSize) 161 if _, err := rand.Read(randomData); err != nil { 162 t.Errorf("unexpected error reading random data: %v", err) 163 } 164 var stdout bytes.Buffer 165 options := &StreamOptions{ 166 Stdin: bytes.NewReader(randomData), 167 Stdout: &stdout, 168 } 169 errorChan := make(chan error) 170 go func() { 171 // Start the streaming on the WebSocket "exec" client. 172 errorChan <- exec.StreamWithContext(context.Background(), *options) 173 }() 174 175 select { 176 case <-time.After(wait.ForeverTestTimeout): 177 t.Fatalf("expect stream to be closed after connection is closed.") 178 case err := <-errorChan: 179 if err != nil { 180 t.Errorf("unexpected error") 181 } 182 // Validate remote command v5 protocol was negotiated. 183 streamExec := exec.(*wsStreamExecutor) 184 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 185 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 186 } 187 } 188 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) 189 if err != nil { 190 t.Errorf("error reading the stream: %v", err) 191 return 192 } 193 // Check all the random data sent on STDIN was the same returned on STDOUT. 194 if !bytes.Equal(randomData, data) { 195 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) 196 } 197 } 198 } 199 200 // TestWebSocketClient_LoopbackStdinAsPipe uses a pipe to send random data on the STDIN 201 // channel, then closes the pipe. The fake server simply returns all STDIN data back 202 // onto the STDOUT channel, and the received data on STDOUT is validated against the 203 // random data initially sent. 204 func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) { 205 // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. 206 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 207 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 208 if err != nil { 209 t.Fatalf("error on webSocketServerStreams: %v", err) 210 } 211 defer conns.conn.Close() 212 // Loopback the STDIN stream onto the STDOUT stream. 213 _, err = io.Copy(conns.stdoutStream, conns.stdinStream) 214 if err != nil { 215 t.Fatalf("error copying STDIN to STDOUT: %v", err) 216 } 217 })) 218 defer websocketServer.Close() 219 220 // Now create the WebSocket client (executor), and point it to the "websocketServer". 221 // Must add STDIN and STDOUT query params for the WebSocket client request. 222 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" 223 websocketLocation, err := url.Parse(websocketServer.URL) 224 if err != nil { 225 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 226 } 227 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 228 if err != nil { 229 t.Errorf("unexpected error creating websocket executor: %v", err) 230 } 231 // Generate random data, and it will be written on the STDIN pipe. The same 232 // data will be returned on the STDOUT channel. 233 randomSize := 1024 * 1024 234 randomData := make([]byte, randomSize) 235 if _, err := rand.Read(randomData); err != nil { 236 t.Errorf("unexpected error reading random data: %v", err) 237 } 238 reader, writer := io.Pipe() 239 var stdout bytes.Buffer 240 options := &StreamOptions{ 241 Stdin: reader, 242 Stdout: &stdout, 243 } 244 errorChan := make(chan error) 245 go func() { 246 // Start the streaming on the WebSocket "exec" client. 247 errorChan <- exec.StreamWithContext(context.Background(), *options) 248 }() 249 // Write the random data onto the pipe connected to STDIN, then close the pipe. 250 _, err = writer.Write(randomData) 251 if err != nil { 252 t.Fatalf("unable to write random data to STDIN pipe: %v", err) 253 } 254 writer.Close() 255 256 select { 257 case <-time.After(wait.ForeverTestTimeout): 258 t.Fatalf("expect stream to be closed after connection is closed.") 259 case err := <-errorChan: 260 if err != nil { 261 t.Errorf("unexpected error") 262 } 263 // Validate remote command v5 protocol was negotiated. 264 streamExec := exec.(*wsStreamExecutor) 265 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 266 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 267 } 268 } 269 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) 270 if err != nil { 271 t.Errorf("error reading the stream: %v", err) 272 return 273 } 274 // Check the random data sent on STDIN was the same returned on STDOUT. 275 if !bytes.Equal(randomData, data) { 276 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) 277 } 278 } 279 280 // TestWebSocketClient_LoopbackStdinToStderr returns random data sent on the STDIN channel 281 // back down the STDERR channel. A subsequent comparison checks if the data 282 // sent on the STDIN channel is the same as the data returned on the STDERR 283 // channel. This test can be run many times by the "stress" tool to check 284 // if there is any data which would cause problems with the WebSocket streams. 285 func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) { 286 // Create fake WebSocket server. Copy received STDIN data back onto STDERR stream. 287 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 288 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 289 if err != nil { 290 t.Fatalf("error on webSocketServerStreams: %v", err) 291 } 292 defer conns.conn.Close() 293 // Loopback the STDIN stream onto the STDERR stream. 294 _, err = io.Copy(conns.stderrStream, conns.stdinStream) 295 if err != nil { 296 t.Fatalf("error copying STDIN to STDERR: %v", err) 297 } 298 })) 299 defer websocketServer.Close() 300 301 // Now create the WebSocket client (executor), and point it to the "websocketServer". 302 // Must add STDIN and STDERR query params for the WebSocket client request. 303 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true" 304 websocketLocation, err := url.Parse(websocketServer.URL) 305 if err != nil { 306 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 307 } 308 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 309 if err != nil { 310 t.Errorf("unexpected error creating websocket executor: %v", err) 311 } 312 // Generate random data, and set it up to stream on STDIN. The data will be 313 // returned on the STDERR buffer. 314 randomSize := 1024 * 1024 315 randomData := make([]byte, randomSize) 316 if _, err := rand.Read(randomData); err != nil { 317 t.Errorf("unexpected error reading random data: %v", err) 318 } 319 var stderr bytes.Buffer 320 options := &StreamOptions{ 321 Stdin: bytes.NewReader(randomData), 322 Stderr: &stderr, 323 } 324 errorChan := make(chan error) 325 go func() { 326 // Start the streaming on the WebSocket "exec" client. 327 errorChan <- exec.StreamWithContext(context.Background(), *options) 328 }() 329 330 select { 331 case <-time.After(wait.ForeverTestTimeout): 332 t.Fatalf("expect stream to be closed after connection is closed.") 333 case err := <-errorChan: 334 if err != nil { 335 t.Errorf("unexpected error") 336 } 337 // Validate remote command v5 protocol was negotiated. 338 streamExec := exec.(*wsStreamExecutor) 339 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 340 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 341 } 342 } 343 data, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) 344 if err != nil { 345 t.Errorf("error reading the stream: %v", err) 346 return 347 } 348 // Check the random data sent on STDIN was the same returned on STDERR. 349 if !bytes.Equal(randomData, data) { 350 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) 351 } 352 } 353 354 // TestWebSocketClient_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from 355 // the websocket connection at the same time. 356 func TestWebSocketClient_MultipleReadChannels(t *testing.T) { 357 // Create fake WebSocket server, which uses a TeeReader to copy the same data 358 // onto the STDOUT stream onto the STDERR stream as well. 359 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 360 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 361 if err != nil { 362 t.Fatalf("error on webSocketServerStreams: %v", err) 363 } 364 defer conns.conn.Close() 365 // TeeReader copies data read on STDIN onto STDERR. 366 stdinReader := io.TeeReader(conns.stdinStream, conns.stderrStream) 367 // Also copy STDIN to STDOUT. 368 _, err = io.Copy(conns.stdoutStream, stdinReader) 369 if err != nil { 370 t.Errorf("error copying STDIN to STDOUT: %v", err) 371 } 372 })) 373 defer websocketServer.Close() 374 // Now create the WebSocket client (executor), and point it to the "websocketServer". 375 // Must add stdin, stdout, and stderr query param for the WebSocket client request. 376 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + "&" + "stderr=true" 377 websocketLocation, err := url.Parse(websocketServer.URL) 378 if err != nil { 379 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 380 } 381 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 382 if err != nil { 383 t.Errorf("unexpected error creating websocket executor: %v", err) 384 } 385 // Generate 1MB of random data, and set it up to stream on STDIN. The data will be 386 // returned on the STDOUT and STDERR buffers. 387 randomSize := 1024 * 1024 388 randomData := make([]byte, randomSize) 389 if _, err := rand.Read(randomData); err != nil { 390 t.Errorf("unexpected error reading random data: %v", err) 391 } 392 var stdout, stderr bytes.Buffer 393 options := &StreamOptions{ 394 Stdin: bytes.NewReader(randomData), 395 Stdout: &stdout, 396 Stderr: &stderr, 397 } 398 errorChan := make(chan error) 399 go func() { 400 errorChan <- exec.StreamWithContext(context.Background(), *options) 401 }() 402 403 select { 404 case <-time.After(wait.ForeverTestTimeout): 405 t.Fatalf("expect stream to be closed after connection is closed.") 406 case err := <-errorChan: 407 if err != nil { 408 t.Errorf("unexpected error: %v", err) 409 } 410 // Validate remote command v5 protocol was negotiated. 411 streamExec := exec.(*wsStreamExecutor) 412 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 413 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 414 } 415 } 416 // Validate the data read from the STDOUT stream is the same as sent on the STDIN stream. 417 stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) 418 if err != nil { 419 t.Fatalf("error reading the stream: %v", err) 420 } 421 if !bytes.Equal(stdoutBytes, randomData) { 422 t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(randomData)) 423 } 424 // Validate the data read from the STDERR stream is the same as sent on the STDIN stream. 425 stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) 426 if err != nil { 427 t.Fatalf("error reading the stream: %v", err) 428 } 429 if !bytes.Equal(stderrBytes, randomData) { 430 t.Errorf("unexpected data received (%d) sent (%d)", len(stderrBytes), len(randomData)) 431 } 432 } 433 434 // Returns a random exit code in the range(1-127). 435 func randomExitCode() int { 436 errorCode := mrand.Intn(128) 437 if errorCode == 0 { 438 errorCode = 1 439 } 440 return errorCode 441 } 442 443 // TestWebSocketClient_ErrorStream tests the websocket error stream by hard-coding a 444 // structured non-zero exit code error from the websocket server to the websocket client. 445 func TestWebSocketClient_ErrorStream(t *testing.T) { 446 expectedExitCode := randomExitCode() 447 // Create fake WebSocket server. Returns structured exit code error on error stream. 448 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 449 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 450 if err != nil { 451 t.Fatalf("error on webSocketServerStreams: %v", err) 452 } 453 defer conns.conn.Close() 454 _, err = io.Copy(conns.stderrStream, conns.stdinStream) 455 if err != nil { 456 t.Fatalf("error copying STDIN to STDERR: %v", err) 457 } 458 // Force an non-zero exit code error returned on the error stream. 459 err = conns.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{ 460 Status: metav1.StatusFailure, 461 Reason: remotecommand.NonZeroExitCodeReason, 462 Details: &metav1.StatusDetails{ 463 Causes: []metav1.StatusCause{ 464 { 465 Type: remotecommand.ExitCodeCauseType, 466 Message: fmt.Sprintf("%d", expectedExitCode), 467 }, 468 }, 469 }, 470 }}) 471 if err != nil { 472 t.Fatalf("error writing status: %v", err) 473 } 474 })) 475 defer websocketServer.Close() 476 477 // Now create the WebSocket client (executor), and point it to the "websocketServer". 478 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true" 479 websocketLocation, err := url.Parse(websocketServer.URL) 480 if err != nil { 481 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 482 } 483 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 484 if err != nil { 485 t.Errorf("unexpected error creating websocket executor: %v", err) 486 } 487 randomData := make([]byte, 256) 488 if _, err := rand.Read(randomData); err != nil { 489 t.Errorf("unexpected error reading random data: %v", err) 490 } 491 var stderr bytes.Buffer 492 options := &StreamOptions{ 493 Stdin: bytes.NewReader(randomData), 494 Stderr: &stderr, 495 } 496 errorChan := make(chan error) 497 go func() { 498 // Start the streaming on the WebSocket "exec" client. 499 errorChan <- exec.StreamWithContext(context.Background(), *options) 500 }() 501 502 select { 503 case <-time.After(wait.ForeverTestTimeout): 504 t.Fatalf("expect stream to be closed after connection is closed.") 505 case err := <-errorChan: 506 // Validate remote command v5 protocol was negotiated. 507 streamExec := exec.(*wsStreamExecutor) 508 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 509 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 510 } 511 // Expect exit code error on error stream. 512 if err == nil { 513 t.Errorf("expected error, but received none") 514 } 515 expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode) 516 // Compare expected error with exit code to actual error. 517 if expectedError != err.Error() { 518 t.Errorf("expected error (%s), got (%s)", expectedError, err) 519 } 520 } 521 } 522 523 // fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of 524 // "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice. 525 type fakeTerminalSizeQueue struct { 526 maxSizes int 527 terminalSizes []TerminalSize 528 } 529 530 // newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing 531 // "max" number of random TerminalSizes created. 532 func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue { 533 return &fakeTerminalSizeQueue{ 534 maxSizes: max, 535 terminalSizes: make([]TerminalSize, 0, max), 536 } 537 } 538 539 // Next returns a pointer to the next random TerminalSize, or nil if we have 540 // already returned "maxSizes" TerminalSizes already. Stores the randomly 541 // created TerminalSize in "terminalSizes" field for later validation. 542 func (f *fakeTerminalSizeQueue) Next() *TerminalSize { 543 if len(f.terminalSizes) >= f.maxSizes { 544 return nil 545 } 546 size := randomTerminalSize() 547 f.terminalSizes = append(f.terminalSizes, size) 548 return &size 549 } 550 551 // randomTerminalSize returns a TerminalSize with random values in the 552 // range (0-65535) for the fields Width and Height. 553 func randomTerminalSize() TerminalSize { 554 randWidth := uint16(mrand.Intn(int(math.Pow(2, 16)))) 555 randHeight := uint16(mrand.Intn(int(math.Pow(2, 16)))) 556 return TerminalSize{ 557 Width: randWidth, 558 Height: randHeight, 559 } 560 } 561 562 // randReader implements the ReadCloser interface, and it continuously 563 // returns random data until it is closed. Stores number of random 564 // bytes generated and returned. 565 type randReader struct { 566 randBytes []byte 567 closed bool 568 lock sync.Mutex 569 } 570 571 // Read implements the Reader interface filling the passed buffer with 572 // random data, returning the number of bytes filled and an error 573 // if one occurs. Return 0 and EOF if the randReader has been closed. 574 func (r *randReader) Read(b []byte) (int, error) { 575 r.lock.Lock() 576 defer r.lock.Unlock() 577 if r.closed { 578 return 0, io.EOF 579 } 580 n, err := rand.Read(b) 581 c := bytes.Clone(b) 582 r.randBytes = append(r.randBytes, c...) 583 return n, err 584 } 585 586 // Close implements the Closer interface, setting the close field true. 587 // Further calls to Read() after Close() will return 0, EOF. Returns 588 // nil error. 589 func (r *randReader) Close() (err error) { 590 r.lock.Lock() 591 defer r.lock.Unlock() 592 r.closed = true 593 return nil 594 } 595 596 // TestWebSocketClient_MultipleWriteChannels tests two streams (STDIN, TTY resize) writing to the 597 // websocket connection at the same time to exercise the connection write lock. 598 func TestWebSocketClient_MultipleWriteChannels(t *testing.T) { 599 // Create the fake terminal size queue and the actualTerminalSizes which 600 // will be received at the opposite websocket endpoint. 601 numSizeQueue := 10000 602 sizeQueue := newTerminalSizeQueue(numSizeQueue) 603 actualTerminalSizes := make([]TerminalSize, 0, numSizeQueue) 604 // Create ReadCloser sending random data on STDIN stream over websocket connection. 605 stdinReader := randReader{randBytes: []byte{}, closed: false} 606 // Create fake WebSocket server, which will receive concurrently the STDIN stream as 607 // well as the resize stream (TerminalSizes). Store the TerminalSize data from the resize 608 // stream for subsequent validation. 609 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 610 var wg sync.WaitGroup 611 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 612 if err != nil { 613 t.Fatalf("error on webSocketServerStreams: %v", err) 614 } 615 defer conns.conn.Close() 616 // Create goroutine to loopback the STDIN stream onto the STDOUT stream. 617 wg.Add(1) 618 go func() { 619 _, err := io.Copy(conns.stdoutStream, conns.stdinStream) 620 if err != nil { 621 t.Errorf("error copying STDIN to STDOUT: %v", err) 622 } 623 wg.Done() 624 }() 625 // Read the terminal resize requests, storing them in actualTerminalSizes 626 for i := 0; i < numSizeQueue; i++ { 627 actualTerminalSize := <-conns.resizeChan 628 actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize) 629 } 630 stdinReader.Close() // Stops the random STDIN stream generation 631 wg.Wait() // Wait for all bytes copied from STDIN to STDOUT 632 })) 633 defer websocketServer.Close() 634 // Now create the WebSocket client (executor), and point it to the "websocketServer". 635 // Must add stdin, stdout, and TTY query param for the WebSocket client request. 636 websocketServer.URL = websocketServer.URL + "?" + "tty=true" + "&" + "stdin=true" + "&" + "stdout=true" 637 websocketLocation, err := url.Parse(websocketServer.URL) 638 if err != nil { 639 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 640 } 641 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 642 if err != nil { 643 t.Errorf("unexpected error creating websocket executor: %v", err) 644 } 645 var stdout bytes.Buffer 646 options := &StreamOptions{ 647 Stdin: &stdinReader, 648 Stdout: &stdout, 649 Tty: true, 650 TerminalSizeQueue: sizeQueue, 651 } 652 errorChan := make(chan error) 653 go func() { 654 errorChan <- exec.StreamWithContext(context.Background(), *options) 655 }() 656 657 select { 658 case <-time.After(wait.ForeverTestTimeout): 659 t.Fatalf("expect stream to be closed after connection is closed.") 660 case err := <-errorChan: 661 if err != nil { 662 t.Errorf("unexpected error: %v", err) 663 } 664 // Validate remote command v5 protocol was negotiated. 665 streamExec := exec.(*wsStreamExecutor) 666 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 667 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 668 } 669 } 670 // Check the random data sent on STDIN was the same returned on STDOUT *and* 671 // that a minimum amount of random data was sent and received, ensuring concurrency. 672 stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) 673 if err != nil { 674 t.Fatalf("error reading the stream: %v", err) 675 } 676 if len(stdoutBytes) == 0 { 677 t.Errorf("No STDOUT bytes processed before resize stream finished: %d", len(stdoutBytes)) 678 } 679 if !bytes.Equal(stdoutBytes, stdinReader.randBytes) { 680 t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(stdinReader.randBytes)) 681 } 682 // Validate the random TerminalSizes sent on the resize stream are the same 683 // as the actual TerminalSizes received at the websocket server. 684 if len(actualTerminalSizes) != numSizeQueue { 685 t.Errorf("expected received terminal size window (%d), got (%d)", 686 numSizeQueue, len(actualTerminalSizes)) 687 } 688 for i, actual := range actualTerminalSizes { 689 expected := sizeQueue.terminalSizes[i] 690 if !reflect.DeepEqual(expected, actual) { 691 t.Errorf("expected terminal resize window %v, got %v", expected, actual) 692 } 693 } 694 } 695 696 // TestWebSocketClient_ProtocolVersions validates that remote command subprotocol versions V2-V4 697 // (V5 is already tested elsewhere) can be negotiated. 698 func TestWebSocketClient_ProtocolVersions(t *testing.T) { 699 // Create a raw websocket server that accepts V2-V4 versions of 700 // the remote command subprotocol. 701 var upgrader = gwebsocket.Upgrader{ 702 CheckOrigin: func(r *http.Request) bool { 703 return true // Accepting all requests 704 }, 705 Subprotocols: []string{ 706 remotecommand.StreamProtocolV4Name, 707 remotecommand.StreamProtocolV3Name, 708 remotecommand.StreamProtocolV2Name, 709 }, 710 } 711 // Upgrade a raw websocket server connection. 712 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 713 conn, err := upgrader.Upgrade(w, req, nil) 714 if err != nil { 715 t.Fatalf("unable to upgrade to create websocket connection: %v", err) 716 } 717 defer conn.Close() 718 })) 719 defer websocketServer.Close() 720 721 // Set up the websocket client with the STDOUT stream. 722 websocketServer.URL = websocketServer.URL + "?" + "stdout=true" 723 websocketLocation, err := url.Parse(websocketServer.URL) 724 if err != nil { 725 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 726 } 727 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 728 if err != nil { 729 t.Errorf("unexpected error creating websocket executor: %v", err) 730 } 731 // Iterate through previous remote command protocol versions, validating the 732 // requested protocol version is the one that is negotiated. 733 versions := []string{ 734 remotecommand.StreamProtocolV4Name, 735 remotecommand.StreamProtocolV3Name, 736 remotecommand.StreamProtocolV2Name, 737 } 738 for _, requestedVersion := range versions { 739 streamExec := exec.(*wsStreamExecutor) 740 streamExec.protocols = []string{requestedVersion} 741 var stdout bytes.Buffer 742 options := &StreamOptions{ 743 Stdout: &stdout, 744 } 745 errorChan := make(chan error) 746 go func() { 747 // Start the streaming on the WebSocket "exec" client. 748 errorChan <- exec.StreamWithContext(context.Background(), *options) 749 }() 750 751 select { 752 case <-time.After(wait.ForeverTestTimeout): 753 t.Fatalf("expect stream to be closed after connection is closed.") 754 case <-errorChan: 755 // Validate remote command protocol requestedVersion was negotiated. 756 streamExec := exec.(*wsStreamExecutor) 757 if requestedVersion != streamExec.negotiated { 758 t.Fatalf("expected protocol version (%s), got (%s)", requestedVersion, streamExec.negotiated) 759 } 760 } 761 } 762 } 763 764 // TestWebSocketClient_BadHandshake tests that a "bad handshake" error occurs when 765 // the WebSocketExecutor attempts to upgrade the connection to a subprotocol version 766 // (V4) that is not supported by the websocket server (only supports V5). 767 func TestWebSocketClient_BadHandshake(t *testing.T) { 768 // Create fake WebSocket server (supports V5 subprotocol). 769 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 770 // Bad handshake means websocket server will not completely initialize. 771 _, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 772 if err == nil { 773 t.Fatalf("expected error, but received none.") 774 } 775 if !strings.Contains(err.Error(), "websocket server finished before becoming ready") { 776 t.Errorf("expected websocket server error, but got: %v", err) 777 } 778 })) 779 defer websocketServer.Close() 780 781 websocketServer.URL = websocketServer.URL + "?" + "stdout=true" 782 websocketLocation, err := url.Parse(websocketServer.URL) 783 if err != nil { 784 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 785 } 786 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 787 if err != nil { 788 t.Errorf("unexpected error creating websocket executor: %v", err) 789 } 790 streamExec := exec.(*wsStreamExecutor) 791 // Set the attempted subprotocol version to V4; websocket server only accepts V5. 792 streamExec.protocols = []string{remotecommand.StreamProtocolV4Name} 793 794 var stdout bytes.Buffer 795 options := &StreamOptions{ 796 Stdout: &stdout, 797 } 798 errorChan := make(chan error) 799 go func() { 800 // Start the streaming on the WebSocket "exec" client. 801 errorChan <- streamExec.StreamWithContext(context.Background(), *options) 802 }() 803 804 select { 805 case <-time.After(wait.ForeverTestTimeout): 806 t.Fatalf("expect stream to be closed after connection is closed.") 807 case err := <-errorChan: 808 // Expecting unable to upgrade connection -- "bad handshake" error. 809 if err == nil { 810 t.Errorf("expected error but received none") 811 } 812 if !strings.Contains(err.Error(), "bad handshake") { 813 t.Errorf("expected bad handshake error, got (%s)", err) 814 } 815 } 816 } 817 818 // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a 819 // timeout by setting the ping period greater than the deadline. 820 func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { 821 blockRequestCtx, unblockRequest := context.WithCancel(context.Background()) 822 defer unblockRequest() 823 // Create fake WebSocket server which blocks. 824 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 825 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) 826 if err != nil { 827 t.Fatalf("error on webSocketServerStreams: %v", err) 828 } 829 defer conns.conn.Close() 830 <-blockRequestCtx.Done() 831 })) 832 defer websocketServer.Close() 833 // Create websocket client connecting to fake server. 834 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" 835 websocketLocation, err := url.Parse(websocketServer.URL) 836 if err != nil { 837 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 838 } 839 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 840 if err != nil { 841 t.Errorf("unexpected error creating websocket executor: %v", err) 842 } 843 streamExec := exec.(*wsStreamExecutor) 844 // Ping period is greater than the ping deadline, forcing the timeout to fire. 845 pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it 846 pingDeadline := time.Second // this gives setup 1 second to establish streams 847 streamExec.heartbeatPeriod = pingPeriod 848 streamExec.heartbeatDeadline = pingDeadline 849 // Send some random data to the websocket server through STDIN. 850 randomData := make([]byte, 128) 851 if _, err := rand.Read(randomData); err != nil { 852 t.Errorf("unexpected error reading random data: %v", err) 853 } 854 options := &StreamOptions{ 855 Stdin: bytes.NewReader(randomData), 856 } 857 errorChan := make(chan error) 858 go func() { 859 // Start the streaming on the WebSocket "exec" client. 860 errorChan <- streamExec.StreamWithContext(context.Background(), *options) 861 }() 862 863 select { 864 case <-time.After(wait.ForeverTestTimeout): 865 t.Fatalf("expected heartbeat timeout, got none.") 866 case err := <-errorChan: 867 // Expecting heartbeat timeout error. 868 if err == nil { 869 t.Fatalf("expected error but received none") 870 } 871 if !strings.Contains(err.Error(), "i/o timeout") { 872 t.Errorf("expected heartbeat timeout error, got (%s)", err) 873 } 874 // Validate remote command v5 protocol was negotiated. 875 streamExec := exec.(*wsStreamExecutor) 876 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 877 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 878 } 879 } 880 } 881 882 // TestWebSocketClient_TextMessageTypeError tests when the wrong message type is returned 883 // from the other websocket endpoint. Remote command protocols use "BinaryMessage", but 884 // this test hard-codes returning a "TextMessage". 885 func TestWebSocketClient_TextMessageTypeError(t *testing.T) { 886 var upgrader = gwebsocket.Upgrader{ 887 CheckOrigin: func(r *http.Request) bool { 888 return true // Accepting all requests 889 }, 890 Subprotocols: []string{remotecommand.StreamProtocolV5Name}, 891 } 892 // Upgrade a raw websocket server connection. Returns wrong message type "TextMessage". 893 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 894 conn, err := upgrader.Upgrade(w, req, nil) 895 if err != nil { 896 t.Fatalf("unable to upgrade to create websocket connection: %v", err) 897 } 898 defer conn.Close() 899 msg := []byte("test message with wrong message type.") 900 stdOutMsg := append([]byte{remotecommand.StreamStdOut}, msg...) 901 // Wrong message type "TextMessage". 902 err = conn.WriteMessage(gwebsocket.TextMessage, stdOutMsg) 903 if err != nil { 904 t.Fatalf("error writing text message to websocket: %v", err) 905 } 906 907 })) 908 defer websocketServer.Close() 909 910 // Set up the websocket client with the STDOUT stream. 911 websocketServer.URL = websocketServer.URL + "?" + "stdout=true" 912 websocketLocation, err := url.Parse(websocketServer.URL) 913 if err != nil { 914 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 915 } 916 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 917 if err != nil { 918 t.Errorf("unexpected error creating websocket executor: %v", err) 919 } 920 var stdout bytes.Buffer 921 options := &StreamOptions{ 922 Stdout: &stdout, 923 } 924 errorChan := make(chan error) 925 go func() { 926 // Start the streaming on the WebSocket "exec" client. 927 errorChan <- exec.StreamWithContext(context.Background(), *options) 928 }() 929 930 select { 931 case <-time.After(wait.ForeverTestTimeout): 932 t.Fatalf("expect stream to be closed after connection is closed.") 933 case err := <-errorChan: 934 // Expecting bad message type error. 935 if err == nil { 936 t.Fatalf("expected error but received none") 937 } 938 if !strings.Contains(err.Error(), "unexpected message type") { 939 t.Errorf("expected bad message type error, got (%s)", err) 940 } 941 // Validate remote command v5 protocol was negotiated. 942 streamExec := exec.(*wsStreamExecutor) 943 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 944 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 945 } 946 } 947 } 948 949 // TestWebSocketClient_EmptyMessageHandled tests that the error of a completely empty message 950 // is handled correctly. If the message is completely empty, the initial read of the stream id 951 // should fail (followed by cleanup). 952 func TestWebSocketClient_EmptyMessageHandled(t *testing.T) { 953 var upgrader = gwebsocket.Upgrader{ 954 CheckOrigin: func(r *http.Request) bool { 955 return true // Accepting all requests 956 }, 957 Subprotocols: []string{remotecommand.StreamProtocolV5Name}, 958 } 959 // Upgrade a raw websocket server connection. Returns wrong message type "TextMessage". 960 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 961 conn, err := upgrader.Upgrade(w, req, nil) 962 if err != nil { 963 t.Fatalf("unable to upgrade to create websocket connection: %v", err) 964 } 965 defer conn.Close() 966 // Send completely empty message, including missing initial stream id. 967 conn.WriteMessage(gwebsocket.BinaryMessage, []byte{}) //nolint:errcheck 968 })) 969 defer websocketServer.Close() 970 971 // Set up the websocket client with the STDOUT stream. 972 websocketServer.URL = websocketServer.URL + "?" + "stdout=true" 973 websocketLocation, err := url.Parse(websocketServer.URL) 974 if err != nil { 975 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) 976 } 977 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) 978 if err != nil { 979 t.Errorf("unexpected error creating websocket executor: %v", err) 980 } 981 var stdout bytes.Buffer 982 options := &StreamOptions{ 983 Stdout: &stdout, 984 } 985 errorChan := make(chan error) 986 go func() { 987 // Start the streaming on the WebSocket "exec" client. 988 errorChan <- exec.StreamWithContext(context.Background(), *options) 989 }() 990 991 select { 992 case <-time.After(wait.ForeverTestTimeout): 993 t.Fatalf("expect stream to be closed after connection is closed.") 994 case err := <-errorChan: 995 // Expecting error reading initial stream id. 996 if err == nil { 997 t.Fatalf("expected error but received none") 998 } 999 if !strings.Contains(err.Error(), "read stream id") { 1000 t.Errorf("expected error reading stream id, got (%s)", err) 1001 } 1002 // Validate remote command v5 protocol was negotiated. 1003 streamExec := exec.(*wsStreamExecutor) 1004 if remotecommand.StreamProtocolV5Name != streamExec.negotiated { 1005 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) 1006 } 1007 } 1008 } 1009 1010 func TestWebSocketClient_ExecutorErrors(t *testing.T) { 1011 // Invalid config causes transport creation error in websocket executor constructor. 1012 config := rest.Config{ 1013 ExecProvider: &clientcmdapi.ExecConfig{}, 1014 AuthProvider: &clientcmdapi.AuthProviderConfig{}, 1015 } 1016 _, err := NewWebSocketExecutor(&config, "GET", "http://localhost") 1017 if err == nil { 1018 t.Errorf("expecting executor constructor error, but received none.") 1019 } else if !strings.Contains(err.Error(), "error creating websocket transports") { 1020 t.Errorf("expecting error creating transports, got (%s)", err.Error()) 1021 } 1022 // Verify that a nil context will cause an error in StreamWithContext 1023 exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost") 1024 if err != nil { 1025 t.Errorf("unexpected error creating websocket executor: %v", err) 1026 } 1027 errorChan := make(chan error) 1028 go func() { 1029 // Start the streaming on the WebSocket "exec" client. 1030 var ctx context.Context 1031 errorChan <- exec.StreamWithContext(ctx, StreamOptions{}) 1032 }() 1033 1034 select { 1035 case <-time.After(wait.ForeverTestTimeout): 1036 t.Fatalf("expect stream to be closed after connection is closed.") 1037 case err := <-errorChan: 1038 // Expecting error with nil context. 1039 if err == nil { 1040 t.Fatalf("expected error but received none") 1041 } 1042 if !strings.Contains(err.Error(), "nil Context") { 1043 t.Errorf("expected nil context error, got (%s)", err) 1044 } 1045 } 1046 } 1047 1048 func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { 1049 var upgrader = gwebsocket.Upgrader{ 1050 CheckOrigin: func(r *http.Request) bool { 1051 return true // Accepting all requests 1052 }, 1053 } 1054 // Upgrade a raw websocket server connection, which automatically responds to Ping. 1055 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 1056 conn, err := upgrader.Upgrade(w, req, nil) 1057 if err != nil { 1058 t.Fatalf("unable to upgrade to create websocket connection: %v", err) 1059 } 1060 defer conn.Close() 1061 for { 1062 _, _, err := conn.ReadMessage() 1063 if err != nil { 1064 break 1065 } 1066 } 1067 })) 1068 defer websocketServer.Close() 1069 // Create a raw websocket client, connecting to the websocket server. 1070 url := strings.ReplaceAll(websocketServer.URL, "http", "ws") 1071 client, _, err := gwebsocket.DefaultDialer.Dial(url, nil) 1072 if err != nil { 1073 t.Fatalf("dial: %v", err) 1074 } 1075 defer client.Close() 1076 // Create a heartbeat using the client websocket connection, and start it. 1077 // "period" is less than "deadline", so ping/pong heartbeat will succceed. 1078 var expectedMsg = "test heartbeat message" 1079 var period = 100 * time.Millisecond 1080 var deadline = 200 * time.Millisecond 1081 heartbeat := newHeartbeat(client, period, deadline) 1082 heartbeat.setMessage(expectedMsg) 1083 // Add a channel to the handler to retrieve the "pong" message. 1084 pongMsgCh := make(chan string) 1085 pongHandler := heartbeat.conn.PongHandler() 1086 heartbeat.conn.SetPongHandler(func(msg string) error { 1087 pongMsgCh <- msg 1088 return pongHandler(msg) 1089 }) 1090 go heartbeat.start() 1091 1092 var wg sync.WaitGroup 1093 wg.Add(1) 1094 go func() { 1095 defer wg.Done() 1096 for { 1097 _, _, err := client.ReadMessage() 1098 if err != nil { 1099 t.Logf("client err reading message: %v", err) 1100 return 1101 } 1102 } 1103 }() 1104 1105 select { 1106 case actualMsg := <-pongMsgCh: 1107 close(heartbeat.closer) 1108 // Validate the received pong message is the same as sent in ping. 1109 if expectedMsg != actualMsg { 1110 t.Errorf("expected received pong message (%s), got (%s)", expectedMsg, actualMsg) 1111 } 1112 case <-time.After(period * 4): 1113 // This case should not happen. 1114 close(heartbeat.closer) 1115 t.Errorf("unexpected heartbeat timeout") 1116 } 1117 wg.Wait() 1118 } 1119 1120 func TestLateStreamCreation(t *testing.T) { 1121 c := newWSStreamCreator(nil) 1122 c.closeAllStreamReaders(nil) 1123 if err := c.setStream(0, nil); err == nil { 1124 t.Fatal("expected error adding stream after closeAllStreamReaders") 1125 } 1126 } 1127 1128 func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { 1129 // Validate Stream functions. 1130 c := newWSStreamCreator(nil) 1131 headers := http.Header{} 1132 headers.Set(v1.StreamType, v1.StreamTypeStdin) 1133 s, err := c.CreateStream(headers) 1134 if err != nil { 1135 t.Errorf("unexpected stream creation error: %v", err) 1136 } 1137 expectedStreamID := uint32(remotecommand.StreamStdIn) 1138 actualStreamID := s.Identifier() 1139 if expectedStreamID != actualStreamID { 1140 t.Errorf("expecting stream id (%d), got (%d)", expectedStreamID, actualStreamID) 1141 } 1142 actualHeaders := s.Headers() 1143 if !reflect.DeepEqual(headers, actualHeaders) { 1144 t.Errorf("expecting stream headers (%v), got (%v)", headers, actualHeaders) 1145 } 1146 // Validate stream reset does not return error. 1147 err = s.Reset() 1148 if err != nil { 1149 t.Errorf("unexpected error in stream reset: %v", err) 1150 } 1151 // Validate close with nil connection is an error. 1152 err = s.Close() 1153 if err == nil { 1154 t.Errorf("expecting stream Close error, but received none") 1155 } 1156 if !strings.Contains(err.Error(), "Close() on already closed stream") { 1157 t.Errorf("expected stream close error, got (%s)", err) 1158 } 1159 // Validate write with nil connection is an error. 1160 n, err := s.Write([]byte("not written")) 1161 if n != 0 { 1162 t.Errorf("expected zero bytes written, wrote (%d) instead", n) 1163 } 1164 if err == nil { 1165 t.Errorf("expecting stream Write error, but received none") 1166 } 1167 if !strings.Contains(err.Error(), "write on closed stream") { 1168 t.Errorf("expected stream write error, got (%s)", err) 1169 } 1170 // Validate CreateStream errors -- unknown stream 1171 headers = http.Header{} 1172 headers.Set(v1.StreamType, "UNKNOWN") 1173 _, err = c.CreateStream(headers) 1174 if err == nil { 1175 t.Errorf("expecting CreateStream error, but received none") 1176 } else if !strings.Contains(err.Error(), "unknown stream type") { 1177 t.Errorf("expecting unknown stream type error, got (%s)", err.Error()) 1178 } 1179 // Validate CreateStream errors -- duplicate stream 1180 headers.Set(v1.StreamType, v1.StreamTypeError) 1181 c.streams[remotecommand.StreamErr] = &stream{} 1182 _, err = c.CreateStream(headers) 1183 if err == nil { 1184 t.Errorf("expecting CreateStream error, but received none") 1185 } else if !strings.Contains(err.Error(), "duplicate stream") { 1186 t.Errorf("expecting duplicate stream error, got (%s)", err.Error()) 1187 } 1188 } 1189 1190 // options contains details about which streams are required for 1191 // remote command execution. 1192 type options struct { 1193 stdin bool 1194 stdout bool 1195 stderr bool 1196 tty bool 1197 } 1198 1199 // Translates query params in request into options struct. 1200 func streamOptionsFromRequest(req *http.Request) *options { 1201 query := req.URL.Query() 1202 tty := query.Get("tty") == "true" 1203 stdin := query.Get("stdin") == "true" 1204 stdout := query.Get("stdout") == "true" 1205 stderr := query.Get("stderr") == "true" 1206 return &options{ 1207 stdin: stdin, 1208 stdout: stdout, 1209 stderr: stderr, 1210 tty: tty, 1211 } 1212 } 1213 1214 // websocketStreams contains the WebSocket connection and streams from a server. 1215 type websocketStreams struct { 1216 conn io.Closer 1217 stdinStream io.ReadCloser 1218 stdoutStream io.WriteCloser 1219 stderrStream io.WriteCloser 1220 writeStatus func(status *apierrors.StatusError) error 1221 resizeStream io.ReadCloser 1222 resizeChan chan TerminalSize 1223 tty bool 1224 } 1225 1226 // Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed 1227 // in the stream options. 1228 func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) { 1229 conn, err := createWebSocketStreams(req, w, opts) 1230 if err != nil { 1231 return nil, err 1232 } 1233 1234 if conn.resizeStream != nil { 1235 conn.resizeChan = make(chan TerminalSize) 1236 go handleResizeEvents(req.Context(), conn.resizeStream, conn.resizeChan) 1237 } 1238 1239 return conn, nil 1240 } 1241 1242 // Read terminal resize events off of passed stream and queue into passed channel. 1243 func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- TerminalSize) { 1244 defer close(channel) 1245 1246 decoder := json.NewDecoder(stream) 1247 for { 1248 size := TerminalSize{} 1249 if err := decoder.Decode(&size); err != nil { 1250 break 1251 } 1252 1253 select { 1254 case channel <- size: 1255 case <-ctx.Done(): 1256 // To avoid leaking this routine, exit if the http request finishes. This path 1257 // would generally be hit if starting the process fails and nothing is started to 1258 // ingest these resize events. 1259 return 1260 } 1261 } 1262 } 1263 1264 // createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2) 1265 // along with the approximate duplex value. It also creates the error (3) and resize (4) channels. 1266 func createChannels(opts *options) []wsstream.ChannelType { 1267 // open the requested channels, and always open the error channel 1268 channels := make([]wsstream.ChannelType, 5) 1269 channels[remotecommand.StreamStdIn] = readChannel(opts.stdin) 1270 channels[remotecommand.StreamStdOut] = writeChannel(opts.stdout) 1271 channels[remotecommand.StreamStdErr] = writeChannel(opts.stderr) 1272 channels[remotecommand.StreamErr] = wsstream.WriteChannel 1273 channels[remotecommand.StreamResize] = wsstream.ReadChannel 1274 return channels 1275 } 1276 1277 // readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel. 1278 func readChannel(real bool) wsstream.ChannelType { 1279 if real { 1280 return wsstream.ReadChannel 1281 } 1282 return wsstream.IgnoreChannel 1283 } 1284 1285 // writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel. 1286 func writeChannel(real bool) wsstream.ChannelType { 1287 if real { 1288 return wsstream.WriteChannel 1289 } 1290 return wsstream.IgnoreChannel 1291 } 1292 1293 // createWebSocketStreams returns a "channels" struct containing the websocket connection and 1294 // streams needed to perform an exec or an attach. 1295 func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) { 1296 channels := createChannels(opts) 1297 conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ 1298 remotecommand.StreamProtocolV5Name: { 1299 Binary: true, 1300 Channels: channels, 1301 }, 1302 }) 1303 conn.SetIdleTimeout(4 * time.Hour) 1304 // Opening the connection responds to WebSocket client, negotiating 1305 // the WebSocket upgrade connection and the subprotocol. 1306 _, streams, err := conn.Open(w, req) 1307 if err != nil { 1308 return nil, err 1309 } 1310 1311 // Send an empty message to the lowest writable channel to notify the client the connection is established 1312 //nolint:errcheck 1313 switch { 1314 case opts.stdout: 1315 streams[remotecommand.StreamStdOut].Write([]byte{}) 1316 case opts.stderr: 1317 streams[remotecommand.StreamStdErr].Write([]byte{}) 1318 default: 1319 streams[remotecommand.StreamErr].Write([]byte{}) 1320 } 1321 1322 wsStreams := &websocketStreams{ 1323 conn: conn, 1324 stdinStream: streams[remotecommand.StreamStdIn], 1325 stdoutStream: streams[remotecommand.StreamStdOut], 1326 stderrStream: streams[remotecommand.StreamStdErr], 1327 tty: opts.tty, 1328 resizeStream: streams[remotecommand.StreamResize], 1329 } 1330 1331 wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error { 1332 return func(status *apierrors.StatusError) error { 1333 bs, err := json.Marshal(status.Status()) 1334 if err != nil { 1335 return err 1336 } 1337 _, err = stream.Write(bs) 1338 return err 1339 } 1340 }(streams[remotecommand.StreamErr]) 1341 1342 return wsStreams, nil 1343 } 1344 1345 // See (https://github.com/kubernetes/kubernetes/issues/126134). 1346 func TestWebSocketClient_HTTPSProxyErrorExpected(t *testing.T) { 1347 urlStr := "http://127.0.0.1/never-used" + "?" + "stdin=true" + "&" + "stdout=true" 1348 websocketLocation, err := url.Parse(urlStr) 1349 if err != nil { 1350 t.Fatalf("Unable to parse WebSocket server URL: %s", urlStr) 1351 } 1352 // proxy url with https scheme will trigger websocket dialing error. 1353 httpsProxyFunc := func(req *http.Request) (*url.URL, error) { return url.Parse("https://127.0.0.1") } 1354 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host, Proxy: httpsProxyFunc}, "GET", urlStr) 1355 if err != nil { 1356 t.Errorf("unexpected error creating websocket executor: %v", err) 1357 } 1358 var stdout bytes.Buffer 1359 options := &StreamOptions{ 1360 Stdout: &stdout, 1361 } 1362 errorChan := make(chan error) 1363 go func() { 1364 // Start the streaming on the WebSocket "exec" client. 1365 errorChan <- exec.StreamWithContext(context.Background(), *options) 1366 }() 1367 1368 select { 1369 case <-time.After(wait.ForeverTestTimeout): 1370 t.Fatalf("expect stream to be closed after connection is closed.") 1371 case err := <-errorChan: 1372 if err == nil { 1373 t.Errorf("expected error but received none") 1374 } 1375 if !httpstream.IsHTTPSProxyError(err) { 1376 t.Errorf("expected https proxy error, got (%s)", err) 1377 } 1378 } 1379 }