github.com/rudderlabs/rudder-go-kit@v0.30.0/tcpproxy/tcpproxy.go (about)

     1  package tcpproxy
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"os"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  type Proxy struct {
    15  	LocalAddr     string
    16  	RemoteAddr    string
    17  	BytesSent     atomic.Int64
    18  	BytesReceived atomic.Int64
    19  	Verbose       bool
    20  
    21  	wg   sync.WaitGroup
    22  	stop chan struct{}
    23  }
    24  
    25  func (p *Proxy) Start(t testing.TB) {
    26  	p.stop = make(chan struct{})
    27  
    28  	p.wg.Add(1)
    29  	defer p.wg.Done()
    30  
    31  	listener, err := net.Listen("tcp", p.LocalAddr)
    32  	require.NoError(t, err)
    33  
    34  	p.wg.Add(1)
    35  	go func() {
    36  		<-p.stop
    37  		_ = listener.Close()
    38  		p.wg.Done()
    39  	}()
    40  
    41  	for {
    42  		select {
    43  		case <-p.stop:
    44  			return
    45  
    46  		default:
    47  			connRcv, err := listener.Accept()
    48  			if err != nil {
    49  				continue // error accepting connection
    50  			}
    51  
    52  			p.wg.Add(1)
    53  			go func() {
    54  				defer p.wg.Done()
    55  				defer func() { _ = connRcv.Close() }()
    56  
    57  				connSend, err := net.Dial("tcp", p.RemoteAddr)
    58  				if err != nil {
    59  					t.Logf("Cannot dial remote: %v", err)
    60  					return // cannot dial remote, return and listen for new connections
    61  				}
    62  
    63  				defer func() { _ = connSend.Close() }()
    64  
    65  				p.wg.Add(2)
    66  				done := make(chan struct{}, 2)
    67  				go p.pipe(connRcv, connSend, &p.BytesReceived, done)
    68  				go p.pipe(connSend, connRcv, &p.BytesSent, done)
    69  				select {
    70  				case <-done: // one of the connections got terminated
    71  				case <-p.stop: // TCP proxy stopped
    72  				}
    73  			}()
    74  		}
    75  	}
    76  }
    77  
    78  func (p *Proxy) Stop() {
    79  	close(p.stop)
    80  	p.wg.Wait()
    81  }
    82  
    83  func (p *Proxy) pipe(src io.Reader, dst io.Writer, bytesMetric *atomic.Int64, done chan struct{}) {
    84  	defer p.wg.Done()
    85  
    86  	wrt, rdr := dst, src
    87  	if p.Verbose {
    88  		wrt = os.Stdout
    89  		rdr = io.TeeReader(src, dst)
    90  	}
    91  	n, _ := io.Copy(wrt, rdr) // this is a blocking call, it terminates when the connection is closed
    92  	bytesMetric.Add(n)
    93  
    94  	done <- struct{}{} // connection is closed, send signal to stop proxy
    95  }