github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/alpn_conn_upgrade_test.go (about) 1 /* 2 Copyright 2022 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "context" 21 "crypto/tls" 22 "crypto/x509" 23 "encoding/base64" 24 "errors" 25 "net" 26 "net/http" 27 "net/http/httptest" 28 "net/url" 29 "testing" 30 "time" 31 32 "github.com/gobwas/ws" 33 "github.com/gravitational/trace" 34 "github.com/stretchr/testify/require" 35 36 "github.com/gravitational/teleport/api/constants" 37 "github.com/gravitational/teleport/api/fixtures" 38 "github.com/gravitational/teleport/api/testhelpers" 39 "github.com/gravitational/teleport/api/utils/pingconn" 40 ) 41 42 func TestIsALPNConnUpgradeRequired(t *testing.T) { 43 t.Parallel() 44 45 tests := []struct { 46 name string 47 serverProtos []string 48 dialOpts []DialOption 49 skipProxyURLTest bool 50 insecure bool 51 expectedResult bool 52 }{ 53 { 54 name: "upgrade required (handshake success)", 55 serverProtos: nil, // Use nil for NextProtos to simulate no ALPN support. 56 insecure: true, 57 expectedResult: true, 58 }, 59 { 60 name: "upgrade not required (proto negotiated)", 61 serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, 62 insecure: true, 63 expectedResult: false, 64 }, 65 { 66 name: "upgrade required (handshake with no ALPN error)", 67 serverProtos: []string{"unknown"}, 68 insecure: true, 69 expectedResult: true, 70 }, 71 { 72 name: "upgrade required (unadvertised ALPN error)", 73 dialOpts: []DialOption{ 74 // Use a fake dialer to simulate this error. 75 withBaseDialer(ContextDialerFunc(func(context.Context, string, string) (net.Conn, error) { 76 return nil, trace.Errorf("tls: server selected unadvertised ALPN protocol") 77 })), 78 }, 79 serverProtos: []string{"h2"}, // Doesn't matter here since not hitting server. 80 expectedResult: true, 81 skipProxyURLTest: true, 82 }, 83 { 84 name: "upgrade not required (other handshake error)", 85 serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, 86 insecure: false, // to cause handshake error 87 expectedResult: false, 88 }, 89 } 90 91 ctx := context.Background() 92 forwardProxy, forwardProxyURL := mustStartForwardProxy(t) 93 94 for _, test := range tests { 95 t.Run(test.name, func(t *testing.T) { 96 server := mustStartMockALPNServer(t, test.serverProtos) 97 98 t.Run("direct", func(t *testing.T) { 99 require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure, test.dialOpts...)) 100 }) 101 102 if test.skipProxyURLTest { 103 return 104 } 105 106 t.Run("with ProxyURL", func(t *testing.T) { 107 countBeforeTest := forwardProxy.Count() 108 dialOpts := append(test.dialOpts, withProxyURL(forwardProxyURL)) 109 require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure, dialOpts...)) 110 require.Equal(t, countBeforeTest+1, forwardProxy.Count()) 111 }) 112 }) 113 } 114 } 115 116 func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) { 117 t.Parallel() 118 119 addr := "example.teleport.com:443" 120 tests := []struct { 121 name string 122 envValue string 123 require require.BoolAssertionFunc 124 }{ 125 { 126 name: "upgraded required (for all addr)", 127 envValue: "yes", 128 require: require.True, 129 }, 130 { 131 name: "upgraded required (for target addr)", 132 envValue: "0;example.teleport.com:443=1", 133 require: require.True, 134 }, 135 { 136 name: "upgraded not required (for all addr)", 137 envValue: "false", 138 require: require.False, 139 }, 140 { 141 name: "upgraded not required (no addr match)", 142 envValue: "another.teleport.com:443=true", 143 require: require.False, 144 }, 145 { 146 name: "upgraded not required (for target addr)", 147 envValue: "another.teleport.com:443=true,example.teleport.com:443=false", 148 require: require.False, 149 }, 150 } 151 152 for _, test := range tests { 153 t.Run(test.name, func(t *testing.T) { 154 test.require(t, isALPNConnUpgradeRequiredByEnv(addr, test.envValue)) 155 }) 156 } 157 } 158 159 func TestALPNConnUpgradeDialer(t *testing.T) { 160 t.Parallel() 161 162 tests := []struct { 163 name string 164 serverHandler http.Handler 165 withPing bool 166 wantError bool 167 }{ 168 { 169 // TODO(greedy52) DELETE in 17.0 170 name: "connection upgrade (legacy)", 171 serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), 172 }, 173 { 174 // TODO(greedy52) DELETE in 17.0 175 name: "connection upgrade with ping (legacy)", 176 serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), 177 withPing: true, 178 }, 179 { 180 name: "connection upgrade (WebSocket)", 181 serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), 182 }, 183 { 184 name: "connection upgrade with ping (WebSocket)", 185 serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), 186 withPing: true, 187 }, 188 { 189 name: "connection upgrade API not found", 190 serverHandler: http.NotFoundHandler(), 191 wantError: true, 192 }, 193 } 194 195 for _, test := range tests { 196 test := test 197 t.Run(test.name, func(t *testing.T) { 198 t.Parallel() 199 ctx := context.Background() 200 201 server := httptest.NewTLSServer(test.serverHandler) 202 t.Cleanup(server.Close) 203 addr, err := url.Parse(server.URL) 204 require.NoError(t, err) 205 pool := x509.NewCertPool() 206 pool.AddCert(server.Certificate()) 207 208 tlsConfig := &tls.Config{RootCAs: pool} 209 directDialer := newDirectDialer(0, 5*time.Second) 210 211 t.Run("direct", func(t *testing.T) { 212 dialer := newALPNConnUpgradeDialer(directDialer, tlsConfig, test.withPing) 213 conn, err := dialer.DialContext(ctx, "tcp", addr.Host) 214 if test.wantError { 215 require.Error(t, err) 216 return 217 } 218 require.NoError(t, err) 219 defer conn.Close() 220 221 mustReadConnData(t, conn, "hello") 222 }) 223 224 t.Run("with ProxyURL", func(t *testing.T) { 225 forwardProxy, forwardProxyURL := mustStartForwardProxy(t) 226 countBeforeTest := forwardProxy.Count() 227 228 proxyURLDialer := newProxyURLDialer(forwardProxyURL, directDialer) 229 dialer := newALPNConnUpgradeDialer(proxyURLDialer, tlsConfig, test.withPing) 230 conn, err := dialer.DialContext(ctx, "tcp", addr.Host) 231 if test.wantError { 232 require.Error(t, err) 233 return 234 } 235 require.NoError(t, err) 236 defer conn.Close() 237 238 mustReadConnData(t, conn, "hello") 239 require.Equal(t, countBeforeTest+1, forwardProxy.Count()) 240 }) 241 }) 242 } 243 } 244 245 func mustReadConnData(t *testing.T, conn net.Conn, wantText string) { 246 t.Helper() 247 248 require.NotEmpty(t, wantText) 249 250 // Use a small buffer. 251 bufferSize := len(wantText) - 1 252 data := make([]byte, bufferSize) 253 n, err := conn.Read(data) 254 require.NoError(t, err) 255 require.Equal(t, bufferSize, n) 256 actualText := string(data) 257 258 // Now read it again to get the full text. This tests 259 // websocketALPNClientConn.readBuffer is implemented correctly. 260 data = make([]byte, bufferSize) 261 n, err = conn.Read(data) 262 require.NoError(t, err) 263 require.Equal(t, 1, n) 264 actualText += string(data[:1]) 265 266 require.Equal(t, wantText, actualText) 267 } 268 269 type mockALPNServer struct { 270 net.Listener 271 cert tls.Certificate 272 supportedProtos []string 273 } 274 275 func (m *mockALPNServer) serve(ctx context.Context, t *testing.T) { 276 config := &tls.Config{ 277 NextProtos: m.supportedProtos, 278 Certificates: []tls.Certificate{m.cert}, 279 } 280 281 for { 282 select { 283 case <-ctx.Done(): 284 return 285 default: 286 } 287 288 conn, err := m.Accept() 289 if errors.Is(err, net.ErrClosed) { 290 return 291 } 292 293 go func() { 294 clientConn := tls.Server(conn, config) 295 clientConn.HandshakeContext(ctx) 296 clientConn.Close() 297 }() 298 } 299 } 300 301 func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNServer { 302 ctx, cancel := context.WithCancel(context.Background()) 303 t.Cleanup(cancel) 304 305 listener, err := net.Listen("tcp", "localhost:0") 306 require.NoError(t, err) 307 t.Cleanup(func() { 308 listener.Close() 309 }) 310 311 cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) 312 require.NoError(t, err) 313 314 m := &mockALPNServer{ 315 Listener: listener, 316 cert: cert, 317 supportedProtos: supportedProtos, 318 } 319 go m.serve(ctx, t) 320 return m 321 } 322 323 // mockLegacyConnUpgradeHandler mocks the server side implementation to handle 324 // an upgrade request and sends back some data inside the tunnel. 325 func mockLegacyConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { 326 t.Helper() 327 328 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 329 require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) 330 require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), upgradeType) 331 require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), upgradeType) 332 require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader)) 333 334 hj, ok := w.(http.Hijacker) 335 require.True(t, ok) 336 337 conn, _, err := hj.Hijack() 338 require.NoError(t, err) 339 defer conn.Close() 340 341 // Upgrade response. 342 response := &http.Response{ 343 StatusCode: http.StatusSwitchingProtocols, 344 ProtoMajor: 1, 345 ProtoMinor: 1, 346 } 347 require.NoError(t, response.Write(conn)) 348 349 // Upgraded. 350 switch upgradeType { 351 case constants.WebAPIConnUpgradeTypeALPNPing: 352 // Wrap conn with Ping and write some pings. 353 pingConn := pingconn.New(conn) 354 pingConn.WritePing() 355 _, err = pingConn.Write(write) 356 require.NoError(t, err) 357 pingConn.WritePing() 358 359 default: 360 _, err = conn.Write(write) 361 require.NoError(t, err) 362 } 363 }) 364 } 365 366 // mockWebSocketConnUpgradeHandler mocks the server side implementation to handle 367 // a WebSocket upgrade request and sends back some data inside the tunnel. 368 func mockWebSocketConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { 369 t.Helper() 370 371 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 372 require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) 373 require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), "websocket") 374 require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader)) 375 require.Equal(t, upgradeType, r.Header.Get("Sec-Websocket-Protocol")) 376 require.Equal(t, "13", r.Header.Get("Sec-Websocket-Version")) 377 378 challengeKey := r.Header.Get("Sec-Websocket-Key") 379 challengeKeyDecoded, err := base64.StdEncoding.DecodeString(challengeKey) 380 require.NoError(t, err) 381 require.Len(t, challengeKeyDecoded, 16) 382 383 hj, ok := w.(http.Hijacker) 384 require.True(t, ok) 385 386 conn, _, err := hj.Hijack() 387 require.NoError(t, err) 388 defer conn.Close() 389 390 // Upgrade response. 391 response := &http.Response{ 392 StatusCode: http.StatusSwitchingProtocols, 393 ProtoMajor: 1, 394 ProtoMinor: 1, 395 Header: make(http.Header), 396 } 397 response.Header.Set("Upgrade", "websocket") 398 response.Header.Set("Sec-WebSocket-Protocol", upgradeType) 399 response.Header.Set("Sec-WebSocket-Accept", computeWebSocketAcceptKey(challengeKey)) 400 require.NoError(t, response.Write(conn)) 401 402 // Upgraded. 403 frame := ws.NewFrame(ws.OpBinary, true, write) 404 frame.Header.Masked = true 405 require.NoError(t, ws.WriteFrame(conn, frame)) 406 }) 407 } 408 409 func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) { 410 t.Helper() 411 412 listener, err := net.Listen("tcp", "localhost:0") 413 require.NoError(t, err) 414 t.Cleanup(func() { 415 listener.Close() 416 }) 417 418 url, err := url.Parse("http://" + listener.Addr().String()) 419 require.NoError(t, err) 420 421 handler := &testhelpers.ProxyHandler{} 422 go http.Serve(listener, handler) 423 return handler, url 424 } 425 426 func Test_connUpgradeMode(t *testing.T) { 427 tests := []struct { 428 envVarValue string 429 wantUseWebSocket require.BoolAssertionFunc 430 wantUseLegacy require.BoolAssertionFunc 431 }{ 432 { 433 envVarValue: "", 434 wantUseWebSocket: require.True, 435 wantUseLegacy: require.True, 436 }, 437 { 438 envVarValue: "WebSocket", 439 wantUseWebSocket: require.True, 440 wantUseLegacy: require.False, 441 }, 442 { 443 envVarValue: "websocket", 444 wantUseWebSocket: require.True, 445 wantUseLegacy: require.False, 446 }, 447 { 448 envVarValue: "legacy", 449 wantUseWebSocket: require.False, 450 wantUseLegacy: require.True, 451 }, 452 { 453 envVarValue: "default", 454 wantUseWebSocket: require.True, 455 wantUseLegacy: require.True, 456 }, 457 } 458 459 for _, test := range tests { 460 mode := connUpgradeMode(test.envVarValue) 461 test.wantUseWebSocket(t, mode.useWebSocket()) 462 test.wantUseLegacy(t, mode.useLegacy()) 463 } 464 }