github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/alpn_conn_upgrade_test.go (about)

     1  /*
     2  Copyright 2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"errors"
    25  	"net"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/gobwas/ws"
    33  	"github.com/gravitational/trace"
    34  	"github.com/stretchr/testify/require"
    35  
    36  	"github.com/gravitational/teleport/api/constants"
    37  	"github.com/gravitational/teleport/api/fixtures"
    38  	"github.com/gravitational/teleport/api/testhelpers"
    39  	"github.com/gravitational/teleport/api/utils/pingconn"
    40  )
    41  
    42  func TestIsALPNConnUpgradeRequired(t *testing.T) {
    43  	t.Parallel()
    44  
    45  	tests := []struct {
    46  		name             string
    47  		serverProtos     []string
    48  		dialOpts         []DialOption
    49  		skipProxyURLTest bool
    50  		insecure         bool
    51  		expectedResult   bool
    52  	}{
    53  		{
    54  			name:           "upgrade required (handshake success)",
    55  			serverProtos:   nil, // Use nil for NextProtos to simulate no ALPN support.
    56  			insecure:       true,
    57  			expectedResult: true,
    58  		},
    59  		{
    60  			name:           "upgrade not required (proto negotiated)",
    61  			serverProtos:   []string{string(constants.ALPNSNIProtocolReverseTunnel)},
    62  			insecure:       true,
    63  			expectedResult: false,
    64  		},
    65  		{
    66  			name:           "upgrade required (handshake with no ALPN error)",
    67  			serverProtos:   []string{"unknown"},
    68  			insecure:       true,
    69  			expectedResult: true,
    70  		},
    71  		{
    72  			name: "upgrade required (unadvertised ALPN error)",
    73  			dialOpts: []DialOption{
    74  				// Use a fake dialer to simulate this error.
    75  				withBaseDialer(ContextDialerFunc(func(context.Context, string, string) (net.Conn, error) {
    76  					return nil, trace.Errorf("tls: server selected unadvertised ALPN protocol")
    77  				})),
    78  			},
    79  			serverProtos:     []string{"h2"}, // Doesn't matter here since not hitting server.
    80  			expectedResult:   true,
    81  			skipProxyURLTest: true,
    82  		},
    83  		{
    84  			name:           "upgrade not required (other handshake error)",
    85  			serverProtos:   []string{string(constants.ALPNSNIProtocolReverseTunnel)},
    86  			insecure:       false, // to cause handshake error
    87  			expectedResult: false,
    88  		},
    89  	}
    90  
    91  	ctx := context.Background()
    92  	forwardProxy, forwardProxyURL := mustStartForwardProxy(t)
    93  
    94  	for _, test := range tests {
    95  		t.Run(test.name, func(t *testing.T) {
    96  			server := mustStartMockALPNServer(t, test.serverProtos)
    97  
    98  			t.Run("direct", func(t *testing.T) {
    99  				require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure, test.dialOpts...))
   100  			})
   101  
   102  			if test.skipProxyURLTest {
   103  				return
   104  			}
   105  
   106  			t.Run("with ProxyURL", func(t *testing.T) {
   107  				countBeforeTest := forwardProxy.Count()
   108  				dialOpts := append(test.dialOpts, withProxyURL(forwardProxyURL))
   109  				require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure, dialOpts...))
   110  				require.Equal(t, countBeforeTest+1, forwardProxy.Count())
   111  			})
   112  		})
   113  	}
   114  }
   115  
   116  func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) {
   117  	t.Parallel()
   118  
   119  	addr := "example.teleport.com:443"
   120  	tests := []struct {
   121  		name     string
   122  		envValue string
   123  		require  require.BoolAssertionFunc
   124  	}{
   125  		{
   126  			name:     "upgraded required (for all addr)",
   127  			envValue: "yes",
   128  			require:  require.True,
   129  		},
   130  		{
   131  			name:     "upgraded required (for target addr)",
   132  			envValue: "0;example.teleport.com:443=1",
   133  			require:  require.True,
   134  		},
   135  		{
   136  			name:     "upgraded not required (for all addr)",
   137  			envValue: "false",
   138  			require:  require.False,
   139  		},
   140  		{
   141  			name:     "upgraded not required (no addr match)",
   142  			envValue: "another.teleport.com:443=true",
   143  			require:  require.False,
   144  		},
   145  		{
   146  			name:     "upgraded not required (for target addr)",
   147  			envValue: "another.teleport.com:443=true,example.teleport.com:443=false",
   148  			require:  require.False,
   149  		},
   150  	}
   151  
   152  	for _, test := range tests {
   153  		t.Run(test.name, func(t *testing.T) {
   154  			test.require(t, isALPNConnUpgradeRequiredByEnv(addr, test.envValue))
   155  		})
   156  	}
   157  }
   158  
   159  func TestALPNConnUpgradeDialer(t *testing.T) {
   160  	t.Parallel()
   161  
   162  	tests := []struct {
   163  		name          string
   164  		serverHandler http.Handler
   165  		withPing      bool
   166  		wantError     bool
   167  	}{
   168  		{
   169  			// TODO(greedy52) DELETE in 17.0
   170  			name:          "connection upgrade (legacy)",
   171  			serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
   172  		},
   173  		{
   174  			// TODO(greedy52) DELETE in 17.0
   175  			name:          "connection upgrade with ping (legacy)",
   176  			serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
   177  			withPing:      true,
   178  		},
   179  		{
   180  			name:          "connection upgrade (WebSocket)",
   181  			serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
   182  		},
   183  		{
   184  			name:          "connection upgrade with ping (WebSocket)",
   185  			serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
   186  			withPing:      true,
   187  		},
   188  		{
   189  			name:          "connection upgrade API not found",
   190  			serverHandler: http.NotFoundHandler(),
   191  			wantError:     true,
   192  		},
   193  	}
   194  
   195  	for _, test := range tests {
   196  		test := test
   197  		t.Run(test.name, func(t *testing.T) {
   198  			t.Parallel()
   199  			ctx := context.Background()
   200  
   201  			server := httptest.NewTLSServer(test.serverHandler)
   202  			t.Cleanup(server.Close)
   203  			addr, err := url.Parse(server.URL)
   204  			require.NoError(t, err)
   205  			pool := x509.NewCertPool()
   206  			pool.AddCert(server.Certificate())
   207  
   208  			tlsConfig := &tls.Config{RootCAs: pool}
   209  			directDialer := newDirectDialer(0, 5*time.Second)
   210  
   211  			t.Run("direct", func(t *testing.T) {
   212  				dialer := newALPNConnUpgradeDialer(directDialer, tlsConfig, test.withPing)
   213  				conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
   214  				if test.wantError {
   215  					require.Error(t, err)
   216  					return
   217  				}
   218  				require.NoError(t, err)
   219  				defer conn.Close()
   220  
   221  				mustReadConnData(t, conn, "hello")
   222  			})
   223  
   224  			t.Run("with ProxyURL", func(t *testing.T) {
   225  				forwardProxy, forwardProxyURL := mustStartForwardProxy(t)
   226  				countBeforeTest := forwardProxy.Count()
   227  
   228  				proxyURLDialer := newProxyURLDialer(forwardProxyURL, directDialer)
   229  				dialer := newALPNConnUpgradeDialer(proxyURLDialer, tlsConfig, test.withPing)
   230  				conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
   231  				if test.wantError {
   232  					require.Error(t, err)
   233  					return
   234  				}
   235  				require.NoError(t, err)
   236  				defer conn.Close()
   237  
   238  				mustReadConnData(t, conn, "hello")
   239  				require.Equal(t, countBeforeTest+1, forwardProxy.Count())
   240  			})
   241  		})
   242  	}
   243  }
   244  
   245  func mustReadConnData(t *testing.T, conn net.Conn, wantText string) {
   246  	t.Helper()
   247  
   248  	require.NotEmpty(t, wantText)
   249  
   250  	// Use a small buffer.
   251  	bufferSize := len(wantText) - 1
   252  	data := make([]byte, bufferSize)
   253  	n, err := conn.Read(data)
   254  	require.NoError(t, err)
   255  	require.Equal(t, bufferSize, n)
   256  	actualText := string(data)
   257  
   258  	// Now read it again to get the full text. This tests
   259  	// websocketALPNClientConn.readBuffer is implemented correctly.
   260  	data = make([]byte, bufferSize)
   261  	n, err = conn.Read(data)
   262  	require.NoError(t, err)
   263  	require.Equal(t, 1, n)
   264  	actualText += string(data[:1])
   265  
   266  	require.Equal(t, wantText, actualText)
   267  }
   268  
   269  type mockALPNServer struct {
   270  	net.Listener
   271  	cert            tls.Certificate
   272  	supportedProtos []string
   273  }
   274  
   275  func (m *mockALPNServer) serve(ctx context.Context, t *testing.T) {
   276  	config := &tls.Config{
   277  		NextProtos:   m.supportedProtos,
   278  		Certificates: []tls.Certificate{m.cert},
   279  	}
   280  
   281  	for {
   282  		select {
   283  		case <-ctx.Done():
   284  			return
   285  		default:
   286  		}
   287  
   288  		conn, err := m.Accept()
   289  		if errors.Is(err, net.ErrClosed) {
   290  			return
   291  		}
   292  
   293  		go func() {
   294  			clientConn := tls.Server(conn, config)
   295  			clientConn.HandshakeContext(ctx)
   296  			clientConn.Close()
   297  		}()
   298  	}
   299  }
   300  
   301  func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNServer {
   302  	ctx, cancel := context.WithCancel(context.Background())
   303  	t.Cleanup(cancel)
   304  
   305  	listener, err := net.Listen("tcp", "localhost:0")
   306  	require.NoError(t, err)
   307  	t.Cleanup(func() {
   308  		listener.Close()
   309  	})
   310  
   311  	cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
   312  	require.NoError(t, err)
   313  
   314  	m := &mockALPNServer{
   315  		Listener:        listener,
   316  		cert:            cert,
   317  		supportedProtos: supportedProtos,
   318  	}
   319  	go m.serve(ctx, t)
   320  	return m
   321  }
   322  
   323  // mockLegacyConnUpgradeHandler mocks the server side implementation to handle
   324  // an upgrade request and sends back some data inside the tunnel.
   325  func mockLegacyConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
   326  	t.Helper()
   327  
   328  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   329  		require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
   330  		require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), upgradeType)
   331  		require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), upgradeType)
   332  		require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))
   333  
   334  		hj, ok := w.(http.Hijacker)
   335  		require.True(t, ok)
   336  
   337  		conn, _, err := hj.Hijack()
   338  		require.NoError(t, err)
   339  		defer conn.Close()
   340  
   341  		// Upgrade response.
   342  		response := &http.Response{
   343  			StatusCode: http.StatusSwitchingProtocols,
   344  			ProtoMajor: 1,
   345  			ProtoMinor: 1,
   346  		}
   347  		require.NoError(t, response.Write(conn))
   348  
   349  		// Upgraded.
   350  		switch upgradeType {
   351  		case constants.WebAPIConnUpgradeTypeALPNPing:
   352  			// Wrap conn with Ping and write some pings.
   353  			pingConn := pingconn.New(conn)
   354  			pingConn.WritePing()
   355  			_, err = pingConn.Write(write)
   356  			require.NoError(t, err)
   357  			pingConn.WritePing()
   358  
   359  		default:
   360  			_, err = conn.Write(write)
   361  			require.NoError(t, err)
   362  		}
   363  	})
   364  }
   365  
   366  // mockWebSocketConnUpgradeHandler mocks the server side implementation to handle
   367  // a WebSocket upgrade request and sends back some data inside the tunnel.
   368  func mockWebSocketConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
   369  	t.Helper()
   370  
   371  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   372  		require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
   373  		require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), "websocket")
   374  		require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader))
   375  		require.Equal(t, upgradeType, r.Header.Get("Sec-Websocket-Protocol"))
   376  		require.Equal(t, "13", r.Header.Get("Sec-Websocket-Version"))
   377  
   378  		challengeKey := r.Header.Get("Sec-Websocket-Key")
   379  		challengeKeyDecoded, err := base64.StdEncoding.DecodeString(challengeKey)
   380  		require.NoError(t, err)
   381  		require.Len(t, challengeKeyDecoded, 16)
   382  
   383  		hj, ok := w.(http.Hijacker)
   384  		require.True(t, ok)
   385  
   386  		conn, _, err := hj.Hijack()
   387  		require.NoError(t, err)
   388  		defer conn.Close()
   389  
   390  		// Upgrade response.
   391  		response := &http.Response{
   392  			StatusCode: http.StatusSwitchingProtocols,
   393  			ProtoMajor: 1,
   394  			ProtoMinor: 1,
   395  			Header:     make(http.Header),
   396  		}
   397  		response.Header.Set("Upgrade", "websocket")
   398  		response.Header.Set("Sec-WebSocket-Protocol", upgradeType)
   399  		response.Header.Set("Sec-WebSocket-Accept", computeWebSocketAcceptKey(challengeKey))
   400  		require.NoError(t, response.Write(conn))
   401  
   402  		// Upgraded.
   403  		frame := ws.NewFrame(ws.OpBinary, true, write)
   404  		frame.Header.Masked = true
   405  		require.NoError(t, ws.WriteFrame(conn, frame))
   406  	})
   407  }
   408  
   409  func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
   410  	t.Helper()
   411  
   412  	listener, err := net.Listen("tcp", "localhost:0")
   413  	require.NoError(t, err)
   414  	t.Cleanup(func() {
   415  		listener.Close()
   416  	})
   417  
   418  	url, err := url.Parse("http://" + listener.Addr().String())
   419  	require.NoError(t, err)
   420  
   421  	handler := &testhelpers.ProxyHandler{}
   422  	go http.Serve(listener, handler)
   423  	return handler, url
   424  }
   425  
   426  func Test_connUpgradeMode(t *testing.T) {
   427  	tests := []struct {
   428  		envVarValue      string
   429  		wantUseWebSocket require.BoolAssertionFunc
   430  		wantUseLegacy    require.BoolAssertionFunc
   431  	}{
   432  		{
   433  			envVarValue:      "",
   434  			wantUseWebSocket: require.True,
   435  			wantUseLegacy:    require.True,
   436  		},
   437  		{
   438  			envVarValue:      "WebSocket",
   439  			wantUseWebSocket: require.True,
   440  			wantUseLegacy:    require.False,
   441  		},
   442  		{
   443  			envVarValue:      "websocket",
   444  			wantUseWebSocket: require.True,
   445  			wantUseLegacy:    require.False,
   446  		},
   447  		{
   448  			envVarValue:      "legacy",
   449  			wantUseWebSocket: require.False,
   450  			wantUseLegacy:    require.True,
   451  		},
   452  		{
   453  			envVarValue:      "default",
   454  			wantUseWebSocket: require.True,
   455  			wantUseLegacy:    require.True,
   456  		},
   457  	}
   458  
   459  	for _, test := range tests {
   460  		mode := connUpgradeMode(test.envVarValue)
   461  		test.wantUseWebSocket(t, mode.useWebSocket())
   462  		test.wantUseLegacy(t, mode.useLegacy())
   463  	}
   464  }