k8s.io/client-go@v0.31.1/tools/remotecommand/websocket_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io"
    26  	"math"
    27  	mrand "math/rand"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"net/url"
    31  	"reflect"
    32  	"strings"
    33  	"sync"
    34  	"testing"
    35  	"time"
    36  
    37  	gwebsocket "github.com/gorilla/websocket"
    38  
    39  	v1 "k8s.io/api/core/v1"
    40  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    41  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    42  	"k8s.io/apimachinery/pkg/util/httpstream"
    43  	"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
    44  	"k8s.io/apimachinery/pkg/util/remotecommand"
    45  	"k8s.io/apimachinery/pkg/util/wait"
    46  	"k8s.io/client-go/rest"
    47  	clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
    48  )
    49  
    50  // TestWebSocketClient_LoopbackStdinToStdout returns random data sent on the STDIN channel
    51  // back down the STDOUT channel. A subsequent comparison checks if the data
    52  // sent on the STDIN channel is the same as the data returned on the STDOUT
    53  // channel. This test can be run many times by the "stress" tool to check
    54  // if there is any data which would cause problems with the WebSocket streams.
    55  func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) {
    56  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
    57  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    58  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
    59  		if err != nil {
    60  			t.Fatalf("error on webSocketServerStreams: %v", err)
    61  		}
    62  		defer conns.conn.Close()
    63  		// Loopback the STDIN stream onto the STDOUT stream.
    64  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
    65  		if err != nil {
    66  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
    67  		}
    68  	}))
    69  	defer websocketServer.Close()
    70  
    71  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
    72  	// Must add STDIN and STDOUT query params for the WebSocket client request.
    73  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
    74  	websocketLocation, err := url.Parse(websocketServer.URL)
    75  	if err != nil {
    76  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
    77  	}
    78  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
    79  	if err != nil {
    80  		t.Errorf("unexpected error creating websocket executor: %v", err)
    81  	}
    82  	// Generate random data, and set it up to stream on STDIN. The data will be
    83  	// returned on the STDOUT buffer.
    84  	randomSize := 1024 * 1024
    85  	randomData := make([]byte, randomSize)
    86  	if _, err := rand.Read(randomData); err != nil {
    87  		t.Errorf("unexpected error reading random data: %v", err)
    88  	}
    89  	var stdout bytes.Buffer
    90  	options := &StreamOptions{
    91  		Stdin:  bytes.NewReader(randomData),
    92  		Stdout: &stdout,
    93  	}
    94  	errorChan := make(chan error)
    95  	go func() {
    96  		// Start the streaming on the WebSocket "exec" client.
    97  		errorChan <- exec.StreamWithContext(context.Background(), *options)
    98  	}()
    99  
   100  	select {
   101  	case <-time.After(wait.ForeverTestTimeout):
   102  		t.Fatalf("expect stream to be closed after connection is closed.")
   103  	case err := <-errorChan:
   104  		if err != nil {
   105  			t.Errorf("unexpected error")
   106  		}
   107  		// Validate remote command v5 protocol was negotiated.
   108  		streamExec := exec.(*wsStreamExecutor)
   109  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   110  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   111  		}
   112  	}
   113  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   114  	if err != nil {
   115  		t.Fatalf("error reading the stream: %v", err)
   116  	}
   117  	// Check the random data sent on STDIN was the same returned on STDOUT.
   118  	if !bytes.Equal(randomData, data) {
   119  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   120  	}
   121  }
   122  
   123  // TestWebSocketClient_DifferentBufferSizes runs the previous loopback (STDIN -> STDOUT) test with different
   124  // buffer sizes for reading from the opposite end of the websocket connection (in the websocket server).
   125  func TestWebSocketClient_DifferentBufferSizes(t *testing.T) {
   126  	// 1k, 4k, 64k, and 128k buffer sizes for reading STDIN at websocket server endpoint.
   127  	// The standard buffer size for io.Copy is 32k.
   128  	bufferSizes := []int{1 * 1024, 4 * 1024, 64 * 1024, 128 * 1024}
   129  	for _, bufferSize := range bufferSizes {
   130  		// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   131  		websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   132  			conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   133  			if err != nil {
   134  				t.Fatalf("error on webSocketServerStreams: %v", err)
   135  			}
   136  			defer conns.conn.Close()
   137  			// Loopback the STDIN stream onto the STDOUT stream, using buffer with size.
   138  			buffer := make([]byte, bufferSize)
   139  			_, err = io.CopyBuffer(conns.stdoutStream, conns.stdinStream, buffer)
   140  			if err != nil {
   141  				t.Fatalf("error copying STDIN to STDOUT: %v", err)
   142  			}
   143  		}))
   144  		defer websocketServer.Close()
   145  
   146  		// Now create the WebSocket client (executor), and point it to the "websocketServer".
   147  		// Must add STDIN and STDOUT query params for the WebSocket client request.
   148  		websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   149  		websocketLocation, err := url.Parse(websocketServer.URL)
   150  		if err != nil {
   151  			t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   152  		}
   153  		exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   154  		if err != nil {
   155  			t.Errorf("unexpected error creating websocket executor: %v", err)
   156  		}
   157  		// Generate random data, and set it up to stream on STDIN. The data will be
   158  		// returned on the STDOUT buffer.
   159  		randomSize := 1024 * 1024
   160  		randomData := make([]byte, randomSize)
   161  		if _, err := rand.Read(randomData); err != nil {
   162  			t.Errorf("unexpected error reading random data: %v", err)
   163  		}
   164  		var stdout bytes.Buffer
   165  		options := &StreamOptions{
   166  			Stdin:  bytes.NewReader(randomData),
   167  			Stdout: &stdout,
   168  		}
   169  		errorChan := make(chan error)
   170  		go func() {
   171  			// Start the streaming on the WebSocket "exec" client.
   172  			errorChan <- exec.StreamWithContext(context.Background(), *options)
   173  		}()
   174  
   175  		select {
   176  		case <-time.After(wait.ForeverTestTimeout):
   177  			t.Fatalf("expect stream to be closed after connection is closed.")
   178  		case err := <-errorChan:
   179  			if err != nil {
   180  				t.Errorf("unexpected error")
   181  			}
   182  			// Validate remote command v5 protocol was negotiated.
   183  			streamExec := exec.(*wsStreamExecutor)
   184  			if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   185  				t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   186  			}
   187  		}
   188  		data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   189  		if err != nil {
   190  			t.Errorf("error reading the stream: %v", err)
   191  			return
   192  		}
   193  		// Check all the random data sent on STDIN was the same returned on STDOUT.
   194  		if !bytes.Equal(randomData, data) {
   195  			t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   196  		}
   197  	}
   198  }
   199  
   200  // TestWebSocketClient_LoopbackStdinAsPipe uses a pipe to send random data on the STDIN
   201  // channel, then closes the pipe. The fake server simply returns all STDIN data back
   202  // onto the STDOUT channel, and the received data on STDOUT is validated against the
   203  // random data initially sent.
   204  func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) {
   205  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   206  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   207  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   208  		if err != nil {
   209  			t.Fatalf("error on webSocketServerStreams: %v", err)
   210  		}
   211  		defer conns.conn.Close()
   212  		// Loopback the STDIN stream onto the STDOUT stream.
   213  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
   214  		if err != nil {
   215  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
   216  		}
   217  	}))
   218  	defer websocketServer.Close()
   219  
   220  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   221  	// Must add STDIN and STDOUT query params for the WebSocket client request.
   222  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   223  	websocketLocation, err := url.Parse(websocketServer.URL)
   224  	if err != nil {
   225  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   226  	}
   227  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   228  	if err != nil {
   229  		t.Errorf("unexpected error creating websocket executor: %v", err)
   230  	}
   231  	// Generate random data, and it will be written on the STDIN pipe. The same
   232  	// data will be returned on the STDOUT channel.
   233  	randomSize := 1024 * 1024
   234  	randomData := make([]byte, randomSize)
   235  	if _, err := rand.Read(randomData); err != nil {
   236  		t.Errorf("unexpected error reading random data: %v", err)
   237  	}
   238  	reader, writer := io.Pipe()
   239  	var stdout bytes.Buffer
   240  	options := &StreamOptions{
   241  		Stdin:  reader,
   242  		Stdout: &stdout,
   243  	}
   244  	errorChan := make(chan error)
   245  	go func() {
   246  		// Start the streaming on the WebSocket "exec" client.
   247  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   248  	}()
   249  	// Write the random data onto the pipe connected to STDIN, then close the pipe.
   250  	_, err = writer.Write(randomData)
   251  	if err != nil {
   252  		t.Fatalf("unable to write random data to STDIN pipe: %v", err)
   253  	}
   254  	writer.Close()
   255  
   256  	select {
   257  	case <-time.After(wait.ForeverTestTimeout):
   258  		t.Fatalf("expect stream to be closed after connection is closed.")
   259  	case err := <-errorChan:
   260  		if err != nil {
   261  			t.Errorf("unexpected error")
   262  		}
   263  		// Validate remote command v5 protocol was negotiated.
   264  		streamExec := exec.(*wsStreamExecutor)
   265  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   266  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   267  		}
   268  	}
   269  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   270  	if err != nil {
   271  		t.Errorf("error reading the stream: %v", err)
   272  		return
   273  	}
   274  	// Check the random data sent on STDIN was the same returned on STDOUT.
   275  	if !bytes.Equal(randomData, data) {
   276  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   277  	}
   278  }
   279  
   280  // TestWebSocketClient_LoopbackStdinToStderr returns random data sent on the STDIN channel
   281  // back down the STDERR channel. A subsequent comparison checks if the data
   282  // sent on the STDIN channel is the same as the data returned on the STDERR
   283  // channel. This test can be run many times by the "stress" tool to check
   284  // if there is any data which would cause problems with the WebSocket streams.
   285  func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) {
   286  	// Create fake WebSocket server. Copy received STDIN data back onto STDERR stream.
   287  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   288  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   289  		if err != nil {
   290  			t.Fatalf("error on webSocketServerStreams: %v", err)
   291  		}
   292  		defer conns.conn.Close()
   293  		// Loopback the STDIN stream onto the STDERR stream.
   294  		_, err = io.Copy(conns.stderrStream, conns.stdinStream)
   295  		if err != nil {
   296  			t.Fatalf("error copying STDIN to STDERR: %v", err)
   297  		}
   298  	}))
   299  	defer websocketServer.Close()
   300  
   301  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   302  	// Must add STDIN and STDERR query params for the WebSocket client request.
   303  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
   304  	websocketLocation, err := url.Parse(websocketServer.URL)
   305  	if err != nil {
   306  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   307  	}
   308  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   309  	if err != nil {
   310  		t.Errorf("unexpected error creating websocket executor: %v", err)
   311  	}
   312  	// Generate random data, and set it up to stream on STDIN. The data will be
   313  	// returned on the STDERR buffer.
   314  	randomSize := 1024 * 1024
   315  	randomData := make([]byte, randomSize)
   316  	if _, err := rand.Read(randomData); err != nil {
   317  		t.Errorf("unexpected error reading random data: %v", err)
   318  	}
   319  	var stderr bytes.Buffer
   320  	options := &StreamOptions{
   321  		Stdin:  bytes.NewReader(randomData),
   322  		Stderr: &stderr,
   323  	}
   324  	errorChan := make(chan error)
   325  	go func() {
   326  		// Start the streaming on the WebSocket "exec" client.
   327  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   328  	}()
   329  
   330  	select {
   331  	case <-time.After(wait.ForeverTestTimeout):
   332  		t.Fatalf("expect stream to be closed after connection is closed.")
   333  	case err := <-errorChan:
   334  		if err != nil {
   335  			t.Errorf("unexpected error")
   336  		}
   337  		// Validate remote command v5 protocol was negotiated.
   338  		streamExec := exec.(*wsStreamExecutor)
   339  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   340  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   341  		}
   342  	}
   343  	data, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
   344  	if err != nil {
   345  		t.Errorf("error reading the stream: %v", err)
   346  		return
   347  	}
   348  	// Check the random data sent on STDIN was the same returned on STDERR.
   349  	if !bytes.Equal(randomData, data) {
   350  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   351  	}
   352  }
   353  
   354  // TestWebSocketClient_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from
   355  // the websocket connection at the same time.
   356  func TestWebSocketClient_MultipleReadChannels(t *testing.T) {
   357  	// Create fake WebSocket server, which uses a TeeReader to copy the same data
   358  	// onto the STDOUT stream onto the STDERR stream as well.
   359  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   360  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   361  		if err != nil {
   362  			t.Fatalf("error on webSocketServerStreams: %v", err)
   363  		}
   364  		defer conns.conn.Close()
   365  		// TeeReader copies data read on STDIN onto STDERR.
   366  		stdinReader := io.TeeReader(conns.stdinStream, conns.stderrStream)
   367  		// Also copy STDIN to STDOUT.
   368  		_, err = io.Copy(conns.stdoutStream, stdinReader)
   369  		if err != nil {
   370  			t.Errorf("error copying STDIN to STDOUT: %v", err)
   371  		}
   372  	}))
   373  	defer websocketServer.Close()
   374  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   375  	// Must add stdin, stdout, and stderr query param for the WebSocket client request.
   376  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + "&" + "stderr=true"
   377  	websocketLocation, err := url.Parse(websocketServer.URL)
   378  	if err != nil {
   379  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   380  	}
   381  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   382  	if err != nil {
   383  		t.Errorf("unexpected error creating websocket executor: %v", err)
   384  	}
   385  	// Generate 1MB of random data, and set it up to stream on STDIN. The data will be
   386  	// returned on the STDOUT and STDERR buffers.
   387  	randomSize := 1024 * 1024
   388  	randomData := make([]byte, randomSize)
   389  	if _, err := rand.Read(randomData); err != nil {
   390  		t.Errorf("unexpected error reading random data: %v", err)
   391  	}
   392  	var stdout, stderr bytes.Buffer
   393  	options := &StreamOptions{
   394  		Stdin:  bytes.NewReader(randomData),
   395  		Stdout: &stdout,
   396  		Stderr: &stderr,
   397  	}
   398  	errorChan := make(chan error)
   399  	go func() {
   400  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   401  	}()
   402  
   403  	select {
   404  	case <-time.After(wait.ForeverTestTimeout):
   405  		t.Fatalf("expect stream to be closed after connection is closed.")
   406  	case err := <-errorChan:
   407  		if err != nil {
   408  			t.Errorf("unexpected error: %v", err)
   409  		}
   410  		// Validate remote command v5 protocol was negotiated.
   411  		streamExec := exec.(*wsStreamExecutor)
   412  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   413  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   414  		}
   415  	}
   416  	// Validate the data read from the STDOUT stream is the same as sent on the STDIN stream.
   417  	stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   418  	if err != nil {
   419  		t.Fatalf("error reading the stream: %v", err)
   420  	}
   421  	if !bytes.Equal(stdoutBytes, randomData) {
   422  		t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(randomData))
   423  	}
   424  	// Validate the data read from the STDERR stream is the same as sent on the STDIN stream.
   425  	stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
   426  	if err != nil {
   427  		t.Fatalf("error reading the stream: %v", err)
   428  	}
   429  	if !bytes.Equal(stderrBytes, randomData) {
   430  		t.Errorf("unexpected data received (%d) sent (%d)", len(stderrBytes), len(randomData))
   431  	}
   432  }
   433  
   434  // Returns a random exit code in the range(1-127).
   435  func randomExitCode() int {
   436  	errorCode := mrand.Intn(128)
   437  	if errorCode == 0 {
   438  		errorCode = 1
   439  	}
   440  	return errorCode
   441  }
   442  
   443  // TestWebSocketClient_ErrorStream tests the websocket error stream by hard-coding a
   444  // structured non-zero exit code error from the websocket server to the websocket client.
   445  func TestWebSocketClient_ErrorStream(t *testing.T) {
   446  	expectedExitCode := randomExitCode()
   447  	// Create fake WebSocket server. Returns structured exit code error on error stream.
   448  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   449  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   450  		if err != nil {
   451  			t.Fatalf("error on webSocketServerStreams: %v", err)
   452  		}
   453  		defer conns.conn.Close()
   454  		_, err = io.Copy(conns.stderrStream, conns.stdinStream)
   455  		if err != nil {
   456  			t.Fatalf("error copying STDIN to STDERR: %v", err)
   457  		}
   458  		// Force an non-zero exit code error returned on the error stream.
   459  		err = conns.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
   460  			Status: metav1.StatusFailure,
   461  			Reason: remotecommand.NonZeroExitCodeReason,
   462  			Details: &metav1.StatusDetails{
   463  				Causes: []metav1.StatusCause{
   464  					{
   465  						Type:    remotecommand.ExitCodeCauseType,
   466  						Message: fmt.Sprintf("%d", expectedExitCode),
   467  					},
   468  				},
   469  			},
   470  		}})
   471  		if err != nil {
   472  			t.Fatalf("error writing status: %v", err)
   473  		}
   474  	}))
   475  	defer websocketServer.Close()
   476  
   477  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   478  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
   479  	websocketLocation, err := url.Parse(websocketServer.URL)
   480  	if err != nil {
   481  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   482  	}
   483  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   484  	if err != nil {
   485  		t.Errorf("unexpected error creating websocket executor: %v", err)
   486  	}
   487  	randomData := make([]byte, 256)
   488  	if _, err := rand.Read(randomData); err != nil {
   489  		t.Errorf("unexpected error reading random data: %v", err)
   490  	}
   491  	var stderr bytes.Buffer
   492  	options := &StreamOptions{
   493  		Stdin:  bytes.NewReader(randomData),
   494  		Stderr: &stderr,
   495  	}
   496  	errorChan := make(chan error)
   497  	go func() {
   498  		// Start the streaming on the WebSocket "exec" client.
   499  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   500  	}()
   501  
   502  	select {
   503  	case <-time.After(wait.ForeverTestTimeout):
   504  		t.Fatalf("expect stream to be closed after connection is closed.")
   505  	case err := <-errorChan:
   506  		// Validate remote command v5 protocol was negotiated.
   507  		streamExec := exec.(*wsStreamExecutor)
   508  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   509  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   510  		}
   511  		// Expect exit code error on error stream.
   512  		if err == nil {
   513  			t.Errorf("expected error, but received none")
   514  		}
   515  		expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode)
   516  		// Compare expected error with exit code to actual error.
   517  		if expectedError != err.Error() {
   518  			t.Errorf("expected error (%s), got (%s)", expectedError, err)
   519  		}
   520  	}
   521  }
   522  
   523  // fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of
   524  // "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice.
   525  type fakeTerminalSizeQueue struct {
   526  	maxSizes      int
   527  	terminalSizes []TerminalSize
   528  }
   529  
   530  // newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing
   531  // "max" number of random TerminalSizes created.
   532  func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue {
   533  	return &fakeTerminalSizeQueue{
   534  		maxSizes:      max,
   535  		terminalSizes: make([]TerminalSize, 0, max),
   536  	}
   537  }
   538  
   539  // Next returns a pointer to the next random TerminalSize, or nil if we have
   540  // already returned "maxSizes" TerminalSizes already. Stores the randomly
   541  // created TerminalSize in "terminalSizes" field for later validation.
   542  func (f *fakeTerminalSizeQueue) Next() *TerminalSize {
   543  	if len(f.terminalSizes) >= f.maxSizes {
   544  		return nil
   545  	}
   546  	size := randomTerminalSize()
   547  	f.terminalSizes = append(f.terminalSizes, size)
   548  	return &size
   549  }
   550  
   551  // randomTerminalSize returns a TerminalSize with random values in the
   552  // range (0-65535) for the fields Width and Height.
   553  func randomTerminalSize() TerminalSize {
   554  	randWidth := uint16(mrand.Intn(int(math.Pow(2, 16))))
   555  	randHeight := uint16(mrand.Intn(int(math.Pow(2, 16))))
   556  	return TerminalSize{
   557  		Width:  randWidth,
   558  		Height: randHeight,
   559  	}
   560  }
   561  
   562  // randReader implements the ReadCloser interface, and it continuously
   563  // returns random data until it is closed. Stores number of random
   564  // bytes generated and returned.
   565  type randReader struct {
   566  	randBytes []byte
   567  	closed    bool
   568  	lock      sync.Mutex
   569  }
   570  
   571  // Read implements the Reader interface filling the passed buffer with
   572  // random data, returning the number of bytes filled and an error
   573  // if one occurs. Return 0 and EOF if the randReader has been closed.
   574  func (r *randReader) Read(b []byte) (int, error) {
   575  	r.lock.Lock()
   576  	defer r.lock.Unlock()
   577  	if r.closed {
   578  		return 0, io.EOF
   579  	}
   580  	n, err := rand.Read(b)
   581  	c := bytes.Clone(b)
   582  	r.randBytes = append(r.randBytes, c...)
   583  	return n, err
   584  }
   585  
   586  // Close implements the Closer interface, setting the close field true.
   587  // Further calls to Read() after Close() will return 0, EOF. Returns
   588  // nil error.
   589  func (r *randReader) Close() (err error) {
   590  	r.lock.Lock()
   591  	defer r.lock.Unlock()
   592  	r.closed = true
   593  	return nil
   594  }
   595  
   596  // TestWebSocketClient_MultipleWriteChannels tests two streams (STDIN, TTY resize) writing to the
   597  // websocket connection at the same time to exercise the connection write lock.
   598  func TestWebSocketClient_MultipleWriteChannels(t *testing.T) {
   599  	// Create the fake terminal size queue and the actualTerminalSizes which
   600  	// will be received at the opposite websocket endpoint.
   601  	numSizeQueue := 10000
   602  	sizeQueue := newTerminalSizeQueue(numSizeQueue)
   603  	actualTerminalSizes := make([]TerminalSize, 0, numSizeQueue)
   604  	// Create ReadCloser sending random data on STDIN stream over websocket connection.
   605  	stdinReader := randReader{randBytes: []byte{}, closed: false}
   606  	// Create fake WebSocket server, which will receive concurrently the STDIN stream as
   607  	// well as the resize stream (TerminalSizes). Store the TerminalSize data from the resize
   608  	// stream for subsequent validation.
   609  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   610  		var wg sync.WaitGroup
   611  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   612  		if err != nil {
   613  			t.Fatalf("error on webSocketServerStreams: %v", err)
   614  		}
   615  		defer conns.conn.Close()
   616  		// Create goroutine to loopback the STDIN stream onto the STDOUT stream.
   617  		wg.Add(1)
   618  		go func() {
   619  			_, err := io.Copy(conns.stdoutStream, conns.stdinStream)
   620  			if err != nil {
   621  				t.Errorf("error copying STDIN to STDOUT: %v", err)
   622  			}
   623  			wg.Done()
   624  		}()
   625  		// Read the terminal resize requests, storing them in actualTerminalSizes
   626  		for i := 0; i < numSizeQueue; i++ {
   627  			actualTerminalSize := <-conns.resizeChan
   628  			actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize)
   629  		}
   630  		stdinReader.Close() // Stops the random STDIN stream generation
   631  		wg.Wait()           // Wait for all bytes copied from STDIN to STDOUT
   632  	}))
   633  	defer websocketServer.Close()
   634  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   635  	// Must add stdin, stdout, and TTY query param for the WebSocket client request.
   636  	websocketServer.URL = websocketServer.URL + "?" + "tty=true" + "&" + "stdin=true" + "&" + "stdout=true"
   637  	websocketLocation, err := url.Parse(websocketServer.URL)
   638  	if err != nil {
   639  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   640  	}
   641  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   642  	if err != nil {
   643  		t.Errorf("unexpected error creating websocket executor: %v", err)
   644  	}
   645  	var stdout bytes.Buffer
   646  	options := &StreamOptions{
   647  		Stdin:             &stdinReader,
   648  		Stdout:            &stdout,
   649  		Tty:               true,
   650  		TerminalSizeQueue: sizeQueue,
   651  	}
   652  	errorChan := make(chan error)
   653  	go func() {
   654  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   655  	}()
   656  
   657  	select {
   658  	case <-time.After(wait.ForeverTestTimeout):
   659  		t.Fatalf("expect stream to be closed after connection is closed.")
   660  	case err := <-errorChan:
   661  		if err != nil {
   662  			t.Errorf("unexpected error: %v", err)
   663  		}
   664  		// Validate remote command v5 protocol was negotiated.
   665  		streamExec := exec.(*wsStreamExecutor)
   666  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   667  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   668  		}
   669  	}
   670  	// Check the random data sent on STDIN was the same returned on STDOUT *and*
   671  	// that a minimum amount of random data was sent and received, ensuring concurrency.
   672  	stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   673  	if err != nil {
   674  		t.Fatalf("error reading the stream: %v", err)
   675  	}
   676  	if len(stdoutBytes) == 0 {
   677  		t.Errorf("No STDOUT bytes processed before resize stream finished: %d", len(stdoutBytes))
   678  	}
   679  	if !bytes.Equal(stdoutBytes, stdinReader.randBytes) {
   680  		t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(stdinReader.randBytes))
   681  	}
   682  	// Validate the random TerminalSizes sent on the resize stream are the same
   683  	// as the actual TerminalSizes received at the websocket server.
   684  	if len(actualTerminalSizes) != numSizeQueue {
   685  		t.Errorf("expected received terminal size window (%d), got (%d)",
   686  			numSizeQueue, len(actualTerminalSizes))
   687  	}
   688  	for i, actual := range actualTerminalSizes {
   689  		expected := sizeQueue.terminalSizes[i]
   690  		if !reflect.DeepEqual(expected, actual) {
   691  			t.Errorf("expected terminal resize window %v, got %v", expected, actual)
   692  		}
   693  	}
   694  }
   695  
   696  // TestWebSocketClient_ProtocolVersions validates that remote command subprotocol versions V2-V4
   697  // (V5 is already tested elsewhere) can be negotiated.
   698  func TestWebSocketClient_ProtocolVersions(t *testing.T) {
   699  	// Create a raw websocket server that accepts V2-V4 versions of
   700  	// the remote command subprotocol.
   701  	var upgrader = gwebsocket.Upgrader{
   702  		CheckOrigin: func(r *http.Request) bool {
   703  			return true // Accepting all requests
   704  		},
   705  		Subprotocols: []string{
   706  			remotecommand.StreamProtocolV4Name,
   707  			remotecommand.StreamProtocolV3Name,
   708  			remotecommand.StreamProtocolV2Name,
   709  		},
   710  	}
   711  	// Upgrade a raw websocket server connection.
   712  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   713  		conn, err := upgrader.Upgrade(w, req, nil)
   714  		if err != nil {
   715  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   716  		}
   717  		defer conn.Close()
   718  	}))
   719  	defer websocketServer.Close()
   720  
   721  	// Set up the websocket client with the STDOUT stream.
   722  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   723  	websocketLocation, err := url.Parse(websocketServer.URL)
   724  	if err != nil {
   725  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   726  	}
   727  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   728  	if err != nil {
   729  		t.Errorf("unexpected error creating websocket executor: %v", err)
   730  	}
   731  	// Iterate through previous remote command protocol versions, validating the
   732  	// requested protocol version is the one that is negotiated.
   733  	versions := []string{
   734  		remotecommand.StreamProtocolV4Name,
   735  		remotecommand.StreamProtocolV3Name,
   736  		remotecommand.StreamProtocolV2Name,
   737  	}
   738  	for _, requestedVersion := range versions {
   739  		streamExec := exec.(*wsStreamExecutor)
   740  		streamExec.protocols = []string{requestedVersion}
   741  		var stdout bytes.Buffer
   742  		options := &StreamOptions{
   743  			Stdout: &stdout,
   744  		}
   745  		errorChan := make(chan error)
   746  		go func() {
   747  			// Start the streaming on the WebSocket "exec" client.
   748  			errorChan <- exec.StreamWithContext(context.Background(), *options)
   749  		}()
   750  
   751  		select {
   752  		case <-time.After(wait.ForeverTestTimeout):
   753  			t.Fatalf("expect stream to be closed after connection is closed.")
   754  		case <-errorChan:
   755  			// Validate remote command protocol requestedVersion was negotiated.
   756  			streamExec := exec.(*wsStreamExecutor)
   757  			if requestedVersion != streamExec.negotiated {
   758  				t.Fatalf("expected protocol version (%s), got (%s)", requestedVersion, streamExec.negotiated)
   759  			}
   760  		}
   761  	}
   762  }
   763  
   764  // TestWebSocketClient_BadHandshake tests that a "bad handshake" error occurs when
   765  // the WebSocketExecutor attempts to upgrade the connection to a subprotocol version
   766  // (V4) that is not supported by the websocket server (only supports V5).
   767  func TestWebSocketClient_BadHandshake(t *testing.T) {
   768  	// Create fake WebSocket server (supports V5 subprotocol).
   769  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   770  		// Bad handshake means websocket server will not completely initialize.
   771  		_, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   772  		if err == nil {
   773  			t.Fatalf("expected error, but received none.")
   774  		}
   775  		if !strings.Contains(err.Error(), "websocket server finished before becoming ready") {
   776  			t.Errorf("expected websocket server error, but got: %v", err)
   777  		}
   778  	}))
   779  	defer websocketServer.Close()
   780  
   781  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   782  	websocketLocation, err := url.Parse(websocketServer.URL)
   783  	if err != nil {
   784  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   785  	}
   786  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   787  	if err != nil {
   788  		t.Errorf("unexpected error creating websocket executor: %v", err)
   789  	}
   790  	streamExec := exec.(*wsStreamExecutor)
   791  	// Set the attempted subprotocol version to V4; websocket server only accepts V5.
   792  	streamExec.protocols = []string{remotecommand.StreamProtocolV4Name}
   793  
   794  	var stdout bytes.Buffer
   795  	options := &StreamOptions{
   796  		Stdout: &stdout,
   797  	}
   798  	errorChan := make(chan error)
   799  	go func() {
   800  		// Start the streaming on the WebSocket "exec" client.
   801  		errorChan <- streamExec.StreamWithContext(context.Background(), *options)
   802  	}()
   803  
   804  	select {
   805  	case <-time.After(wait.ForeverTestTimeout):
   806  		t.Fatalf("expect stream to be closed after connection is closed.")
   807  	case err := <-errorChan:
   808  		// Expecting unable to upgrade connection -- "bad handshake" error.
   809  		if err == nil {
   810  			t.Errorf("expected error but received none")
   811  		}
   812  		if !strings.Contains(err.Error(), "bad handshake") {
   813  			t.Errorf("expected bad handshake error, got (%s)", err)
   814  		}
   815  	}
   816  }
   817  
   818  // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a
   819  // timeout by setting the ping period greater than the deadline.
   820  func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
   821  	blockRequestCtx, unblockRequest := context.WithCancel(context.Background())
   822  	defer unblockRequest()
   823  	// Create fake WebSocket server which blocks.
   824  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   825  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   826  		if err != nil {
   827  			t.Fatalf("error on webSocketServerStreams: %v", err)
   828  		}
   829  		defer conns.conn.Close()
   830  		<-blockRequestCtx.Done()
   831  	}))
   832  	defer websocketServer.Close()
   833  	// Create websocket client connecting to fake server.
   834  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true"
   835  	websocketLocation, err := url.Parse(websocketServer.URL)
   836  	if err != nil {
   837  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   838  	}
   839  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   840  	if err != nil {
   841  		t.Errorf("unexpected error creating websocket executor: %v", err)
   842  	}
   843  	streamExec := exec.(*wsStreamExecutor)
   844  	// Ping period is greater than the ping deadline, forcing the timeout to fire.
   845  	pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it
   846  	pingDeadline := time.Second           // this gives setup 1 second to establish streams
   847  	streamExec.heartbeatPeriod = pingPeriod
   848  	streamExec.heartbeatDeadline = pingDeadline
   849  	// Send some random data to the websocket server through STDIN.
   850  	randomData := make([]byte, 128)
   851  	if _, err := rand.Read(randomData); err != nil {
   852  		t.Errorf("unexpected error reading random data: %v", err)
   853  	}
   854  	options := &StreamOptions{
   855  		Stdin: bytes.NewReader(randomData),
   856  	}
   857  	errorChan := make(chan error)
   858  	go func() {
   859  		// Start the streaming on the WebSocket "exec" client.
   860  		errorChan <- streamExec.StreamWithContext(context.Background(), *options)
   861  	}()
   862  
   863  	select {
   864  	case <-time.After(wait.ForeverTestTimeout):
   865  		t.Fatalf("expected heartbeat timeout, got none.")
   866  	case err := <-errorChan:
   867  		// Expecting heartbeat timeout error.
   868  		if err == nil {
   869  			t.Fatalf("expected error but received none")
   870  		}
   871  		if !strings.Contains(err.Error(), "i/o timeout") {
   872  			t.Errorf("expected heartbeat timeout error, got (%s)", err)
   873  		}
   874  		// Validate remote command v5 protocol was negotiated.
   875  		streamExec := exec.(*wsStreamExecutor)
   876  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   877  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   878  		}
   879  	}
   880  }
   881  
   882  // TestWebSocketClient_TextMessageTypeError tests when the wrong message type is returned
   883  // from the other websocket endpoint. Remote command protocols use "BinaryMessage", but
   884  // this test hard-codes returning a "TextMessage".
   885  func TestWebSocketClient_TextMessageTypeError(t *testing.T) {
   886  	var upgrader = gwebsocket.Upgrader{
   887  		CheckOrigin: func(r *http.Request) bool {
   888  			return true // Accepting all requests
   889  		},
   890  		Subprotocols: []string{remotecommand.StreamProtocolV5Name},
   891  	}
   892  	// Upgrade a raw websocket server connection. Returns wrong message type "TextMessage".
   893  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   894  		conn, err := upgrader.Upgrade(w, req, nil)
   895  		if err != nil {
   896  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   897  		}
   898  		defer conn.Close()
   899  		msg := []byte("test message with wrong message type.")
   900  		stdOutMsg := append([]byte{remotecommand.StreamStdOut}, msg...)
   901  		// Wrong message type "TextMessage".
   902  		err = conn.WriteMessage(gwebsocket.TextMessage, stdOutMsg)
   903  		if err != nil {
   904  			t.Fatalf("error writing text message to websocket: %v", err)
   905  		}
   906  
   907  	}))
   908  	defer websocketServer.Close()
   909  
   910  	// Set up the websocket client with the STDOUT stream.
   911  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   912  	websocketLocation, err := url.Parse(websocketServer.URL)
   913  	if err != nil {
   914  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   915  	}
   916  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   917  	if err != nil {
   918  		t.Errorf("unexpected error creating websocket executor: %v", err)
   919  	}
   920  	var stdout bytes.Buffer
   921  	options := &StreamOptions{
   922  		Stdout: &stdout,
   923  	}
   924  	errorChan := make(chan error)
   925  	go func() {
   926  		// Start the streaming on the WebSocket "exec" client.
   927  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   928  	}()
   929  
   930  	select {
   931  	case <-time.After(wait.ForeverTestTimeout):
   932  		t.Fatalf("expect stream to be closed after connection is closed.")
   933  	case err := <-errorChan:
   934  		// Expecting bad message type error.
   935  		if err == nil {
   936  			t.Fatalf("expected error but received none")
   937  		}
   938  		if !strings.Contains(err.Error(), "unexpected message type") {
   939  			t.Errorf("expected bad message type error, got (%s)", err)
   940  		}
   941  		// Validate remote command v5 protocol was negotiated.
   942  		streamExec := exec.(*wsStreamExecutor)
   943  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   944  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   945  		}
   946  	}
   947  }
   948  
   949  // TestWebSocketClient_EmptyMessageHandled tests that the error of a completely empty message
   950  // is handled correctly. If the message is completely empty, the initial read of the stream id
   951  // should fail (followed by cleanup).
   952  func TestWebSocketClient_EmptyMessageHandled(t *testing.T) {
   953  	var upgrader = gwebsocket.Upgrader{
   954  		CheckOrigin: func(r *http.Request) bool {
   955  			return true // Accepting all requests
   956  		},
   957  		Subprotocols: []string{remotecommand.StreamProtocolV5Name},
   958  	}
   959  	// Upgrade a raw websocket server connection. Returns wrong message type "TextMessage".
   960  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   961  		conn, err := upgrader.Upgrade(w, req, nil)
   962  		if err != nil {
   963  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   964  		}
   965  		defer conn.Close()
   966  		// Send completely empty message, including missing initial stream id.
   967  		conn.WriteMessage(gwebsocket.BinaryMessage, []byte{}) //nolint:errcheck
   968  	}))
   969  	defer websocketServer.Close()
   970  
   971  	// Set up the websocket client with the STDOUT stream.
   972  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   973  	websocketLocation, err := url.Parse(websocketServer.URL)
   974  	if err != nil {
   975  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   976  	}
   977  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   978  	if err != nil {
   979  		t.Errorf("unexpected error creating websocket executor: %v", err)
   980  	}
   981  	var stdout bytes.Buffer
   982  	options := &StreamOptions{
   983  		Stdout: &stdout,
   984  	}
   985  	errorChan := make(chan error)
   986  	go func() {
   987  		// Start the streaming on the WebSocket "exec" client.
   988  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   989  	}()
   990  
   991  	select {
   992  	case <-time.After(wait.ForeverTestTimeout):
   993  		t.Fatalf("expect stream to be closed after connection is closed.")
   994  	case err := <-errorChan:
   995  		// Expecting error reading initial stream id.
   996  		if err == nil {
   997  			t.Fatalf("expected error but received none")
   998  		}
   999  		if !strings.Contains(err.Error(), "read stream id") {
  1000  			t.Errorf("expected error reading stream id, got (%s)", err)
  1001  		}
  1002  		// Validate remote command v5 protocol was negotiated.
  1003  		streamExec := exec.(*wsStreamExecutor)
  1004  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
  1005  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
  1006  		}
  1007  	}
  1008  }
  1009  
  1010  func TestWebSocketClient_ExecutorErrors(t *testing.T) {
  1011  	// Invalid config causes transport creation error in websocket executor constructor.
  1012  	config := rest.Config{
  1013  		ExecProvider: &clientcmdapi.ExecConfig{},
  1014  		AuthProvider: &clientcmdapi.AuthProviderConfig{},
  1015  	}
  1016  	_, err := NewWebSocketExecutor(&config, "GET", "http://localhost")
  1017  	if err == nil {
  1018  		t.Errorf("expecting executor constructor error, but received none.")
  1019  	} else if !strings.Contains(err.Error(), "error creating websocket transports") {
  1020  		t.Errorf("expecting error creating transports, got (%s)", err.Error())
  1021  	}
  1022  	// Verify that a nil context will cause an error in StreamWithContext
  1023  	exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost")
  1024  	if err != nil {
  1025  		t.Errorf("unexpected error creating websocket executor: %v", err)
  1026  	}
  1027  	errorChan := make(chan error)
  1028  	go func() {
  1029  		// Start the streaming on the WebSocket "exec" client.
  1030  		var ctx context.Context
  1031  		errorChan <- exec.StreamWithContext(ctx, StreamOptions{})
  1032  	}()
  1033  
  1034  	select {
  1035  	case <-time.After(wait.ForeverTestTimeout):
  1036  		t.Fatalf("expect stream to be closed after connection is closed.")
  1037  	case err := <-errorChan:
  1038  		// Expecting error with nil context.
  1039  		if err == nil {
  1040  			t.Fatalf("expected error but received none")
  1041  		}
  1042  		if !strings.Contains(err.Error(), "nil Context") {
  1043  			t.Errorf("expected nil context error, got (%s)", err)
  1044  		}
  1045  	}
  1046  }
  1047  
  1048  func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
  1049  	var upgrader = gwebsocket.Upgrader{
  1050  		CheckOrigin: func(r *http.Request) bool {
  1051  			return true // Accepting all requests
  1052  		},
  1053  	}
  1054  	// Upgrade a raw websocket server connection, which automatically responds to Ping.
  1055  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  1056  		conn, err := upgrader.Upgrade(w, req, nil)
  1057  		if err != nil {
  1058  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
  1059  		}
  1060  		defer conn.Close()
  1061  		for {
  1062  			_, _, err := conn.ReadMessage()
  1063  			if err != nil {
  1064  				break
  1065  			}
  1066  		}
  1067  	}))
  1068  	defer websocketServer.Close()
  1069  	// Create a raw websocket client, connecting to the websocket server.
  1070  	url := strings.ReplaceAll(websocketServer.URL, "http", "ws")
  1071  	client, _, err := gwebsocket.DefaultDialer.Dial(url, nil)
  1072  	if err != nil {
  1073  		t.Fatalf("dial: %v", err)
  1074  	}
  1075  	defer client.Close()
  1076  	// Create a heartbeat using the client websocket connection, and start it.
  1077  	// "period" is less than "deadline", so ping/pong heartbeat will succceed.
  1078  	var expectedMsg = "test heartbeat message"
  1079  	var period = 100 * time.Millisecond
  1080  	var deadline = 200 * time.Millisecond
  1081  	heartbeat := newHeartbeat(client, period, deadline)
  1082  	heartbeat.setMessage(expectedMsg)
  1083  	// Add a channel to the handler to retrieve the "pong" message.
  1084  	pongMsgCh := make(chan string)
  1085  	pongHandler := heartbeat.conn.PongHandler()
  1086  	heartbeat.conn.SetPongHandler(func(msg string) error {
  1087  		pongMsgCh <- msg
  1088  		return pongHandler(msg)
  1089  	})
  1090  	go heartbeat.start()
  1091  
  1092  	var wg sync.WaitGroup
  1093  	wg.Add(1)
  1094  	go func() {
  1095  		defer wg.Done()
  1096  		for {
  1097  			_, _, err := client.ReadMessage()
  1098  			if err != nil {
  1099  				t.Logf("client err reading message: %v", err)
  1100  				return
  1101  			}
  1102  		}
  1103  	}()
  1104  
  1105  	select {
  1106  	case actualMsg := <-pongMsgCh:
  1107  		close(heartbeat.closer)
  1108  		// Validate the received pong message is the same as sent in ping.
  1109  		if expectedMsg != actualMsg {
  1110  			t.Errorf("expected received pong message (%s), got (%s)", expectedMsg, actualMsg)
  1111  		}
  1112  	case <-time.After(period * 4):
  1113  		// This case should not happen.
  1114  		close(heartbeat.closer)
  1115  		t.Errorf("unexpected heartbeat timeout")
  1116  	}
  1117  	wg.Wait()
  1118  }
  1119  
  1120  func TestLateStreamCreation(t *testing.T) {
  1121  	c := newWSStreamCreator(nil)
  1122  	c.closeAllStreamReaders(nil)
  1123  	if err := c.setStream(0, nil); err == nil {
  1124  		t.Fatal("expected error adding stream after closeAllStreamReaders")
  1125  	}
  1126  }
  1127  
  1128  func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
  1129  	// Validate Stream functions.
  1130  	c := newWSStreamCreator(nil)
  1131  	headers := http.Header{}
  1132  	headers.Set(v1.StreamType, v1.StreamTypeStdin)
  1133  	s, err := c.CreateStream(headers)
  1134  	if err != nil {
  1135  		t.Errorf("unexpected stream creation error: %v", err)
  1136  	}
  1137  	expectedStreamID := uint32(remotecommand.StreamStdIn)
  1138  	actualStreamID := s.Identifier()
  1139  	if expectedStreamID != actualStreamID {
  1140  		t.Errorf("expecting stream id (%d), got (%d)", expectedStreamID, actualStreamID)
  1141  	}
  1142  	actualHeaders := s.Headers()
  1143  	if !reflect.DeepEqual(headers, actualHeaders) {
  1144  		t.Errorf("expecting stream headers (%v), got (%v)", headers, actualHeaders)
  1145  	}
  1146  	// Validate stream reset does not return error.
  1147  	err = s.Reset()
  1148  	if err != nil {
  1149  		t.Errorf("unexpected error in stream reset: %v", err)
  1150  	}
  1151  	// Validate close with nil connection is an error.
  1152  	err = s.Close()
  1153  	if err == nil {
  1154  		t.Errorf("expecting stream Close error, but received none")
  1155  	}
  1156  	if !strings.Contains(err.Error(), "Close() on already closed stream") {
  1157  		t.Errorf("expected stream close error, got (%s)", err)
  1158  	}
  1159  	// Validate write with nil connection is an error.
  1160  	n, err := s.Write([]byte("not written"))
  1161  	if n != 0 {
  1162  		t.Errorf("expected zero bytes written, wrote (%d) instead", n)
  1163  	}
  1164  	if err == nil {
  1165  		t.Errorf("expecting stream Write error, but received none")
  1166  	}
  1167  	if !strings.Contains(err.Error(), "write on closed stream") {
  1168  		t.Errorf("expected stream write error, got (%s)", err)
  1169  	}
  1170  	// Validate CreateStream errors -- unknown stream
  1171  	headers = http.Header{}
  1172  	headers.Set(v1.StreamType, "UNKNOWN")
  1173  	_, err = c.CreateStream(headers)
  1174  	if err == nil {
  1175  		t.Errorf("expecting CreateStream error, but received none")
  1176  	} else if !strings.Contains(err.Error(), "unknown stream type") {
  1177  		t.Errorf("expecting unknown stream type error, got (%s)", err.Error())
  1178  	}
  1179  	// Validate CreateStream errors -- duplicate stream
  1180  	headers.Set(v1.StreamType, v1.StreamTypeError)
  1181  	c.streams[remotecommand.StreamErr] = &stream{}
  1182  	_, err = c.CreateStream(headers)
  1183  	if err == nil {
  1184  		t.Errorf("expecting CreateStream error, but received none")
  1185  	} else if !strings.Contains(err.Error(), "duplicate stream") {
  1186  		t.Errorf("expecting duplicate stream error, got (%s)", err.Error())
  1187  	}
  1188  }
  1189  
  1190  // options contains details about which streams are required for
  1191  // remote command execution.
  1192  type options struct {
  1193  	stdin  bool
  1194  	stdout bool
  1195  	stderr bool
  1196  	tty    bool
  1197  }
  1198  
  1199  // Translates query params in request into options struct.
  1200  func streamOptionsFromRequest(req *http.Request) *options {
  1201  	query := req.URL.Query()
  1202  	tty := query.Get("tty") == "true"
  1203  	stdin := query.Get("stdin") == "true"
  1204  	stdout := query.Get("stdout") == "true"
  1205  	stderr := query.Get("stderr") == "true"
  1206  	return &options{
  1207  		stdin:  stdin,
  1208  		stdout: stdout,
  1209  		stderr: stderr,
  1210  		tty:    tty,
  1211  	}
  1212  }
  1213  
  1214  // websocketStreams contains the WebSocket connection and streams from a server.
  1215  type websocketStreams struct {
  1216  	conn         io.Closer
  1217  	stdinStream  io.ReadCloser
  1218  	stdoutStream io.WriteCloser
  1219  	stderrStream io.WriteCloser
  1220  	writeStatus  func(status *apierrors.StatusError) error
  1221  	resizeStream io.ReadCloser
  1222  	resizeChan   chan TerminalSize
  1223  	tty          bool
  1224  }
  1225  
  1226  // Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed
  1227  // in the stream options.
  1228  func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
  1229  	conn, err := createWebSocketStreams(req, w, opts)
  1230  	if err != nil {
  1231  		return nil, err
  1232  	}
  1233  
  1234  	if conn.resizeStream != nil {
  1235  		conn.resizeChan = make(chan TerminalSize)
  1236  		go handleResizeEvents(req.Context(), conn.resizeStream, conn.resizeChan)
  1237  	}
  1238  
  1239  	return conn, nil
  1240  }
  1241  
  1242  // Read terminal resize events off of passed stream and queue into passed channel.
  1243  func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- TerminalSize) {
  1244  	defer close(channel)
  1245  
  1246  	decoder := json.NewDecoder(stream)
  1247  	for {
  1248  		size := TerminalSize{}
  1249  		if err := decoder.Decode(&size); err != nil {
  1250  			break
  1251  		}
  1252  
  1253  		select {
  1254  		case channel <- size:
  1255  		case <-ctx.Done():
  1256  			// To avoid leaking this routine, exit if the http request finishes. This path
  1257  			// would generally be hit if starting the process fails and nothing is started to
  1258  			// ingest these resize events.
  1259  			return
  1260  		}
  1261  	}
  1262  }
  1263  
  1264  // createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
  1265  // along with the approximate duplex value. It also creates the error (3) and resize (4) channels.
  1266  func createChannels(opts *options) []wsstream.ChannelType {
  1267  	// open the requested channels, and always open the error channel
  1268  	channels := make([]wsstream.ChannelType, 5)
  1269  	channels[remotecommand.StreamStdIn] = readChannel(opts.stdin)
  1270  	channels[remotecommand.StreamStdOut] = writeChannel(opts.stdout)
  1271  	channels[remotecommand.StreamStdErr] = writeChannel(opts.stderr)
  1272  	channels[remotecommand.StreamErr] = wsstream.WriteChannel
  1273  	channels[remotecommand.StreamResize] = wsstream.ReadChannel
  1274  	return channels
  1275  }
  1276  
  1277  // readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel.
  1278  func readChannel(real bool) wsstream.ChannelType {
  1279  	if real {
  1280  		return wsstream.ReadChannel
  1281  	}
  1282  	return wsstream.IgnoreChannel
  1283  }
  1284  
  1285  // writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel.
  1286  func writeChannel(real bool) wsstream.ChannelType {
  1287  	if real {
  1288  		return wsstream.WriteChannel
  1289  	}
  1290  	return wsstream.IgnoreChannel
  1291  }
  1292  
  1293  // createWebSocketStreams returns a "channels" struct containing the websocket connection and
  1294  // streams needed to perform an exec or an attach.
  1295  func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
  1296  	channels := createChannels(opts)
  1297  	conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
  1298  		remotecommand.StreamProtocolV5Name: {
  1299  			Binary:   true,
  1300  			Channels: channels,
  1301  		},
  1302  	})
  1303  	conn.SetIdleTimeout(4 * time.Hour)
  1304  	// Opening the connection responds to WebSocket client, negotiating
  1305  	// the WebSocket upgrade connection and the subprotocol.
  1306  	_, streams, err := conn.Open(w, req)
  1307  	if err != nil {
  1308  		return nil, err
  1309  	}
  1310  
  1311  	// Send an empty message to the lowest writable channel to notify the client the connection is established
  1312  	//nolint:errcheck
  1313  	switch {
  1314  	case opts.stdout:
  1315  		streams[remotecommand.StreamStdOut].Write([]byte{})
  1316  	case opts.stderr:
  1317  		streams[remotecommand.StreamStdErr].Write([]byte{})
  1318  	default:
  1319  		streams[remotecommand.StreamErr].Write([]byte{})
  1320  	}
  1321  
  1322  	wsStreams := &websocketStreams{
  1323  		conn:         conn,
  1324  		stdinStream:  streams[remotecommand.StreamStdIn],
  1325  		stdoutStream: streams[remotecommand.StreamStdOut],
  1326  		stderrStream: streams[remotecommand.StreamStdErr],
  1327  		tty:          opts.tty,
  1328  		resizeStream: streams[remotecommand.StreamResize],
  1329  	}
  1330  
  1331  	wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error {
  1332  		return func(status *apierrors.StatusError) error {
  1333  			bs, err := json.Marshal(status.Status())
  1334  			if err != nil {
  1335  				return err
  1336  			}
  1337  			_, err = stream.Write(bs)
  1338  			return err
  1339  		}
  1340  	}(streams[remotecommand.StreamErr])
  1341  
  1342  	return wsStreams, nil
  1343  }
  1344  
  1345  // See (https://github.com/kubernetes/kubernetes/issues/126134).
  1346  func TestWebSocketClient_HTTPSProxyErrorExpected(t *testing.T) {
  1347  	urlStr := "http://127.0.0.1/never-used" + "?" + "stdin=true" + "&" + "stdout=true"
  1348  	websocketLocation, err := url.Parse(urlStr)
  1349  	if err != nil {
  1350  		t.Fatalf("Unable to parse WebSocket server URL: %s", urlStr)
  1351  	}
  1352  	// proxy url with https scheme will trigger websocket dialing error.
  1353  	httpsProxyFunc := func(req *http.Request) (*url.URL, error) { return url.Parse("https://127.0.0.1") }
  1354  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host, Proxy: httpsProxyFunc}, "GET", urlStr)
  1355  	if err != nil {
  1356  		t.Errorf("unexpected error creating websocket executor: %v", err)
  1357  	}
  1358  	var stdout bytes.Buffer
  1359  	options := &StreamOptions{
  1360  		Stdout: &stdout,
  1361  	}
  1362  	errorChan := make(chan error)
  1363  	go func() {
  1364  		// Start the streaming on the WebSocket "exec" client.
  1365  		errorChan <- exec.StreamWithContext(context.Background(), *options)
  1366  	}()
  1367  
  1368  	select {
  1369  	case <-time.After(wait.ForeverTestTimeout):
  1370  		t.Fatalf("expect stream to be closed after connection is closed.")
  1371  	case err := <-errorChan:
  1372  		if err == nil {
  1373  			t.Errorf("expected error but received none")
  1374  		}
  1375  		if !httpstream.IsHTTPSProxyError(err) {
  1376  			t.Errorf("expected https proxy error, got (%s)", err)
  1377  		}
  1378  	}
  1379  }