github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/proxy_test.go (about)

     1  /*
     2  Copyright 2017 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package utils
    18  
    19  import (
    20  	"crypto/tls"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"net/url"
    26  	"strings"
    27  	"testing"
    28  
    29  	"github.com/gravitational/trace"
    30  	"github.com/stretchr/testify/require"
    31  	"golang.org/x/net/http/httpproxy"
    32  )
    33  
    34  func TestGetProxyAddress(t *testing.T) {
    35  	type env struct {
    36  		name string
    37  		val  string
    38  	}
    39  	var tests = []struct {
    40  		info       string
    41  		env        []env
    42  		targetAddr string
    43  		proxyAddr  string
    44  	}{
    45  		{
    46  			info:       "valid, can be raw host:port",
    47  			env:        []env{{name: "http_proxy", val: "proxy:1234"}},
    48  			proxyAddr:  "proxy:1234",
    49  			targetAddr: "192.168.1.1:3030",
    50  		},
    51  		{
    52  			info:       "valid, raw host:port works for https",
    53  			env:        []env{{name: "HTTPS_PROXY", val: "proxy:1234"}},
    54  			proxyAddr:  "proxy:1234",
    55  			targetAddr: "192.168.1.1:3030",
    56  		},
    57  		{
    58  			info:       "valid, correct full url",
    59  			env:        []env{{name: "https_proxy", val: "https://proxy:1234"}},
    60  			proxyAddr:  "proxy:1234",
    61  			targetAddr: "192.168.1.1:3030",
    62  		},
    63  		{
    64  			info:       "valid, http endpoint can be set in https_proxy",
    65  			env:        []env{{name: "https_proxy", val: "http://proxy:1234"}},
    66  			proxyAddr:  "proxy:1234",
    67  			targetAddr: "192.168.1.1:3030",
    68  		},
    69  		{
    70  			info:       "valid, socks5 endpoint can be set in https_proxy",
    71  			env:        []env{{name: "https_proxy", val: "socks5://proxy:1234"}},
    72  			proxyAddr:  "proxy:1234",
    73  			targetAddr: "192.168.1.1:3030",
    74  		},
    75  		{
    76  			info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches domain",
    77  			env: []env{
    78  				{name: "https_proxy", val: "http://proxy:1234"},
    79  				{name: "no_proxy", val: "proxy"}},
    80  			proxyAddr:  "",
    81  			targetAddr: "proxy:1234",
    82  		},
    83  		{
    84  			info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches ip",
    85  			env: []env{
    86  				{name: "https_proxy", val: "http://proxy:1234"},
    87  				{name: "no_proxy", val: "192.168.1.1"}},
    88  			proxyAddr:  "",
    89  			targetAddr: "192.168.1.1:1234",
    90  		},
    91  		{
    92  			info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches subdomain",
    93  			env: []env{
    94  				{name: "https_proxy", val: "http://proxy:1234"},
    95  				{name: "no_proxy", val: ".example.com"}},
    96  			proxyAddr:  "",
    97  			targetAddr: "bla.example.com:1234",
    98  		},
    99  		{
   100  			info: "valid, no_proxy blocks matching port",
   101  			env: []env{
   102  				{name: "https_proxy", val: "proxy:9999"},
   103  				{name: "no_proxy", val: "example.com:1234"},
   104  			},
   105  			proxyAddr:  "",
   106  			targetAddr: "example.com:1234",
   107  		},
   108  		{
   109  			info: "valid, no_proxy matches host but not port",
   110  			env: []env{
   111  				{name: "https_proxy", val: "proxy:9999"},
   112  				{name: "no_proxy", val: "example.com:1234"},
   113  			},
   114  			proxyAddr:  "proxy:9999",
   115  			targetAddr: "example.com:5678",
   116  		},
   117  	}
   118  
   119  	// used to augment test cases with auth credentials
   120  	authTests := []struct {
   121  		info     string
   122  		user     string
   123  		password string
   124  	}{
   125  		{info: "no credentials", user: "", password: ""},
   126  		{info: "plain password", user: "alice", password: "password"},
   127  		{info: "special characters in password", user: "alice", password: " !@#$%^&*()_+-=[]{};:,.<>/?`~\"\\ abc123"},
   128  	}
   129  
   130  	for i, tt := range tests {
   131  		for j, authTest := range authTests {
   132  			t.Run(fmt.Sprintf("%v %v: %v with %v", i, j, tt.info, authTest.info), func(t *testing.T) {
   133  				for _, env := range tt.env {
   134  					switch strings.ToLower(env.name) {
   135  					case "http_proxy", "https_proxy":
   136  						// add auth test credentials into http(s)_proxy env vars
   137  						val, err := buildProxyAddr(env.val, authTest.user, authTest.password)
   138  						require.NoError(t, err)
   139  						t.Setenv(env.name, val)
   140  					case "no_proxy":
   141  						t.Setenv(env.name, env.val)
   142  					}
   143  				}
   144  				p := GetProxyURL(tt.targetAddr)
   145  
   146  				// is a proxy expected?
   147  				if tt.proxyAddr == "" {
   148  					require.Nil(t, p)
   149  					return
   150  				}
   151  				require.NotNil(t, p)
   152  				require.Equal(t, tt.proxyAddr, p.Host)
   153  
   154  				// are auth credentials expected?
   155  				if authTest.user == "" && authTest.password == "" {
   156  					require.Nil(t, p.User)
   157  					return
   158  				}
   159  				require.NotNil(t, p.User)
   160  				require.Equal(t, authTest.user, p.User.Username())
   161  				password, _ := p.User.Password()
   162  				require.Equal(t, authTest.password, password)
   163  			})
   164  		}
   165  	}
   166  }
   167  
   168  func buildProxyAddr(addr, user, pass string) (string, error) {
   169  	if user == "" && pass == "" {
   170  		return addr, nil
   171  	}
   172  	userInfo := url.UserPassword(user, pass)
   173  	if strings.HasPrefix(addr, "http") || strings.HasPrefix(addr, "socks5") {
   174  		u, err := url.Parse(addr)
   175  		if err != nil {
   176  			return "", trace.Wrap(err)
   177  		}
   178  		u.User = userInfo
   179  		return u.String(), nil
   180  	}
   181  	return fmt.Sprintf("%v@%v", userInfo.String(), addr), nil
   182  }
   183  
   184  func TestProxyAwareRoundTripper(t *testing.T) {
   185  	t.Setenv("HTTP_PROXY", "http://localhost:8888")
   186  	transport := &http.Transport{
   187  		TLSClientConfig: &tls.Config{
   188  			InsecureSkipVerify: true,
   189  		},
   190  		Proxy: func(req *http.Request) (*url.URL, error) {
   191  			return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
   192  		},
   193  	}
   194  	rt := NewHTTPRoundTripper(transport, nil)
   195  	req, err := http.NewRequest(http.MethodGet, "https://localhost:9999", nil)
   196  	require.NoError(t, err)
   197  	// Don't care about response, only if the scheme changed.
   198  	//nolint:bodyclose // resp should be nil, so there will be no body to close.
   199  	_, err = rt.RoundTrip(req)
   200  	require.Error(t, err)
   201  	require.Equal(t, "http", req.URL.Scheme)
   202  }
   203  
   204  // TestHttpRoundTripperDowngrade tests that the round tripper downgrades https requests to http
   205  // when HTTP_PROXY is set to "http://localhost:*" (i.e. there's an http proxy running on localhost).
   206  func TestHttpRoundTripperDowngrade(t *testing.T) {
   207  	testCases := []struct {
   208  		desc           string
   209  		setHTTPProxy   bool
   210  		shouldHitProxy bool
   211  	}{
   212  		{
   213  			desc:           "hits http proxy if insecure and localhost http proxy is set",
   214  			setHTTPProxy:   true,
   215  			shouldHitProxy: true,
   216  		},
   217  		{
   218  			desc:           "does not hit http proxy if insecure and localhost http proxy is not set",
   219  			setHTTPProxy:   false,
   220  			shouldHitProxy: false,
   221  		},
   222  	}
   223  
   224  	for _, tc := range testCases {
   225  		t.Run(tc.desc, func(t *testing.T) {
   226  			newHandler := func(runningAtProxy bool, wasHit *bool) http.HandlerFunc {
   227  				return func(w http.ResponseWriter, r *http.Request) {
   228  					*wasHit = true
   229  					if tc.shouldHitProxy {
   230  						// If the request should hit the proxy, then:
   231  						// - this handler is running at the proxy, and
   232  						// - the scheme should be http.
   233  						require.True(t, runningAtProxy)
   234  						require.Equal(t, "http", r.URL.Scheme)
   235  					}
   236  					w.WriteHeader(http.StatusOK)
   237  				}
   238  			}
   239  
   240  			// Start localhost http proxy.
   241  			runningAtProxy := true
   242  			loopback := true
   243  			https := false
   244  			httpProxyWasHit := false
   245  			httpProxy, err := newServer(newHandler(runningAtProxy, &httpProxyWasHit), loopback, https)
   246  			require.NoError(t, err)
   247  			defer httpProxy.Close()
   248  
   249  			// Start non-localhost https server.
   250  			runningAtProxy = false
   251  			loopback = false
   252  			https = true
   253  			httpsSrvWasHit := false
   254  			httpsSrv, err := newServer(newHandler(runningAtProxy, &httpsSrvWasHit), loopback, https)
   255  			require.NoError(t, err)
   256  			defer httpsSrv.Close()
   257  
   258  			if tc.setHTTPProxy {
   259  				// url.Parse won't correctly parse an absolute URL without a scheme.
   260  				u, err := url.Parse("http://" + httpProxy.Listener.Addr().String())
   261  				require.NoError(t, err)
   262  				_, port, err := net.SplitHostPort(u.Host)
   263  				require.NoError(t, err)
   264  
   265  				// Set HTTP_PROXY to "http://localhost:*".
   266  				t.Setenv("HTTP_PROXY", fmt.Sprintf("http://localhost:%s", port))
   267  			}
   268  
   269  			clt := newClient(t, nil)
   270  
   271  			// Perform any request.
   272  			// Set addr to the https server. If HTTP_PROXY was set above,
   273  			// the http proxy should be hit regardless.
   274  			addr := httpsSrv.Listener.Addr().String()
   275  			request(t, clt, addr)
   276  
   277  			// Validate that the correct server was hit.
   278  			require.Equal(t, tc.shouldHitProxy, httpProxyWasHit)
   279  			require.Equal(t, !tc.shouldHitProxy, httpsSrvWasHit)
   280  		})
   281  	}
   282  }
   283  
   284  // TestHttpRoundTripperExtraHeaders tests that the round tripper adds the extra headers set.
   285  func TestHttpRoundTripperExtraHeaders(t *testing.T) {
   286  	testCases := []struct {
   287  		desc          string
   288  		extraHeaders  map[string]string
   289  		expectHeaders func(*testing.T, http.Header)
   290  	}{
   291  		{
   292  			desc: "extra headers are added",
   293  			extraHeaders: map[string]string{
   294  				"header1": "value1",
   295  				"header2": "value2",
   296  			},
   297  			expectHeaders: func(t *testing.T, headers http.Header) {
   298  				require.Equal(t, []string{"value1"}, headers.Values("header1"))
   299  				require.Equal(t, []string{"value2"}, headers.Values("header2"))
   300  			},
   301  		},
   302  		{
   303  			desc: "extra headers do not overwrite existing headers",
   304  			extraHeaders: map[string]string{
   305  				"header1":      "value1",
   306  				"Content-Type": "value2",
   307  			},
   308  			expectHeaders: func(t *testing.T, headers http.Header) {
   309  				require.Equal(t, []string{"value1"}, headers.Values("header1"))
   310  				require.Equal(t, []string{"application/json", "value2"}, headers.Values("Content-Type"))
   311  			},
   312  		},
   313  	}
   314  
   315  	for _, tc := range testCases {
   316  		t.Run(tc.desc, func(t *testing.T) {
   317  			var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
   318  				tc.expectHeaders(t, r.Header)
   319  				w.WriteHeader(http.StatusOK)
   320  			}
   321  
   322  			// Start localhost https server.
   323  			loopback := true
   324  			tls := true
   325  			httpsSrv, err := newServer(handler, loopback, tls)
   326  			require.NoError(t, err)
   327  			defer httpsSrv.Close()
   328  
   329  			clt := newClient(t, tc.extraHeaders)
   330  
   331  			// Perform any request.
   332  			// Set the address to the localhost https server.
   333  			addr := httpsSrv.Listener.Addr().String()
   334  			request(t, clt, addr)
   335  		})
   336  	}
   337  }
   338  
   339  // newServer starts a new server that:
   340  // - runs TLS if `https`
   341  // - uses a loopback listener if `loopback`
   342  func newServer(handler http.HandlerFunc, loopback bool, https bool) (*httptest.Server, error) {
   343  	srv := httptest.NewUnstartedServer(handler)
   344  
   345  	if !loopback {
   346  		// Replace the test-supplied loopback listener with the first available
   347  		// non-loopback address.
   348  		srv.Listener.Close()
   349  		l, err := net.Listen("tcp", "0.0.0.0:0")
   350  		if err != nil {
   351  			return nil, err
   352  		}
   353  		srv.Listener = l
   354  	}
   355  
   356  	if https {
   357  		srv.StartTLS()
   358  	} else {
   359  		srv.Start()
   360  	}
   361  	return srv, nil
   362  }
   363  
   364  // newClient creates a new https roundtrip client.
   365  func newClient(t *testing.T, extraHeaders map[string]string) *http.Client {
   366  	transport := &http.Transport{
   367  		TLSClientConfig: &tls.Config{
   368  			// Setting insecure ensures that https requests succeed.
   369  			InsecureSkipVerify: true,
   370  		},
   371  		Proxy: func(req *http.Request) (*url.URL, error) {
   372  			return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
   373  		},
   374  	}
   375  	return &http.Client{
   376  		Transport: NewHTTPRoundTripper(transport, extraHeaders),
   377  	}
   378  }
   379  
   380  // request perform a POST request.
   381  func request(t *testing.T, clt *http.Client, addr string) {
   382  	url := "https://" + addr + "/v1/content"
   383  	resp, err := clt.Post(url, "application/json", nil)
   384  	require.NoError(t, err)
   385  	require.NoError(t, resp.Body.Close())
   386  }
   387  
   388  func TestParse(t *testing.T) {
   389  	successTests := []struct {
   390  		name, addr, scheme, host, path string
   391  	}{
   392  		{name: "scheme-host-port", addr: "http://example.com:8080", scheme: "http", host: "example.com:8080", path: ""},
   393  		{name: "host-port", addr: "example.com:8080", scheme: "", host: "example.com:8080", path: ""},
   394  		{name: "scheme-ip4-port", addr: "http://127.0.0.1:8080", scheme: "http", host: "127.0.0.1:8080", path: ""},
   395  		{name: "ip4-port", addr: "127.0.0.1:8080", scheme: "", host: "127.0.0.1:8080", path: ""},
   396  		{name: "scheme-ip6-port", addr: "http://[::1]:8080", scheme: "http", host: "[::1]:8080", path: ""},
   397  		{name: "ip6-port", addr: "[::1]:8080", scheme: "", host: "[::1]:8080"},
   398  		{name: "host/path", addr: "example.com/path/to/somewhere", scheme: "", host: "example.com", path: "/path/to/somewhere"},
   399  	}
   400  	for _, tc := range successTests {
   401  		t.Run(fmt.Sprintf("should parse: %s", tc.name), func(t *testing.T) {
   402  			u, err := ParseURL(tc.addr)
   403  			require.NoError(t, err)
   404  			errMsg := fmt.Sprintf("(%v, %v, %v)", u.Scheme, u.Host, u.Path)
   405  			require.Equal(t, tc.scheme, u.Scheme, errMsg)
   406  			require.Equal(t, tc.host, u.Host, errMsg)
   407  			require.Equal(t, tc.path, u.Path)
   408  		})
   409  	}
   410  
   411  	failTests := []struct {
   412  		name, addr string
   413  	}{
   414  		{name: "invalid char in host without scheme", addr: "bad addr"},
   415  	}
   416  	for _, tc := range failTests {
   417  		t.Run(fmt.Sprintf("should not parse: %s", tc.name), func(t *testing.T) {
   418  			u, err := ParseURL(tc.addr)
   419  			require.Error(t, err, u)
   420  		})
   421  	}
   422  
   423  	t.Run("empty addr", func(t *testing.T) {
   424  		u, err := ParseURL("")
   425  		require.NoError(t, err)
   426  		require.Nil(t, u)
   427  	})
   428  }