k8s.io/client-go@v0.31.1/tools/remotecommand/fallback_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"crypto/tls"
    24  	"io"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"net/url"
    28  	"sync/atomic"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"k8s.io/apimachinery/pkg/util/httpstream"
    35  	utilnettesting "k8s.io/apimachinery/pkg/util/net/testing"
    36  	"k8s.io/apimachinery/pkg/util/remotecommand"
    37  	"k8s.io/apimachinery/pkg/util/wait"
    38  	"k8s.io/client-go/rest"
    39  )
    40  
    41  func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) {
    42  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
    43  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    44  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
    45  		if err != nil {
    46  			w.WriteHeader(http.StatusForbidden)
    47  			return
    48  		}
    49  		defer conns.conn.Close()
    50  		// Loopback the STDIN stream onto the STDOUT stream.
    51  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
    52  		require.NoError(t, err)
    53  	}))
    54  	defer websocketServer.Close()
    55  
    56  	// Now create the fallback client (executor), and point it to the "websocketServer".
    57  	// Must add STDIN and STDOUT query params for the client request.
    58  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
    59  	websocketLocation, err := url.Parse(websocketServer.URL)
    60  	require.NoError(t, err)
    61  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
    62  	require.NoError(t, err)
    63  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
    64  	require.NoError(t, err)
    65  	// Never fallback, so always use the websocketExecutor, which succeeds against websocket server.
    66  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false })
    67  	require.NoError(t, err)
    68  	// Generate random data, and set it up to stream on STDIN. The data will be
    69  	// returned on the STDOUT buffer.
    70  	randomSize := 1024 * 1024
    71  	randomData := make([]byte, randomSize)
    72  	if _, err := rand.Read(randomData); err != nil {
    73  		t.Errorf("unexpected error reading random data: %v", err)
    74  	}
    75  	var stdout bytes.Buffer
    76  	options := &StreamOptions{
    77  		Stdin:  bytes.NewReader(randomData),
    78  		Stdout: &stdout,
    79  	}
    80  	errorChan := make(chan error)
    81  	go func() {
    82  		// Start the streaming on the WebSocket "exec" client.
    83  		errorChan <- exec.StreamWithContext(context.Background(), *options)
    84  	}()
    85  
    86  	select {
    87  	case <-time.After(wait.ForeverTestTimeout):
    88  		t.Fatalf("expect stream to be closed after connection is closed.")
    89  	case err := <-errorChan:
    90  		if err != nil {
    91  			t.Errorf("unexpected error")
    92  		}
    93  	}
    94  
    95  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
    96  	if err != nil {
    97  		t.Errorf("error reading the stream: %v", err)
    98  		return
    99  	}
   100  	// Check the random data sent on STDIN was the same returned on STDOUT.
   101  	if !bytes.Equal(randomData, data) {
   102  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   103  	}
   104  }
   105  
   106  func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) {
   107  	// Create fake SPDY server. Copy received STDIN data back onto STDOUT stream.
   108  	spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   109  		var stdin, stdout bytes.Buffer
   110  		ctx, err := createHTTPStreams(w, req, &StreamOptions{
   111  			Stdin:  &stdin,
   112  			Stdout: &stdout,
   113  		})
   114  		if err != nil {
   115  			w.WriteHeader(http.StatusForbidden)
   116  			return
   117  		}
   118  		defer ctx.conn.Close()
   119  		_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
   120  		if err != nil {
   121  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
   122  		}
   123  	}))
   124  	defer spdyServer.Close()
   125  
   126  	spdyLocation, err := url.Parse(spdyServer.URL)
   127  	require.NoError(t, err)
   128  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL)
   129  	require.NoError(t, err)
   130  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation)
   131  	require.NoError(t, err)
   132  	// Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server.
   133  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
   134  	require.NoError(t, err)
   135  	// Generate random data, and set it up to stream on STDIN. The data will be
   136  	// returned on the STDOUT buffer.
   137  	randomSize := 1024 * 1024
   138  	randomData := make([]byte, randomSize)
   139  	if _, err := rand.Read(randomData); err != nil {
   140  		t.Errorf("unexpected error reading random data: %v", err)
   141  	}
   142  	var stdout bytes.Buffer
   143  	options := &StreamOptions{
   144  		Stdin:  bytes.NewReader(randomData),
   145  		Stdout: &stdout,
   146  	}
   147  	errorChan := make(chan error)
   148  	go func() {
   149  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   150  	}()
   151  
   152  	select {
   153  	case <-time.After(wait.ForeverTestTimeout):
   154  		t.Fatalf("expect stream to be closed after connection is closed.")
   155  	case err := <-errorChan:
   156  		if err != nil {
   157  			t.Errorf("unexpected error")
   158  		}
   159  	}
   160  
   161  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   162  	if err != nil {
   163  		t.Errorf("error reading the stream: %v", err)
   164  		return
   165  	}
   166  	// Check the random data sent on STDIN was the same returned on STDOUT.
   167  	if !bytes.Equal(randomData, data) {
   168  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   169  	}
   170  }
   171  
   172  func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) {
   173  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   174  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   175  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   176  		if err != nil {
   177  			w.WriteHeader(http.StatusForbidden)
   178  			return
   179  		}
   180  		defer conns.conn.Close()
   181  		// Loopback the STDIN stream onto the STDOUT stream.
   182  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
   183  		require.NoError(t, err)
   184  	}))
   185  	defer websocketServer.Close()
   186  
   187  	// Now create the fallback client (executor), and point it to the "websocketServer".
   188  	// Must add STDIN and STDOUT query params for the client request.
   189  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   190  	websocketLocation, err := url.Parse(websocketServer.URL)
   191  	require.NoError(t, err)
   192  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   193  	require.NoError(t, err)
   194  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
   195  	require.NoError(t, err)
   196  	// Always fallback to spdyExecutor, but spdyExecutor fails against websocket server.
   197  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
   198  	require.NoError(t, err)
   199  	// Update the websocket executor to request remote command v4, which is unsupported.
   200  	fallbackExec, ok := exec.(*FallbackExecutor)
   201  	assert.True(t, ok, "error casting executor as FallbackExecutor")
   202  	websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor)
   203  	assert.True(t, ok, "error casting executor as websocket executor")
   204  	// Set the attempted subprotocol version to V4; websocket server only accepts V5.
   205  	websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name}
   206  
   207  	// Generate random data, and set it up to stream on STDIN. The data will be
   208  	// returned on the STDOUT buffer.
   209  	randomSize := 1024 * 1024
   210  	randomData := make([]byte, randomSize)
   211  	if _, err := rand.Read(randomData); err != nil {
   212  		t.Errorf("unexpected error reading random data: %v", err)
   213  	}
   214  	var stdout bytes.Buffer
   215  	options := &StreamOptions{
   216  		Stdin:  bytes.NewReader(randomData),
   217  		Stdout: &stdout,
   218  	}
   219  	errorChan := make(chan error)
   220  	go func() {
   221  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   222  	}()
   223  
   224  	select {
   225  	case <-time.After(wait.ForeverTestTimeout):
   226  		t.Fatalf("expect stream to be closed after connection is closed.")
   227  	case err := <-errorChan:
   228  		// Ensure secondary executor returned an error.
   229  		require.Error(t, err)
   230  	}
   231  }
   232  
   233  // localhostCert was generated from crypto/tls/generate_cert.go with the following command:
   234  //
   235  //	go run generate_cert.go  --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
   236  var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
   237  MIIDGTCCAgGgAwIBAgIRALL5AZcefF4kkYV1SEG6YrMwDQYJKoZIhvcNAQELBQAw
   238  EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
   239  MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEP
   240  ADCCAQoCggEBALQ/FHcyVwdFHxARbbD2KBtDUT7Eni+8ioNdjtGcmtXqBv45EC1C
   241  JOqqGJTroFGJ6Q9kQIZ9FqH5IJR2fOOJD9kOTueG4Vt1JY1rj1Kbpjefu8XleZ5L
   242  SBwIWVnN/lEsEbuKmj7N2gLt5AH3zMZiBI1mg1u9Z5ZZHYbCiTpBrwsq6cTlvR9g
   243  dyo1YkM5hRESCzsrL0aUByoo0qRMD8ZsgANJwgsiO0/M6idbxDwv1BnGwGmRYvOE
   244  Hxpy3v0Jg7GJYrvnpnifJTs4nw91N5X9pXxR7FFzi/6HTYDWRljvTb0w6XciKYAz
   245  bWZ0+cJr5F7wB7ovlbm7HrQIR7z7EIIu2d8CAwEAAaNoMGYwDgYDVR0PAQH/BAQD
   246  AgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0R
   247  BCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZI
   248  hvcNAQELBQADggEBAFPPWopNEJtIA2VFAQcqN6uJK+JVFOnjGRoCrM6Xgzdm0wxY
   249  XCGjsxY5dl+V7KzdGqu858rCaq5osEBqypBpYAnS9C38VyCDA1vPS1PsN8SYv48z
   250  DyBwj+7R2qar0ADBhnhWxvYO9M72lN/wuCqFKYMeFSnJdQLv3AsrrHe9lYqOa36s
   251  8wxSwVTFTYXBzljPEnSaaJMPqFD8JXaZK1ryJPkO5OsCNQNGtatNiWAf3DcmwHAT
   252  MGYMzP0u4nw47aRz9shB8w+taPKHx2BVwE1m/yp3nHVioOjXqA1fwRQVGclCJSH1
   253  D2iq3hWVHRENgjTjANBPICLo9AZ4JfN6PH19mnU=
   254  -----END CERTIFICATE-----`)
   255  
   256  // localhostKey is the private key for localhostCert.
   257  var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
   258  MIIEogIBAAKCAQEAtD8UdzJXB0UfEBFtsPYoG0NRPsSeL7yKg12O0Zya1eoG/jkQ
   259  LUIk6qoYlOugUYnpD2RAhn0WofkglHZ844kP2Q5O54bhW3UljWuPUpumN5+7xeV5
   260  nktIHAhZWc3+USwRu4qaPs3aAu3kAffMxmIEjWaDW71nllkdhsKJOkGvCyrpxOW9
   261  H2B3KjViQzmFERILOysvRpQHKijSpEwPxmyAA0nCCyI7T8zqJ1vEPC/UGcbAaZFi
   262  84QfGnLe/QmDsYliu+emeJ8lOzifD3U3lf2lfFHsUXOL/odNgNZGWO9NvTDpdyIp
   263  gDNtZnT5wmvkXvAHui+VubsetAhHvPsQgi7Z3wIDAQABAoIBAGmw93IxjYCQ0ncc
   264  kSKMJNZfsdtJdaxuNRZ0nNNirhQzR2h403iGaZlEpmdkhzxozsWcto1l+gh+SdFk
   265  bTUK4MUZM8FlgO2dEqkLYh5BcMT7ICMZvSfJ4v21E5eqR68XVUqQKoQbNvQyxFk3
   266  EddeEGdNrkb0GDK8DKlBlzAW5ep4gjG85wSTjR+J+muUv3R0BgLBFSuQnIDM/IMB
   267  LWqsja/QbtB7yppe7jL5u8UCFdZG8BBKT9fcvFIu5PRLO3MO0uOI7LTc8+W1Xm23
   268  uv+j3SY0+v+6POjK0UlJFFi/wkSPTFIfrQO1qFBkTDQHhQ6q/7GnILYYOiGbIRg2
   269  NNuP52ECgYEAzXEoy50wSYh8xfFaBuxbm3ruuG2W49jgop7ZfoFrPWwOQKAZS441
   270  VIwV4+e5IcA6KkuYbtGSdTYqK1SMkgnUyD/VevwAqH5TJoEIGu0pDuKGwVuwqioZ
   271  frCIAV5GllKyUJ55VZNbRr2vY2fCsWbaCSCHETn6C16DNuTCe5C0JBECgYEA4JqY
   272  5GpNbMG8fOt4H7hU0Fbm2yd6SHJcQ3/9iimef7xG6ajxsYrIhg1ft+3IPHMjVI0+
   273  9brwHDnWg4bOOx/VO4VJBt6Dm/F33bndnZRkuIjfSNpLM51P+EnRdaFVHOJHwKqx
   274  uF69kihifCAG7YATgCveeXImzBUSyZUz9UrETu8CgYARNBimdFNG1RcdvEg9rC0/
   275  p9u1tfecvNySwZqU7WF9kz7eSonTueTdX521qAHowaAdSpdJMGODTTXaywm6cPhQ
   276  jIfj9JZZhbqQzt1O4+08Qdvm9TamCUB5S28YLjza+bHU7nBaqixKkDfPqzCyilpX
   277  yVGGL8SwjwmN3zop/sQXAQKBgC0JMsESQ6YcDsRpnrOVjYQc+LtW5iEitTdfsaID
   278  iGGKihmOI7B66IxgoCHMTws39wycKdSyADVYr5e97xpR3rrJlgQHmBIrz+Iow7Q2
   279  LiAGaec8xjl6QK/DdXmFuQBKqyKJ14rljFODP4QuE9WJid94bGqjpf3j99ltznZP
   280  4J8HAoGAJb4eb4lu4UGwifDzqfAPzLGCoi0fE1/hSx34lfuLcc1G+LEu9YDKoOVJ
   281  9suOh0b5K/bfEy9KrVMBBriduvdaERSD8S3pkIQaitIz0B029AbE4FLFf9lKQpP2
   282  KR8NJEkK99Vh/tew6jAMll70xFrE7aF8VLXJVE7w4sQzuvHxl9Q=
   283  -----END RSA PRIVATE KEY-----
   284  `)
   285  
   286  // See (https://github.com/kubernetes/kubernetes/issues/126134).
   287  func TestFallbackClient_WebSocketHTTPSProxyCausesSPDYFallback(t *testing.T) {
   288  	cert, err := tls.X509KeyPair(localhostCert, localhostKey)
   289  	if err != nil {
   290  		t.Errorf("https (valid hostname): proxy_test: %v", err)
   291  	}
   292  
   293  	var proxyCalled atomic.Int64
   294  	proxyHandler := utilnettesting.NewHTTPProxyHandler(t, func(req *http.Request) bool {
   295  		proxyCalled.Add(1)
   296  		return true
   297  	})
   298  	defer proxyHandler.Wait()
   299  
   300  	proxyServer := httptest.NewUnstartedServer(proxyHandler)
   301  	proxyServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
   302  	proxyServer.StartTLS()
   303  	defer proxyServer.Close() //nolint:errcheck
   304  
   305  	proxyLocation, err := url.Parse(proxyServer.URL)
   306  	require.NoError(t, err)
   307  
   308  	// Create fake SPDY server. Copy received STDIN data back onto STDOUT stream.
   309  	spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   310  		var stdin, stdout bytes.Buffer
   311  		ctx, err := createHTTPStreams(w, req, &StreamOptions{
   312  			Stdin:  &stdin,
   313  			Stdout: &stdout,
   314  		})
   315  		if err != nil {
   316  			w.WriteHeader(http.StatusForbidden)
   317  			return
   318  		}
   319  		defer ctx.conn.Close() //nolint:errcheck
   320  		_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
   321  		if err != nil {
   322  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
   323  		}
   324  	}))
   325  	defer spdyServer.Close() //nolint:errcheck
   326  
   327  	backendLocation, err := url.Parse(spdyServer.URL)
   328  	require.NoError(t, err)
   329  
   330  	clientConfig := &rest.Config{
   331  		Host:            spdyServer.URL,
   332  		TLSClientConfig: rest.TLSClientConfig{CAData: localhostCert},
   333  		Proxy: func(req *http.Request) (*url.URL, error) {
   334  			return proxyLocation, nil
   335  		},
   336  	}
   337  
   338  	// Websocket with https proxy will fail in dialing (falling back to SPDY).
   339  	websocketExecutor, err := NewWebSocketExecutor(clientConfig, "GET", backendLocation.String())
   340  	require.NoError(t, err)
   341  	spdyExecutor, err := NewSPDYExecutor(clientConfig, "POST", backendLocation)
   342  	require.NoError(t, err)
   343  	// Fallback to spdyExecutor with websocket https proxy error; spdyExecutor succeeds against fake spdy server.
   344  	sawHTTPSProxyError := false
   345  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(err error) bool {
   346  		if httpstream.IsUpgradeFailure(err) {
   347  			t.Errorf("saw upgrade failure: %v", err)
   348  			return true
   349  		}
   350  		if httpstream.IsHTTPSProxyError(err) {
   351  			sawHTTPSProxyError = true
   352  			t.Logf("saw https proxy error: %v", err)
   353  			return true
   354  		}
   355  		return false
   356  	})
   357  	require.NoError(t, err)
   358  
   359  	// Generate random data, and set it up to stream on STDIN. The data will be
   360  	// returned on the STDOUT buffer.
   361  	randomSize := 1024 * 1024
   362  	randomData := make([]byte, randomSize)
   363  	if _, err := rand.Read(randomData); err != nil {
   364  		t.Errorf("unexpected error reading random data: %v", err)
   365  	}
   366  	var stdout bytes.Buffer
   367  	options := &StreamOptions{
   368  		Stdin:  bytes.NewReader(randomData),
   369  		Stdout: &stdout,
   370  	}
   371  	errorChan := make(chan error)
   372  	go func() {
   373  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   374  	}()
   375  
   376  	select {
   377  	case <-time.After(wait.ForeverTestTimeout):
   378  		t.Fatalf("expect stream to be closed after connection is closed.")
   379  	case err := <-errorChan:
   380  		if err != nil {
   381  			t.Errorf("unexpected error")
   382  		}
   383  	}
   384  
   385  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   386  	if err != nil {
   387  		t.Errorf("error reading the stream: %v", err)
   388  		return
   389  	}
   390  	// Check the random data sent on STDIN was the same returned on STDOUT.
   391  	if !bytes.Equal(randomData, data) {
   392  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   393  	}
   394  
   395  	// Ensure the https proxy error was observed
   396  	if !sawHTTPSProxyError {
   397  		t.Errorf("expected to see https proxy error")
   398  	}
   399  	// Ensure the proxy was called once
   400  	if e, a := int64(1), proxyCalled.Load(); e != a {
   401  		t.Errorf("expected %d proxy call, got %d", e, a)
   402  	}
   403  }