github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/testhelpers/proxy.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package testhelpers 16 17 import ( 18 "io" 19 "net" 20 "net/http" 21 "net/http/httptest" 22 "sync" 23 "testing" 24 25 "github.com/gravitational/trace" 26 "github.com/stretchr/testify/require" 27 ) 28 29 // ProxyHandler is a http.Handler that implements a simple HTTP proxy server. 30 type ProxyHandler struct { 31 sync.Mutex 32 count int 33 } 34 35 // ServeHTTP only accepts the CONNECT verb and will tunnel your connection to 36 // the specified host. Also tracks the number of connections that it proxies for 37 // debugging purposes. 38 func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 39 // Validate http connect parameters. 40 if r.Method != http.MethodConnect { 41 trace.WriteError(w, trace.BadParameter("%v not supported", r.Method)) 42 return 43 } 44 if r.Host == "" { 45 trace.WriteError(w, trace.BadParameter("host not set")) 46 return 47 } 48 49 // Dial to the target host, this is done before hijacking the connection to 50 // ensure the target host is accessible. 51 dialer := net.Dialer{} 52 dconn, err := dialer.DialContext(r.Context(), "tcp", r.Host) 53 if err != nil { 54 trace.WriteError(w, err) 55 return 56 } 57 defer dconn.Close() 58 59 // Once the client receives 200 OK, the rest of the data will no longer be 60 // http, but whatever protocol is being tunneled. 61 w.WriteHeader(http.StatusOK) 62 63 // Hijack request so we can get underlying connection. 64 hj, ok := w.(http.Hijacker) 65 if !ok { 66 trace.WriteError(w, trace.AccessDenied("unable to hijack connection")) 67 return 68 } 69 sconn, buf, err := hj.Hijack() 70 if err != nil { 71 trace.WriteError(w, err) 72 return 73 } 74 defer sconn.Close() 75 76 // Success, we're proxying data now. 77 p.Lock() 78 p.count++ 79 p.Unlock() 80 81 // Copy from src to dst and dst to src. 82 errc := make(chan error, 2) 83 replicate := func(dst io.Writer, src io.Reader) { 84 _, err := io.Copy(dst, src) 85 errc <- err 86 } 87 go replicate(sconn, dconn) 88 go replicate(dconn, io.MultiReader(buf, sconn)) 89 90 // Wait until done. 91 select { 92 case <-r.Context().Done(): 93 case <-errc: 94 } 95 } 96 97 // Count returns the number of requests that have been proxied. 98 func (p *ProxyHandler) Count() int { 99 p.Lock() 100 defer p.Unlock() 101 return p.count 102 } 103 104 // Reset sets the counter for proxied requests to zero. 105 func (p *ProxyHandler) Reset() { 106 p.Lock() 107 defer p.Unlock() 108 p.count = 0 109 } 110 111 // GetLocalIP gets the non-loopback IP address of this host. 112 func GetLocalIP() (string, error) { 113 addrs, err := net.InterfaceAddrs() 114 if err != nil { 115 return "", trace.Wrap(err) 116 } 117 for _, addr := range addrs { 118 var ip net.IP 119 switch v := addr.(type) { 120 case *net.IPNet: 121 ip = v.IP 122 case *net.IPAddr: 123 ip = v.IP 124 default: 125 continue 126 } 127 if !ip.IsLoopback() && ip.IsPrivate() { 128 return ip.String(), nil 129 } 130 } 131 return "", trace.NotFound("No non-loopback local IP address found") 132 } 133 134 type TestServerOption func(*testing.T, *httptest.Server) 135 136 func WithTestServerAddress(ip string) TestServerOption { 137 return func(t *testing.T, srv *httptest.Server) { 138 // Replace the test server's address. 139 _, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String()) 140 require.NoError(t, err) 141 require.NoError(t, srv.Listener.Close()) 142 l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort)) 143 require.NoError(t, err) 144 srv.Listener = l 145 } 146 } 147 148 func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server { 149 svr := httptest.NewUnstartedServer(h) 150 for _, opt := range opts { 151 opt(t, svr) 152 } 153 svr.StartTLS() 154 t.Cleanup(svr.Close) 155 return svr 156 }