github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/webclient/webclient_test.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package webclient
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"net"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"slices"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  	"github.com/stretchr/testify/require"
    32  
    33  	"github.com/gravitational/teleport/api/defaults"
    34  	apihelpers "github.com/gravitational/teleport/api/testhelpers"
    35  )
    36  
    37  func newPingHandler(path string) http.Handler {
    38  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    39  		if req.RequestURI != path {
    40  			w.WriteHeader(http.StatusNotFound)
    41  			return
    42  		}
    43  
    44  		w.Header().Set("Content-Type", "application/json")
    45  		w.WriteHeader(http.StatusOK)
    46  		json.NewEncoder(w).Encode(PingResponse{ServerVersion: "test"})
    47  	})
    48  }
    49  
    50  func TestPlainHttpFallback(t *testing.T) {
    51  	t.Parallel()
    52  
    53  	testCases := []struct {
    54  		desc            string
    55  		handler         http.Handler
    56  		actionUnderTest func(addr string, insecure bool) error
    57  	}{
    58  		{
    59  			desc:    "Ping",
    60  			handler: newPingHandler("/webapi/ping"),
    61  			actionUnderTest: func(addr string, insecure bool) error {
    62  				_, err := Ping(
    63  					&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
    64  				return err
    65  			},
    66  		}, {
    67  			desc:    "Find",
    68  			handler: newPingHandler("/webapi/find"),
    69  			actionUnderTest: func(addr string, insecure bool) error {
    70  				_, err := Find(&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
    71  				return err
    72  			},
    73  		},
    74  	}
    75  
    76  	for _, testCase := range testCases {
    77  		t.Run(testCase.desc, func(t *testing.T) {
    78  			t.Run("Allowed on insecure & loopback", func(t *testing.T) {
    79  				httpSvr := httptest.NewServer(testCase.handler)
    80  				defer httpSvr.Close()
    81  
    82  				err := testCase.actionUnderTest(httpSvr.Listener.Addr().String(), true /* insecure */)
    83  				require.NoError(t, err)
    84  			})
    85  
    86  			t.Run("Denied on secure", func(t *testing.T) {
    87  				httpSvr := httptest.NewServer(testCase.handler)
    88  				defer httpSvr.Close()
    89  
    90  				err := testCase.actionUnderTest(httpSvr.Listener.Addr().String(), false /* secure */)
    91  				require.Error(t, err)
    92  			})
    93  
    94  			t.Run("Denied on non-loopback", func(t *testing.T) {
    95  				nonLoopbackSvr := httptest.NewUnstartedServer(testCase.handler)
    96  
    97  				// replace the test-supplied loopback listener with the first available
    98  				// non-loopback address
    99  				nonLoopbackSvr.Listener.Close()
   100  				l, err := net.Listen("tcp", "0.0.0.0:0")
   101  				require.NoError(t, err)
   102  				nonLoopbackSvr.Listener = l
   103  				nonLoopbackSvr.Start()
   104  				defer nonLoopbackSvr.Close()
   105  
   106  				err = testCase.actionUnderTest(nonLoopbackSvr.Listener.Addr().String(), true /* insecure */)
   107  				require.Error(t, err)
   108  			})
   109  		})
   110  	}
   111  }
   112  
   113  func TestTunnelAddr(t *testing.T) {
   114  	cases := []struct {
   115  		name               string
   116  		settings           ProxySettings
   117  		expectedTunnelAddr string
   118  		setup              func(t *testing.T)
   119  	}{
   120  		{
   121  			name: "should use TunnelPublicAddr",
   122  			settings: ProxySettings{
   123  				SSH: SSHProxySettings{
   124  					TunnelPublicAddr: "tunnel.example.com:4024",
   125  					PublicAddr:       "public.example.com",
   126  					SSHPublicAddr:    "ssh.example.com",
   127  					TunnelListenAddr: "[::]:5024",
   128  					WebListenAddr:    "proxy.example.com",
   129  				},
   130  			},
   131  			expectedTunnelAddr: "tunnel.example.com:4024",
   132  		},
   133  		{
   134  			name: "should use SSHPublicAddr and TunnelListenAddr",
   135  			settings: ProxySettings{
   136  				SSH: SSHProxySettings{
   137  					SSHPublicAddr:    "ssh.example.com",
   138  					PublicAddr:       "public.example.com",
   139  					TunnelListenAddr: "[::]:5024",
   140  					WebListenAddr:    "proxy.example.com",
   141  				},
   142  			},
   143  			expectedTunnelAddr: "ssh.example.com:5024",
   144  		},
   145  		{
   146  			name: "should use PublicAddr and TunnelListenAddr",
   147  			settings: ProxySettings{
   148  				SSH: SSHProxySettings{
   149  					PublicAddr:       "public.example.com",
   150  					TunnelListenAddr: "[::]:5024",
   151  					WebListenAddr:    "proxy.example.com",
   152  				},
   153  			},
   154  			expectedTunnelAddr: "public.example.com:5024",
   155  		},
   156  		{
   157  			name: "should use PublicAddr and SSHProxyTunnelListenPort",
   158  			settings: ProxySettings{
   159  				SSH: SSHProxySettings{
   160  					PublicAddr:    "public.example.com",
   161  					WebListenAddr: "proxy.example.com",
   162  				},
   163  			},
   164  			expectedTunnelAddr: "public.example.com:3024",
   165  		},
   166  		{
   167  			name: "should use WebListenAddr and SSHProxyTunnelListenPort",
   168  			settings: ProxySettings{
   169  				SSH: SSHProxySettings{
   170  					WebListenAddr: "proxy.example.com",
   171  				},
   172  			},
   173  			expectedTunnelAddr: "proxy.example.com:3024",
   174  		},
   175  		{
   176  			name: "should use PublicAddr with ProxyWebPort if TLSRoutingEnabled was enabled",
   177  			settings: ProxySettings{
   178  				SSH: SSHProxySettings{
   179  					PublicAddr:       "public.example.com",
   180  					TunnelListenAddr: "[::]:5024",
   181  					TunnelPublicAddr: "tpa.example.com:3032",
   182  					WebListenAddr:    "proxy.example.com:443",
   183  				},
   184  				TLSRoutingEnabled: true,
   185  			},
   186  			expectedTunnelAddr: "public.example.com:443",
   187  		},
   188  		{
   189  			name: "should use PublicAddr with custom port if TLSRoutingEnabled was enabled",
   190  			settings: ProxySettings{
   191  				SSH: SSHProxySettings{
   192  					PublicAddr:       "public.example.com:443",
   193  					TunnelListenAddr: "[::]:5024",
   194  					TunnelPublicAddr: "tpa.example.com:3032",
   195  					WebListenAddr:    "proxy.example.com:443",
   196  				},
   197  				TLSRoutingEnabled: true,
   198  			},
   199  			expectedTunnelAddr: "public.example.com:443",
   200  		},
   201  		{
   202  			name: "should use WebListenAddr with custom ProxyWebPort if TLSRoutingEnabled was enabled",
   203  			settings: ProxySettings{
   204  				SSH: SSHProxySettings{
   205  					TunnelListenAddr: "[::]:5024",
   206  					TunnelPublicAddr: "tpa.example.com:3032",
   207  					WebListenAddr:    "proxy.example.com:443",
   208  				},
   209  				TLSRoutingEnabled: true,
   210  			},
   211  			expectedTunnelAddr: "proxy.example.com:443",
   212  		},
   213  		{
   214  			name: "should use WebListenAddr with default https port if TLSRoutingEnabled was enabled",
   215  			settings: ProxySettings{
   216  				SSH: SSHProxySettings{
   217  					TunnelListenAddr: "[::]:5024",
   218  					TunnelPublicAddr: "tpa.example.com:3032",
   219  					WebListenAddr:    "proxy.example.com",
   220  				},
   221  				TLSRoutingEnabled: true,
   222  			},
   223  			expectedTunnelAddr: "proxy.example.com:443",
   224  		},
   225  		{
   226  			name:               "TELEPORT_TUNNEL_PUBLIC_ADDR overrides tunnel address",
   227  			settings:           ProxySettings{},
   228  			expectedTunnelAddr: "tunnel.example.com:4024",
   229  			setup: func(t *testing.T) {
   230  				t.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024")
   231  			},
   232  		},
   233  	}
   234  
   235  	for _, tt := range cases {
   236  		t.Run(tt.name, func(t *testing.T) {
   237  			if tt.setup != nil {
   238  				tt.setup(t)
   239  			}
   240  			tunnelAddr, err := tt.settings.TunnelAddr()
   241  			require.NoError(t, err)
   242  			require.Equal(t, tt.expectedTunnelAddr, tunnelAddr)
   243  		})
   244  	}
   245  }
   246  
   247  func TestParse(t *testing.T) {
   248  	t.Parallel()
   249  
   250  	testCases := []struct {
   251  		addr     string
   252  		hostPort string
   253  		host     string
   254  		port     int
   255  	}{
   256  		{
   257  			addr:     "example.com",
   258  			hostPort: "example.com",
   259  			host:     "example.com",
   260  			port:     0,
   261  		}, {
   262  			addr:     "example.com:443",
   263  			hostPort: "example.com:443",
   264  			host:     "example.com",
   265  			port:     443,
   266  		}, {
   267  			addr:     "http://example.com:443",
   268  			hostPort: "example.com:443",
   269  			host:     "example.com",
   270  			port:     443,
   271  		}, {
   272  			addr:     "https://example.com:443",
   273  			hostPort: "example.com:443",
   274  			host:     "example.com",
   275  			port:     443,
   276  		}, {
   277  			addr:     "tcp://example.com:443",
   278  			hostPort: "example.com:443",
   279  			host:     "example.com",
   280  			port:     443,
   281  		}, {
   282  			addr:     "file://host/path",
   283  			hostPort: "",
   284  			host:     "",
   285  			port:     0,
   286  		}, {
   287  			addr:     "[::]:443",
   288  			hostPort: "[::]:443",
   289  			host:     "::",
   290  			port:     443,
   291  		}, {
   292  			addr:     "https://example.com:443/path?query=query#fragment",
   293  			hostPort: "example.com:443",
   294  			host:     "example.com",
   295  			port:     443,
   296  		},
   297  	}
   298  
   299  	for _, tc := range testCases {
   300  		t.Run(tc.addr, func(t *testing.T) {
   301  			hostPort, err := parseAndJoinHostPort(tc.addr)
   302  			if tc.hostPort == "" {
   303  				require.Error(t, err)
   304  			} else {
   305  				require.NoError(t, err)
   306  				require.Equal(t, tc.hostPort, hostPort)
   307  			}
   308  
   309  			host, _, err := ParseHostPort(tc.addr)
   310  			if tc.host == "" {
   311  				require.Error(t, err)
   312  			} else {
   313  				require.NoError(t, err)
   314  				require.Equal(t, tc.host, host)
   315  			}
   316  
   317  			port, err := parsePort(tc.addr)
   318  			if tc.port == 0 {
   319  				require.Error(t, err)
   320  			} else {
   321  				require.NoError(t, err)
   322  				require.Equal(t, tc.port, port)
   323  			}
   324  		})
   325  	}
   326  }
   327  
   328  func TestNewWebClientHTTPProxy(t *testing.T) {
   329  	proxyHandler := &apihelpers.ProxyHandler{}
   330  	proxyServer := httptest.NewServer(proxyHandler)
   331  	t.Cleanup(proxyServer.Close)
   332  
   333  	localIP, err := apihelpers.GetLocalIP()
   334  	require.NoError(t, err)
   335  	server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   336  		w.WriteHeader(http.StatusOK)
   337  		w.Write([]byte("hello"))
   338  	}), apihelpers.WithTestServerAddress(localIP))
   339  	_, serverPort, err := net.SplitHostPort(server.Listener.Addr().String())
   340  	require.NoError(t, err)
   341  	serverAddr := net.JoinHostPort(localIP, serverPort)
   342  	tests := []struct {
   343  		name               string
   344  		env                map[string]string
   345  		expectedProxyCount int
   346  	}{
   347  		{
   348  			name: "use http proxy",
   349  			env: map[string]string{
   350  				"HTTPS_PROXY": proxyServer.URL,
   351  			},
   352  			expectedProxyCount: 1,
   353  		},
   354  		{
   355  			name: "ignore proxy when no_proxy is set",
   356  			env: map[string]string{
   357  				"HTTPS_PROXY": proxyServer.URL,
   358  				"NO_PROXY":    "*",
   359  			},
   360  			expectedProxyCount: 0,
   361  		},
   362  	}
   363  	for _, tc := range tests {
   364  		t.Run(tc.name, func(t *testing.T) {
   365  			t.Cleanup(proxyHandler.Reset)
   366  			for k, v := range tc.env {
   367  				t.Setenv(k, v)
   368  			}
   369  			ctx, cancel := context.WithCancel(context.Background())
   370  			t.Cleanup(cancel)
   371  			client, err := newWebClient(&Config{
   372  				Context:   ctx,
   373  				ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used
   374  				Insecure:  true,
   375  			})
   376  			require.NoError(t, err)
   377  
   378  			resp, err := client.Get("https://" + serverAddr)
   379  			require.NoError(t, err)
   380  			require.NoError(t, resp.Body.Close())
   381  			require.Equal(t, tc.expectedProxyCount, proxyHandler.Count())
   382  		})
   383  	}
   384  }
   385  
   386  func TestSSHProxyHostPort(t *testing.T) {
   387  	t.Parallel()
   388  
   389  	tests := []struct {
   390  		testName        string
   391  		inProxySettings ProxySettings
   392  		outHost         string
   393  		outPort         string
   394  	}{
   395  		{
   396  			testName: "TLS routing enabled, web public addr",
   397  			inProxySettings: ProxySettings{
   398  				SSH: SSHProxySettings{
   399  					PublicAddr:    "proxy.example.com:443",
   400  					WebListenAddr: "127.0.0.1:3080",
   401  				},
   402  				TLSRoutingEnabled: true,
   403  			},
   404  			outHost: "proxy.example.com",
   405  			outPort: "443",
   406  		},
   407  		{
   408  			testName: "TLS routing enabled, web public addr with listen addr",
   409  			inProxySettings: ProxySettings{
   410  				SSH: SSHProxySettings{
   411  					PublicAddr:    "proxy.example.com",
   412  					WebListenAddr: "127.0.0.1:443",
   413  				},
   414  				TLSRoutingEnabled: true,
   415  			},
   416  			outHost: "proxy.example.com",
   417  			outPort: "443",
   418  		},
   419  		{
   420  			testName: "TLS routing enabled, web listen addr",
   421  			inProxySettings: ProxySettings{
   422  				SSH: SSHProxySettings{
   423  					WebListenAddr: "127.0.0.1:3080",
   424  				},
   425  				TLSRoutingEnabled: true,
   426  			},
   427  			outHost: "127.0.0.1",
   428  			outPort: "3080",
   429  		},
   430  		{
   431  			testName: "TLS routing disabled, SSH public addr",
   432  			inProxySettings: ProxySettings{
   433  				SSH: SSHProxySettings{
   434  					SSHPublicAddr: "ssh.example.com:3023",
   435  					PublicAddr:    "proxy.example.com:443",
   436  					ListenAddr:    "127.0.0.1:3023",
   437  				},
   438  				TLSRoutingEnabled: false,
   439  			},
   440  			outHost: "ssh.example.com",
   441  			outPort: "3023",
   442  		},
   443  		{
   444  			testName: "TLS routing disabled, web public addr",
   445  			inProxySettings: ProxySettings{
   446  				SSH: SSHProxySettings{
   447  					PublicAddr: "proxy.example.com:443",
   448  					ListenAddr: "127.0.0.1:3023",
   449  				},
   450  				TLSRoutingEnabled: false,
   451  			},
   452  			outHost: "proxy.example.com",
   453  			outPort: "3023",
   454  		},
   455  		{
   456  			testName: "TLS routing disabled, SSH listen addr",
   457  			inProxySettings: ProxySettings{
   458  				SSH: SSHProxySettings{
   459  					ListenAddr: "127.0.0.1:3023",
   460  				},
   461  				TLSRoutingEnabled: false,
   462  			},
   463  			outHost: "127.0.0.1",
   464  			outPort: "3023",
   465  		},
   466  	}
   467  	for _, test := range tests {
   468  		t.Run(test.testName, func(t *testing.T) {
   469  			host, port, err := test.inProxySettings.SSHProxyHostPort()
   470  			require.NoError(t, err)
   471  			require.Equal(t, test.outHost, host)
   472  			require.Equal(t, test.outPort, port)
   473  		})
   474  	}
   475  }
   476  
   477  // TestWebClientClosesIdleConnections verifies that all http connections
   478  // are closed when the http.Client created by newWebClient is no longer
   479  // being used.
   480  func TestWebClientClosesIdleConnections(t *testing.T) {
   481  	expectedResponse := &PingResponse{
   482  		Proxy: ProxySettings{
   483  			TLSRoutingEnabled: true,
   484  		},
   485  		ServerVersion:    "1.2.3",
   486  		MinClientVersion: "0.1.2",
   487  		ClusterName:      "test",
   488  	}
   489  
   490  	expectedStates := []string{
   491  		http.StateNew.String(), http.StateActive.String(), http.StateClosed.String(), // the https request will fail and cause us to fallback to http
   492  		http.StateNew.String(), http.StateActive.String(), http.StateIdle.String(), http.StateClosed.String(), // the http request should be processed and closed
   493  	}
   494  
   495  	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   496  		switch r.URL.Path {
   497  		case "/webapi/find":
   498  			json.NewEncoder(w).Encode(expectedResponse)
   499  		default:
   500  			w.WriteHeader(http.StatusBadRequest)
   501  		}
   502  	}))
   503  
   504  	stateChange := make(chan string, len(expectedStates))
   505  	srv.Config.ConnState = func(conn net.Conn, state http.ConnState) {
   506  		stateChange <- state.String()
   507  	}
   508  
   509  	srv.Start()
   510  	t.Cleanup(srv.Close)
   511  
   512  	resp, err := Find(&Config{
   513  		Context:   context.Background(),
   514  		ProxyAddr: strings.TrimPrefix(srv.URL, "http://"),
   515  		Insecure:  true,
   516  	})
   517  	require.NoError(t, err)
   518  	require.Empty(t, cmp.Diff(expectedResponse, resp))
   519  
   520  	var got []string
   521  	for i := range expectedStates {
   522  		select {
   523  		case state := <-stateChange:
   524  			got = append(got, state)
   525  		case <-time.After(3 * time.Second):
   526  			t.Fatalf("timeout waiting for expected connection state %d", i)
   527  		}
   528  	}
   529  
   530  	slices.Sort(expectedStates)
   531  	slices.Sort(got)
   532  
   533  	require.Equal(t, expectedStates, got)
   534  }