github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/webclient/webclient_test.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package webclient 18 19 import ( 20 "context" 21 "encoding/json" 22 "net" 23 "net/http" 24 "net/http/httptest" 25 "slices" 26 "strings" 27 "testing" 28 "time" 29 30 "github.com/google/go-cmp/cmp" 31 "github.com/stretchr/testify/require" 32 33 "github.com/gravitational/teleport/api/defaults" 34 apihelpers "github.com/gravitational/teleport/api/testhelpers" 35 ) 36 37 func newPingHandler(path string) http.Handler { 38 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 39 if req.RequestURI != path { 40 w.WriteHeader(http.StatusNotFound) 41 return 42 } 43 44 w.Header().Set("Content-Type", "application/json") 45 w.WriteHeader(http.StatusOK) 46 json.NewEncoder(w).Encode(PingResponse{ServerVersion: "test"}) 47 }) 48 } 49 50 func TestPlainHttpFallback(t *testing.T) { 51 t.Parallel() 52 53 testCases := []struct { 54 desc string 55 handler http.Handler 56 actionUnderTest func(addr string, insecure bool) error 57 }{ 58 { 59 desc: "Ping", 60 handler: newPingHandler("/webapi/ping"), 61 actionUnderTest: func(addr string, insecure bool) error { 62 _, err := Ping( 63 &Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure}) 64 return err 65 }, 66 }, { 67 desc: "Find", 68 handler: newPingHandler("/webapi/find"), 69 actionUnderTest: func(addr string, insecure bool) error { 70 _, err := Find(&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure}) 71 return err 72 }, 73 }, 74 } 75 76 for _, testCase := range testCases { 77 t.Run(testCase.desc, func(t *testing.T) { 78 t.Run("Allowed on insecure & loopback", func(t *testing.T) { 79 httpSvr := httptest.NewServer(testCase.handler) 80 defer httpSvr.Close() 81 82 err := testCase.actionUnderTest(httpSvr.Listener.Addr().String(), true /* insecure */) 83 require.NoError(t, err) 84 }) 85 86 t.Run("Denied on secure", func(t *testing.T) { 87 httpSvr := httptest.NewServer(testCase.handler) 88 defer httpSvr.Close() 89 90 err := testCase.actionUnderTest(httpSvr.Listener.Addr().String(), false /* secure */) 91 require.Error(t, err) 92 }) 93 94 t.Run("Denied on non-loopback", func(t *testing.T) { 95 nonLoopbackSvr := httptest.NewUnstartedServer(testCase.handler) 96 97 // replace the test-supplied loopback listener with the first available 98 // non-loopback address 99 nonLoopbackSvr.Listener.Close() 100 l, err := net.Listen("tcp", "0.0.0.0:0") 101 require.NoError(t, err) 102 nonLoopbackSvr.Listener = l 103 nonLoopbackSvr.Start() 104 defer nonLoopbackSvr.Close() 105 106 err = testCase.actionUnderTest(nonLoopbackSvr.Listener.Addr().String(), true /* insecure */) 107 require.Error(t, err) 108 }) 109 }) 110 } 111 } 112 113 func TestTunnelAddr(t *testing.T) { 114 cases := []struct { 115 name string 116 settings ProxySettings 117 expectedTunnelAddr string 118 setup func(t *testing.T) 119 }{ 120 { 121 name: "should use TunnelPublicAddr", 122 settings: ProxySettings{ 123 SSH: SSHProxySettings{ 124 TunnelPublicAddr: "tunnel.example.com:4024", 125 PublicAddr: "public.example.com", 126 SSHPublicAddr: "ssh.example.com", 127 TunnelListenAddr: "[::]:5024", 128 WebListenAddr: "proxy.example.com", 129 }, 130 }, 131 expectedTunnelAddr: "tunnel.example.com:4024", 132 }, 133 { 134 name: "should use SSHPublicAddr and TunnelListenAddr", 135 settings: ProxySettings{ 136 SSH: SSHProxySettings{ 137 SSHPublicAddr: "ssh.example.com", 138 PublicAddr: "public.example.com", 139 TunnelListenAddr: "[::]:5024", 140 WebListenAddr: "proxy.example.com", 141 }, 142 }, 143 expectedTunnelAddr: "ssh.example.com:5024", 144 }, 145 { 146 name: "should use PublicAddr and TunnelListenAddr", 147 settings: ProxySettings{ 148 SSH: SSHProxySettings{ 149 PublicAddr: "public.example.com", 150 TunnelListenAddr: "[::]:5024", 151 WebListenAddr: "proxy.example.com", 152 }, 153 }, 154 expectedTunnelAddr: "public.example.com:5024", 155 }, 156 { 157 name: "should use PublicAddr and SSHProxyTunnelListenPort", 158 settings: ProxySettings{ 159 SSH: SSHProxySettings{ 160 PublicAddr: "public.example.com", 161 WebListenAddr: "proxy.example.com", 162 }, 163 }, 164 expectedTunnelAddr: "public.example.com:3024", 165 }, 166 { 167 name: "should use WebListenAddr and SSHProxyTunnelListenPort", 168 settings: ProxySettings{ 169 SSH: SSHProxySettings{ 170 WebListenAddr: "proxy.example.com", 171 }, 172 }, 173 expectedTunnelAddr: "proxy.example.com:3024", 174 }, 175 { 176 name: "should use PublicAddr with ProxyWebPort if TLSRoutingEnabled was enabled", 177 settings: ProxySettings{ 178 SSH: SSHProxySettings{ 179 PublicAddr: "public.example.com", 180 TunnelListenAddr: "[::]:5024", 181 TunnelPublicAddr: "tpa.example.com:3032", 182 WebListenAddr: "proxy.example.com:443", 183 }, 184 TLSRoutingEnabled: true, 185 }, 186 expectedTunnelAddr: "public.example.com:443", 187 }, 188 { 189 name: "should use PublicAddr with custom port if TLSRoutingEnabled was enabled", 190 settings: ProxySettings{ 191 SSH: SSHProxySettings{ 192 PublicAddr: "public.example.com:443", 193 TunnelListenAddr: "[::]:5024", 194 TunnelPublicAddr: "tpa.example.com:3032", 195 WebListenAddr: "proxy.example.com:443", 196 }, 197 TLSRoutingEnabled: true, 198 }, 199 expectedTunnelAddr: "public.example.com:443", 200 }, 201 { 202 name: "should use WebListenAddr with custom ProxyWebPort if TLSRoutingEnabled was enabled", 203 settings: ProxySettings{ 204 SSH: SSHProxySettings{ 205 TunnelListenAddr: "[::]:5024", 206 TunnelPublicAddr: "tpa.example.com:3032", 207 WebListenAddr: "proxy.example.com:443", 208 }, 209 TLSRoutingEnabled: true, 210 }, 211 expectedTunnelAddr: "proxy.example.com:443", 212 }, 213 { 214 name: "should use WebListenAddr with default https port if TLSRoutingEnabled was enabled", 215 settings: ProxySettings{ 216 SSH: SSHProxySettings{ 217 TunnelListenAddr: "[::]:5024", 218 TunnelPublicAddr: "tpa.example.com:3032", 219 WebListenAddr: "proxy.example.com", 220 }, 221 TLSRoutingEnabled: true, 222 }, 223 expectedTunnelAddr: "proxy.example.com:443", 224 }, 225 { 226 name: "TELEPORT_TUNNEL_PUBLIC_ADDR overrides tunnel address", 227 settings: ProxySettings{}, 228 expectedTunnelAddr: "tunnel.example.com:4024", 229 setup: func(t *testing.T) { 230 t.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024") 231 }, 232 }, 233 } 234 235 for _, tt := range cases { 236 t.Run(tt.name, func(t *testing.T) { 237 if tt.setup != nil { 238 tt.setup(t) 239 } 240 tunnelAddr, err := tt.settings.TunnelAddr() 241 require.NoError(t, err) 242 require.Equal(t, tt.expectedTunnelAddr, tunnelAddr) 243 }) 244 } 245 } 246 247 func TestParse(t *testing.T) { 248 t.Parallel() 249 250 testCases := []struct { 251 addr string 252 hostPort string 253 host string 254 port int 255 }{ 256 { 257 addr: "example.com", 258 hostPort: "example.com", 259 host: "example.com", 260 port: 0, 261 }, { 262 addr: "example.com:443", 263 hostPort: "example.com:443", 264 host: "example.com", 265 port: 443, 266 }, { 267 addr: "http://example.com:443", 268 hostPort: "example.com:443", 269 host: "example.com", 270 port: 443, 271 }, { 272 addr: "https://example.com:443", 273 hostPort: "example.com:443", 274 host: "example.com", 275 port: 443, 276 }, { 277 addr: "tcp://example.com:443", 278 hostPort: "example.com:443", 279 host: "example.com", 280 port: 443, 281 }, { 282 addr: "file://host/path", 283 hostPort: "", 284 host: "", 285 port: 0, 286 }, { 287 addr: "[::]:443", 288 hostPort: "[::]:443", 289 host: "::", 290 port: 443, 291 }, { 292 addr: "https://example.com:443/path?query=query#fragment", 293 hostPort: "example.com:443", 294 host: "example.com", 295 port: 443, 296 }, 297 } 298 299 for _, tc := range testCases { 300 t.Run(tc.addr, func(t *testing.T) { 301 hostPort, err := parseAndJoinHostPort(tc.addr) 302 if tc.hostPort == "" { 303 require.Error(t, err) 304 } else { 305 require.NoError(t, err) 306 require.Equal(t, tc.hostPort, hostPort) 307 } 308 309 host, _, err := ParseHostPort(tc.addr) 310 if tc.host == "" { 311 require.Error(t, err) 312 } else { 313 require.NoError(t, err) 314 require.Equal(t, tc.host, host) 315 } 316 317 port, err := parsePort(tc.addr) 318 if tc.port == 0 { 319 require.Error(t, err) 320 } else { 321 require.NoError(t, err) 322 require.Equal(t, tc.port, port) 323 } 324 }) 325 } 326 } 327 328 func TestNewWebClientHTTPProxy(t *testing.T) { 329 proxyHandler := &apihelpers.ProxyHandler{} 330 proxyServer := httptest.NewServer(proxyHandler) 331 t.Cleanup(proxyServer.Close) 332 333 localIP, err := apihelpers.GetLocalIP() 334 require.NoError(t, err) 335 server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 336 w.WriteHeader(http.StatusOK) 337 w.Write([]byte("hello")) 338 }), apihelpers.WithTestServerAddress(localIP)) 339 _, serverPort, err := net.SplitHostPort(server.Listener.Addr().String()) 340 require.NoError(t, err) 341 serverAddr := net.JoinHostPort(localIP, serverPort) 342 tests := []struct { 343 name string 344 env map[string]string 345 expectedProxyCount int 346 }{ 347 { 348 name: "use http proxy", 349 env: map[string]string{ 350 "HTTPS_PROXY": proxyServer.URL, 351 }, 352 expectedProxyCount: 1, 353 }, 354 { 355 name: "ignore proxy when no_proxy is set", 356 env: map[string]string{ 357 "HTTPS_PROXY": proxyServer.URL, 358 "NO_PROXY": "*", 359 }, 360 expectedProxyCount: 0, 361 }, 362 } 363 for _, tc := range tests { 364 t.Run(tc.name, func(t *testing.T) { 365 t.Cleanup(proxyHandler.Reset) 366 for k, v := range tc.env { 367 t.Setenv(k, v) 368 } 369 ctx, cancel := context.WithCancel(context.Background()) 370 t.Cleanup(cancel) 371 client, err := newWebClient(&Config{ 372 Context: ctx, 373 ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used 374 Insecure: true, 375 }) 376 require.NoError(t, err) 377 378 resp, err := client.Get("https://" + serverAddr) 379 require.NoError(t, err) 380 require.NoError(t, resp.Body.Close()) 381 require.Equal(t, tc.expectedProxyCount, proxyHandler.Count()) 382 }) 383 } 384 } 385 386 func TestSSHProxyHostPort(t *testing.T) { 387 t.Parallel() 388 389 tests := []struct { 390 testName string 391 inProxySettings ProxySettings 392 outHost string 393 outPort string 394 }{ 395 { 396 testName: "TLS routing enabled, web public addr", 397 inProxySettings: ProxySettings{ 398 SSH: SSHProxySettings{ 399 PublicAddr: "proxy.example.com:443", 400 WebListenAddr: "127.0.0.1:3080", 401 }, 402 TLSRoutingEnabled: true, 403 }, 404 outHost: "proxy.example.com", 405 outPort: "443", 406 }, 407 { 408 testName: "TLS routing enabled, web public addr with listen addr", 409 inProxySettings: ProxySettings{ 410 SSH: SSHProxySettings{ 411 PublicAddr: "proxy.example.com", 412 WebListenAddr: "127.0.0.1:443", 413 }, 414 TLSRoutingEnabled: true, 415 }, 416 outHost: "proxy.example.com", 417 outPort: "443", 418 }, 419 { 420 testName: "TLS routing enabled, web listen addr", 421 inProxySettings: ProxySettings{ 422 SSH: SSHProxySettings{ 423 WebListenAddr: "127.0.0.1:3080", 424 }, 425 TLSRoutingEnabled: true, 426 }, 427 outHost: "127.0.0.1", 428 outPort: "3080", 429 }, 430 { 431 testName: "TLS routing disabled, SSH public addr", 432 inProxySettings: ProxySettings{ 433 SSH: SSHProxySettings{ 434 SSHPublicAddr: "ssh.example.com:3023", 435 PublicAddr: "proxy.example.com:443", 436 ListenAddr: "127.0.0.1:3023", 437 }, 438 TLSRoutingEnabled: false, 439 }, 440 outHost: "ssh.example.com", 441 outPort: "3023", 442 }, 443 { 444 testName: "TLS routing disabled, web public addr", 445 inProxySettings: ProxySettings{ 446 SSH: SSHProxySettings{ 447 PublicAddr: "proxy.example.com:443", 448 ListenAddr: "127.0.0.1:3023", 449 }, 450 TLSRoutingEnabled: false, 451 }, 452 outHost: "proxy.example.com", 453 outPort: "3023", 454 }, 455 { 456 testName: "TLS routing disabled, SSH listen addr", 457 inProxySettings: ProxySettings{ 458 SSH: SSHProxySettings{ 459 ListenAddr: "127.0.0.1:3023", 460 }, 461 TLSRoutingEnabled: false, 462 }, 463 outHost: "127.0.0.1", 464 outPort: "3023", 465 }, 466 } 467 for _, test := range tests { 468 t.Run(test.testName, func(t *testing.T) { 469 host, port, err := test.inProxySettings.SSHProxyHostPort() 470 require.NoError(t, err) 471 require.Equal(t, test.outHost, host) 472 require.Equal(t, test.outPort, port) 473 }) 474 } 475 } 476 477 // TestWebClientClosesIdleConnections verifies that all http connections 478 // are closed when the http.Client created by newWebClient is no longer 479 // being used. 480 func TestWebClientClosesIdleConnections(t *testing.T) { 481 expectedResponse := &PingResponse{ 482 Proxy: ProxySettings{ 483 TLSRoutingEnabled: true, 484 }, 485 ServerVersion: "1.2.3", 486 MinClientVersion: "0.1.2", 487 ClusterName: "test", 488 } 489 490 expectedStates := []string{ 491 http.StateNew.String(), http.StateActive.String(), http.StateClosed.String(), // the https request will fail and cause us to fallback to http 492 http.StateNew.String(), http.StateActive.String(), http.StateIdle.String(), http.StateClosed.String(), // the http request should be processed and closed 493 } 494 495 srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 496 switch r.URL.Path { 497 case "/webapi/find": 498 json.NewEncoder(w).Encode(expectedResponse) 499 default: 500 w.WriteHeader(http.StatusBadRequest) 501 } 502 })) 503 504 stateChange := make(chan string, len(expectedStates)) 505 srv.Config.ConnState = func(conn net.Conn, state http.ConnState) { 506 stateChange <- state.String() 507 } 508 509 srv.Start() 510 t.Cleanup(srv.Close) 511 512 resp, err := Find(&Config{ 513 Context: context.Background(), 514 ProxyAddr: strings.TrimPrefix(srv.URL, "http://"), 515 Insecure: true, 516 }) 517 require.NoError(t, err) 518 require.Empty(t, cmp.Diff(expectedResponse, resp)) 519 520 var got []string 521 for i := range expectedStates { 522 select { 523 case state := <-stateChange: 524 got = append(got, state) 525 case <-time.After(3 * time.Second): 526 t.Fatalf("timeout waiting for expected connection state %d", i) 527 } 528 } 529 530 slices.Sort(expectedStates) 531 slices.Sort(got) 532 533 require.Equal(t, expectedStates, got) 534 }