github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/testhelpers/proxy.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testhelpers
    16  
    17  import (
    18  	"io"
    19  	"net"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"sync"
    23  	"testing"
    24  
    25  	"github.com/gravitational/trace"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  // ProxyHandler is a http.Handler that implements a simple HTTP proxy server.
    30  type ProxyHandler struct {
    31  	sync.Mutex
    32  	count int
    33  }
    34  
    35  // ServeHTTP only accepts the CONNECT verb and will tunnel your connection to
    36  // the specified host. Also tracks the number of connections that it proxies for
    37  // debugging purposes.
    38  func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    39  	// Validate http connect parameters.
    40  	if r.Method != http.MethodConnect {
    41  		trace.WriteError(w, trace.BadParameter("%v not supported", r.Method))
    42  		return
    43  	}
    44  	if r.Host == "" {
    45  		trace.WriteError(w, trace.BadParameter("host not set"))
    46  		return
    47  	}
    48  
    49  	// Dial to the target host, this is done before hijacking the connection to
    50  	// ensure the target host is accessible.
    51  	dialer := net.Dialer{}
    52  	dconn, err := dialer.DialContext(r.Context(), "tcp", r.Host)
    53  	if err != nil {
    54  		trace.WriteError(w, err)
    55  		return
    56  	}
    57  	defer dconn.Close()
    58  
    59  	// Once the client receives 200 OK, the rest of the data will no longer be
    60  	// http, but whatever protocol is being tunneled.
    61  	w.WriteHeader(http.StatusOK)
    62  
    63  	// Hijack request so we can get underlying connection.
    64  	hj, ok := w.(http.Hijacker)
    65  	if !ok {
    66  		trace.WriteError(w, trace.AccessDenied("unable to hijack connection"))
    67  		return
    68  	}
    69  	sconn, buf, err := hj.Hijack()
    70  	if err != nil {
    71  		trace.WriteError(w, err)
    72  		return
    73  	}
    74  	defer sconn.Close()
    75  
    76  	// Success, we're proxying data now.
    77  	p.Lock()
    78  	p.count++
    79  	p.Unlock()
    80  
    81  	// Copy from src to dst and dst to src.
    82  	errc := make(chan error, 2)
    83  	replicate := func(dst io.Writer, src io.Reader) {
    84  		_, err := io.Copy(dst, src)
    85  		errc <- err
    86  	}
    87  	go replicate(sconn, dconn)
    88  	go replicate(dconn, io.MultiReader(buf, sconn))
    89  
    90  	// Wait until done.
    91  	select {
    92  	case <-r.Context().Done():
    93  	case <-errc:
    94  	}
    95  }
    96  
    97  // Count returns the number of requests that have been proxied.
    98  func (p *ProxyHandler) Count() int {
    99  	p.Lock()
   100  	defer p.Unlock()
   101  	return p.count
   102  }
   103  
   104  // Reset sets the counter for proxied requests to zero.
   105  func (p *ProxyHandler) Reset() {
   106  	p.Lock()
   107  	defer p.Unlock()
   108  	p.count = 0
   109  }
   110  
   111  // GetLocalIP gets the non-loopback IP address of this host.
   112  func GetLocalIP() (string, error) {
   113  	addrs, err := net.InterfaceAddrs()
   114  	if err != nil {
   115  		return "", trace.Wrap(err)
   116  	}
   117  	for _, addr := range addrs {
   118  		var ip net.IP
   119  		switch v := addr.(type) {
   120  		case *net.IPNet:
   121  			ip = v.IP
   122  		case *net.IPAddr:
   123  			ip = v.IP
   124  		default:
   125  			continue
   126  		}
   127  		if !ip.IsLoopback() && ip.IsPrivate() {
   128  			return ip.String(), nil
   129  		}
   130  	}
   131  	return "", trace.NotFound("No non-loopback local IP address found")
   132  }
   133  
   134  type TestServerOption func(*testing.T, *httptest.Server)
   135  
   136  func WithTestServerAddress(ip string) TestServerOption {
   137  	return func(t *testing.T, srv *httptest.Server) {
   138  		// Replace the test server's address.
   139  		_, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String())
   140  		require.NoError(t, err)
   141  		require.NoError(t, srv.Listener.Close())
   142  		l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort))
   143  		require.NoError(t, err)
   144  		srv.Listener = l
   145  	}
   146  }
   147  
   148  func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server {
   149  	svr := httptest.NewUnstartedServer(h)
   150  	for _, opt := range opts {
   151  		opt(t, svr)
   152  	}
   153  	svr.StartTLS()
   154  	t.Cleanup(svr.Close)
   155  	return svr
   156  }