github.com/mailgun/holster/v4@v4.20.0/udp/proxy.go (about)

     1  package udp
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/sirupsen/logrus"
    11  )
    12  
    13  type Proxy struct {
    14  	conf     ProxyConfig
    15  	listen   *Server
    16  	upstream *Client
    17  	mutex    sync.Mutex // ensure the handler isn't still executing when start()/stop() are called
    18  	// A list of addresses that we deny udp proxy for
    19  	blocked []string
    20  }
    21  
    22  type ProxyConfig struct {
    23  	// The address:port we should bind too
    24  	Listen string
    25  	// The address:port of the upstream udp server we should forward requests too
    26  	Upstream string
    27  	// How long we should wait for a response from the upstream udp server
    28  	UpstreamTimeout time.Duration
    29  }
    30  
    31  func NewProxy(conf ProxyConfig) *Proxy {
    32  	return &Proxy{
    33  		conf: conf,
    34  	}
    35  }
    36  
    37  func (p *Proxy) Start() error {
    38  	p.mutex.Lock()
    39  	defer p.mutex.Unlock()
    40  
    41  	if p.conf.Listen == "" {
    42  		return errors.New("variable Listen cannot be empty")
    43  	}
    44  
    45  	if p.conf.Upstream == "" {
    46  		return errors.New("variable Upstream cannot be empty")
    47  	}
    48  
    49  	var err error
    50  	p.upstream, err = NewClient(p.conf.Upstream)
    51  	if err != nil {
    52  		return fmt.Errorf("while dialing upstream '%s' - %w", p.conf.Upstream, err)
    53  	}
    54  
    55  	p.listen, err = NewServer(ServerConfig{
    56  		BindAddress: p.conf.Listen,
    57  		Handler:     p.handler,
    58  	})
    59  	if err != nil {
    60  		return fmt.Errorf("while attempting to listen on '%s' - %w", p.conf.Listen, err)
    61  	}
    62  	return nil
    63  }
    64  
    65  func (p *Proxy) handler(conn net.PacketConn, buf []byte, addr net.Addr) {
    66  	p.mutex.Lock()
    67  	defer p.mutex.Unlock()
    68  
    69  	for _, block := range p.blocked {
    70  		if block == addr.String() {
    71  			logrus.Debugf("Blocked proxy of %d bytes from '%s' ", len(buf), addr.String())
    72  			return
    73  		}
    74  	}
    75  	logrus.Debugf("proxy %d bytes %s -> %s", len(buf), addr.String(), p.upstream.addr.String())
    76  
    77  	// Forward the request to upstream
    78  	if _, err := p.upstream.Send(buf); err != nil {
    79  		logrus.WithError(err).Errorf("failed to forward '%d' bytes from '%s' to upstream '%s'", len(buf), addr.String(), p.upstream.addr.String())
    80  		return
    81  	}
    82  
    83  	// Wait for a response until timeout
    84  	b := make([]byte, 10_000)
    85  	n, _, _ := p.upstream.Recv(b, p.conf.UpstreamTimeout)
    86  
    87  	// Nothing to send to upstream
    88  	if n == 0 {
    89  		return
    90  	}
    91  
    92  	// Send response to upstream
    93  	if _, err := conn.WriteTo(b[:n], addr); err != nil {
    94  		logrus.WithError(err).Errorf("failed to forward '%d' bytes from '%s' to downstream '%s'", n, p.upstream.addr.String(), addr.String())
    95  		return
    96  	}
    97  	logrus.Debugf("proxy %d bytes %s <- %s", n, p.upstream.addr.String(), addr.String())
    98  }
    99  
   100  func (p *Proxy) Block(addr string) {
   101  	p.mutex.Lock()
   102  	defer p.mutex.Unlock()
   103  	p.blocked = append(p.blocked, addr)
   104  }
   105  
   106  func (p *Proxy) UnBlock(addr string) {
   107  	p.mutex.Lock()
   108  	defer p.mutex.Unlock()
   109  
   110  	// Short cut
   111  	if len(p.blocked) == 1 {
   112  		p.blocked = nil
   113  	}
   114  
   115  	var blocked []string
   116  	for _, b := range p.blocked {
   117  		if b == addr {
   118  			continue
   119  		}
   120  		blocked = append(blocked, addr)
   121  	}
   122  	p.blocked = blocked
   123  }
   124  
   125  func (p *Proxy) Stop() {
   126  	p.mutex.Lock()
   127  	defer p.mutex.Unlock()
   128  	p.listen.Close()
   129  	p.upstream.Close()
   130  }