github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/proxy_test.go (about) 1 /* 2 Copyright 2017 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package utils 18 19 import ( 20 "crypto/tls" 21 "fmt" 22 "net" 23 "net/http" 24 "net/http/httptest" 25 "net/url" 26 "strings" 27 "testing" 28 29 "github.com/gravitational/trace" 30 "github.com/stretchr/testify/require" 31 "golang.org/x/net/http/httpproxy" 32 ) 33 34 func TestGetProxyAddress(t *testing.T) { 35 type env struct { 36 name string 37 val string 38 } 39 var tests = []struct { 40 info string 41 env []env 42 targetAddr string 43 proxyAddr string 44 }{ 45 { 46 info: "valid, can be raw host:port", 47 env: []env{{name: "http_proxy", val: "proxy:1234"}}, 48 proxyAddr: "proxy:1234", 49 targetAddr: "192.168.1.1:3030", 50 }, 51 { 52 info: "valid, raw host:port works for https", 53 env: []env{{name: "HTTPS_PROXY", val: "proxy:1234"}}, 54 proxyAddr: "proxy:1234", 55 targetAddr: "192.168.1.1:3030", 56 }, 57 { 58 info: "valid, correct full url", 59 env: []env{{name: "https_proxy", val: "https://proxy:1234"}}, 60 proxyAddr: "proxy:1234", 61 targetAddr: "192.168.1.1:3030", 62 }, 63 { 64 info: "valid, http endpoint can be set in https_proxy", 65 env: []env{{name: "https_proxy", val: "http://proxy:1234"}}, 66 proxyAddr: "proxy:1234", 67 targetAddr: "192.168.1.1:3030", 68 }, 69 { 70 info: "valid, socks5 endpoint can be set in https_proxy", 71 env: []env{{name: "https_proxy", val: "socks5://proxy:1234"}}, 72 proxyAddr: "proxy:1234", 73 targetAddr: "192.168.1.1:3030", 74 }, 75 { 76 info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches domain", 77 env: []env{ 78 {name: "https_proxy", val: "http://proxy:1234"}, 79 {name: "no_proxy", val: "proxy"}}, 80 proxyAddr: "", 81 targetAddr: "proxy:1234", 82 }, 83 { 84 info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches ip", 85 env: []env{ 86 {name: "https_proxy", val: "http://proxy:1234"}, 87 {name: "no_proxy", val: "192.168.1.1"}}, 88 proxyAddr: "", 89 targetAddr: "192.168.1.1:1234", 90 }, 91 { 92 info: "valid, http endpoint can be set in https_proxy, but no_proxy override matches subdomain", 93 env: []env{ 94 {name: "https_proxy", val: "http://proxy:1234"}, 95 {name: "no_proxy", val: ".example.com"}}, 96 proxyAddr: "", 97 targetAddr: "bla.example.com:1234", 98 }, 99 { 100 info: "valid, no_proxy blocks matching port", 101 env: []env{ 102 {name: "https_proxy", val: "proxy:9999"}, 103 {name: "no_proxy", val: "example.com:1234"}, 104 }, 105 proxyAddr: "", 106 targetAddr: "example.com:1234", 107 }, 108 { 109 info: "valid, no_proxy matches host but not port", 110 env: []env{ 111 {name: "https_proxy", val: "proxy:9999"}, 112 {name: "no_proxy", val: "example.com:1234"}, 113 }, 114 proxyAddr: "proxy:9999", 115 targetAddr: "example.com:5678", 116 }, 117 } 118 119 // used to augment test cases with auth credentials 120 authTests := []struct { 121 info string 122 user string 123 password string 124 }{ 125 {info: "no credentials", user: "", password: ""}, 126 {info: "plain password", user: "alice", password: "password"}, 127 {info: "special characters in password", user: "alice", password: " !@#$%^&*()_+-=[]{};:,.<>/?`~\"\\ abc123"}, 128 } 129 130 for i, tt := range tests { 131 for j, authTest := range authTests { 132 t.Run(fmt.Sprintf("%v %v: %v with %v", i, j, tt.info, authTest.info), func(t *testing.T) { 133 for _, env := range tt.env { 134 switch strings.ToLower(env.name) { 135 case "http_proxy", "https_proxy": 136 // add auth test credentials into http(s)_proxy env vars 137 val, err := buildProxyAddr(env.val, authTest.user, authTest.password) 138 require.NoError(t, err) 139 t.Setenv(env.name, val) 140 case "no_proxy": 141 t.Setenv(env.name, env.val) 142 } 143 } 144 p := GetProxyURL(tt.targetAddr) 145 146 // is a proxy expected? 147 if tt.proxyAddr == "" { 148 require.Nil(t, p) 149 return 150 } 151 require.NotNil(t, p) 152 require.Equal(t, tt.proxyAddr, p.Host) 153 154 // are auth credentials expected? 155 if authTest.user == "" && authTest.password == "" { 156 require.Nil(t, p.User) 157 return 158 } 159 require.NotNil(t, p.User) 160 require.Equal(t, authTest.user, p.User.Username()) 161 password, _ := p.User.Password() 162 require.Equal(t, authTest.password, password) 163 }) 164 } 165 } 166 } 167 168 func buildProxyAddr(addr, user, pass string) (string, error) { 169 if user == "" && pass == "" { 170 return addr, nil 171 } 172 userInfo := url.UserPassword(user, pass) 173 if strings.HasPrefix(addr, "http") || strings.HasPrefix(addr, "socks5") { 174 u, err := url.Parse(addr) 175 if err != nil { 176 return "", trace.Wrap(err) 177 } 178 u.User = userInfo 179 return u.String(), nil 180 } 181 return fmt.Sprintf("%v@%v", userInfo.String(), addr), nil 182 } 183 184 func TestProxyAwareRoundTripper(t *testing.T) { 185 t.Setenv("HTTP_PROXY", "http://localhost:8888") 186 transport := &http.Transport{ 187 TLSClientConfig: &tls.Config{ 188 InsecureSkipVerify: true, 189 }, 190 Proxy: func(req *http.Request) (*url.URL, error) { 191 return httpproxy.FromEnvironment().ProxyFunc()(req.URL) 192 }, 193 } 194 rt := NewHTTPRoundTripper(transport, nil) 195 req, err := http.NewRequest(http.MethodGet, "https://localhost:9999", nil) 196 require.NoError(t, err) 197 // Don't care about response, only if the scheme changed. 198 //nolint:bodyclose // resp should be nil, so there will be no body to close. 199 _, err = rt.RoundTrip(req) 200 require.Error(t, err) 201 require.Equal(t, "http", req.URL.Scheme) 202 } 203 204 // TestHttpRoundTripperDowngrade tests that the round tripper downgrades https requests to http 205 // when HTTP_PROXY is set to "http://localhost:*" (i.e. there's an http proxy running on localhost). 206 func TestHttpRoundTripperDowngrade(t *testing.T) { 207 testCases := []struct { 208 desc string 209 setHTTPProxy bool 210 shouldHitProxy bool 211 }{ 212 { 213 desc: "hits http proxy if insecure and localhost http proxy is set", 214 setHTTPProxy: true, 215 shouldHitProxy: true, 216 }, 217 { 218 desc: "does not hit http proxy if insecure and localhost http proxy is not set", 219 setHTTPProxy: false, 220 shouldHitProxy: false, 221 }, 222 } 223 224 for _, tc := range testCases { 225 t.Run(tc.desc, func(t *testing.T) { 226 newHandler := func(runningAtProxy bool, wasHit *bool) http.HandlerFunc { 227 return func(w http.ResponseWriter, r *http.Request) { 228 *wasHit = true 229 if tc.shouldHitProxy { 230 // If the request should hit the proxy, then: 231 // - this handler is running at the proxy, and 232 // - the scheme should be http. 233 require.True(t, runningAtProxy) 234 require.Equal(t, "http", r.URL.Scheme) 235 } 236 w.WriteHeader(http.StatusOK) 237 } 238 } 239 240 // Start localhost http proxy. 241 runningAtProxy := true 242 loopback := true 243 https := false 244 httpProxyWasHit := false 245 httpProxy, err := newServer(newHandler(runningAtProxy, &httpProxyWasHit), loopback, https) 246 require.NoError(t, err) 247 defer httpProxy.Close() 248 249 // Start non-localhost https server. 250 runningAtProxy = false 251 loopback = false 252 https = true 253 httpsSrvWasHit := false 254 httpsSrv, err := newServer(newHandler(runningAtProxy, &httpsSrvWasHit), loopback, https) 255 require.NoError(t, err) 256 defer httpsSrv.Close() 257 258 if tc.setHTTPProxy { 259 // url.Parse won't correctly parse an absolute URL without a scheme. 260 u, err := url.Parse("http://" + httpProxy.Listener.Addr().String()) 261 require.NoError(t, err) 262 _, port, err := net.SplitHostPort(u.Host) 263 require.NoError(t, err) 264 265 // Set HTTP_PROXY to "http://localhost:*". 266 t.Setenv("HTTP_PROXY", fmt.Sprintf("http://localhost:%s", port)) 267 } 268 269 clt := newClient(t, nil) 270 271 // Perform any request. 272 // Set addr to the https server. If HTTP_PROXY was set above, 273 // the http proxy should be hit regardless. 274 addr := httpsSrv.Listener.Addr().String() 275 request(t, clt, addr) 276 277 // Validate that the correct server was hit. 278 require.Equal(t, tc.shouldHitProxy, httpProxyWasHit) 279 require.Equal(t, !tc.shouldHitProxy, httpsSrvWasHit) 280 }) 281 } 282 } 283 284 // TestHttpRoundTripperExtraHeaders tests that the round tripper adds the extra headers set. 285 func TestHttpRoundTripperExtraHeaders(t *testing.T) { 286 testCases := []struct { 287 desc string 288 extraHeaders map[string]string 289 expectHeaders func(*testing.T, http.Header) 290 }{ 291 { 292 desc: "extra headers are added", 293 extraHeaders: map[string]string{ 294 "header1": "value1", 295 "header2": "value2", 296 }, 297 expectHeaders: func(t *testing.T, headers http.Header) { 298 require.Equal(t, []string{"value1"}, headers.Values("header1")) 299 require.Equal(t, []string{"value2"}, headers.Values("header2")) 300 }, 301 }, 302 { 303 desc: "extra headers do not overwrite existing headers", 304 extraHeaders: map[string]string{ 305 "header1": "value1", 306 "Content-Type": "value2", 307 }, 308 expectHeaders: func(t *testing.T, headers http.Header) { 309 require.Equal(t, []string{"value1"}, headers.Values("header1")) 310 require.Equal(t, []string{"application/json", "value2"}, headers.Values("Content-Type")) 311 }, 312 }, 313 } 314 315 for _, tc := range testCases { 316 t.Run(tc.desc, func(t *testing.T) { 317 var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { 318 tc.expectHeaders(t, r.Header) 319 w.WriteHeader(http.StatusOK) 320 } 321 322 // Start localhost https server. 323 loopback := true 324 tls := true 325 httpsSrv, err := newServer(handler, loopback, tls) 326 require.NoError(t, err) 327 defer httpsSrv.Close() 328 329 clt := newClient(t, tc.extraHeaders) 330 331 // Perform any request. 332 // Set the address to the localhost https server. 333 addr := httpsSrv.Listener.Addr().String() 334 request(t, clt, addr) 335 }) 336 } 337 } 338 339 // newServer starts a new server that: 340 // - runs TLS if `https` 341 // - uses a loopback listener if `loopback` 342 func newServer(handler http.HandlerFunc, loopback bool, https bool) (*httptest.Server, error) { 343 srv := httptest.NewUnstartedServer(handler) 344 345 if !loopback { 346 // Replace the test-supplied loopback listener with the first available 347 // non-loopback address. 348 srv.Listener.Close() 349 l, err := net.Listen("tcp", "0.0.0.0:0") 350 if err != nil { 351 return nil, err 352 } 353 srv.Listener = l 354 } 355 356 if https { 357 srv.StartTLS() 358 } else { 359 srv.Start() 360 } 361 return srv, nil 362 } 363 364 // newClient creates a new https roundtrip client. 365 func newClient(t *testing.T, extraHeaders map[string]string) *http.Client { 366 transport := &http.Transport{ 367 TLSClientConfig: &tls.Config{ 368 // Setting insecure ensures that https requests succeed. 369 InsecureSkipVerify: true, 370 }, 371 Proxy: func(req *http.Request) (*url.URL, error) { 372 return httpproxy.FromEnvironment().ProxyFunc()(req.URL) 373 }, 374 } 375 return &http.Client{ 376 Transport: NewHTTPRoundTripper(transport, extraHeaders), 377 } 378 } 379 380 // request perform a POST request. 381 func request(t *testing.T, clt *http.Client, addr string) { 382 url := "https://" + addr + "/v1/content" 383 resp, err := clt.Post(url, "application/json", nil) 384 require.NoError(t, err) 385 require.NoError(t, resp.Body.Close()) 386 } 387 388 func TestParse(t *testing.T) { 389 successTests := []struct { 390 name, addr, scheme, host, path string 391 }{ 392 {name: "scheme-host-port", addr: "http://example.com:8080", scheme: "http", host: "example.com:8080", path: ""}, 393 {name: "host-port", addr: "example.com:8080", scheme: "", host: "example.com:8080", path: ""}, 394 {name: "scheme-ip4-port", addr: "http://127.0.0.1:8080", scheme: "http", host: "127.0.0.1:8080", path: ""}, 395 {name: "ip4-port", addr: "127.0.0.1:8080", scheme: "", host: "127.0.0.1:8080", path: ""}, 396 {name: "scheme-ip6-port", addr: "http://[::1]:8080", scheme: "http", host: "[::1]:8080", path: ""}, 397 {name: "ip6-port", addr: "[::1]:8080", scheme: "", host: "[::1]:8080"}, 398 {name: "host/path", addr: "example.com/path/to/somewhere", scheme: "", host: "example.com", path: "/path/to/somewhere"}, 399 } 400 for _, tc := range successTests { 401 t.Run(fmt.Sprintf("should parse: %s", tc.name), func(t *testing.T) { 402 u, err := ParseURL(tc.addr) 403 require.NoError(t, err) 404 errMsg := fmt.Sprintf("(%v, %v, %v)", u.Scheme, u.Host, u.Path) 405 require.Equal(t, tc.scheme, u.Scheme, errMsg) 406 require.Equal(t, tc.host, u.Host, errMsg) 407 require.Equal(t, tc.path, u.Path) 408 }) 409 } 410 411 failTests := []struct { 412 name, addr string 413 }{ 414 {name: "invalid char in host without scheme", addr: "bad addr"}, 415 } 416 for _, tc := range failTests { 417 t.Run(fmt.Sprintf("should not parse: %s", tc.name), func(t *testing.T) { 418 u, err := ParseURL(tc.addr) 419 require.Error(t, err, u) 420 }) 421 } 422 423 t.Run("empty addr", func(t *testing.T) { 424 u, err := ParseURL("") 425 require.NoError(t, err) 426 require.Nil(t, u) 427 }) 428 }