github.com/mailgun/holster/v4@v4.20.0/udp/proxy.go (about) 1 package udp 2 3 import ( 4 "errors" 5 "fmt" 6 "net" 7 "sync" 8 "time" 9 10 "github.com/sirupsen/logrus" 11 ) 12 13 type Proxy struct { 14 conf ProxyConfig 15 listen *Server 16 upstream *Client 17 mutex sync.Mutex // ensure the handler isn't still executing when start()/stop() are called 18 // A list of addresses that we deny udp proxy for 19 blocked []string 20 } 21 22 type ProxyConfig struct { 23 // The address:port we should bind too 24 Listen string 25 // The address:port of the upstream udp server we should forward requests too 26 Upstream string 27 // How long we should wait for a response from the upstream udp server 28 UpstreamTimeout time.Duration 29 } 30 31 func NewProxy(conf ProxyConfig) *Proxy { 32 return &Proxy{ 33 conf: conf, 34 } 35 } 36 37 func (p *Proxy) Start() error { 38 p.mutex.Lock() 39 defer p.mutex.Unlock() 40 41 if p.conf.Listen == "" { 42 return errors.New("variable Listen cannot be empty") 43 } 44 45 if p.conf.Upstream == "" { 46 return errors.New("variable Upstream cannot be empty") 47 } 48 49 var err error 50 p.upstream, err = NewClient(p.conf.Upstream) 51 if err != nil { 52 return fmt.Errorf("while dialing upstream '%s' - %w", p.conf.Upstream, err) 53 } 54 55 p.listen, err = NewServer(ServerConfig{ 56 BindAddress: p.conf.Listen, 57 Handler: p.handler, 58 }) 59 if err != nil { 60 return fmt.Errorf("while attempting to listen on '%s' - %w", p.conf.Listen, err) 61 } 62 return nil 63 } 64 65 func (p *Proxy) handler(conn net.PacketConn, buf []byte, addr net.Addr) { 66 p.mutex.Lock() 67 defer p.mutex.Unlock() 68 69 for _, block := range p.blocked { 70 if block == addr.String() { 71 logrus.Debugf("Blocked proxy of %d bytes from '%s' ", len(buf), addr.String()) 72 return 73 } 74 } 75 logrus.Debugf("proxy %d bytes %s -> %s", len(buf), addr.String(), p.upstream.addr.String()) 76 77 // Forward the request to upstream 78 if _, err := p.upstream.Send(buf); err != nil { 79 logrus.WithError(err).Errorf("failed to forward '%d' bytes from '%s' to upstream '%s'", len(buf), addr.String(), p.upstream.addr.String()) 80 return 81 } 82 83 // Wait for a response until timeout 84 b := make([]byte, 10_000) 85 n, _, _ := p.upstream.Recv(b, p.conf.UpstreamTimeout) 86 87 // Nothing to send to upstream 88 if n == 0 { 89 return 90 } 91 92 // Send response to upstream 93 if _, err := conn.WriteTo(b[:n], addr); err != nil { 94 logrus.WithError(err).Errorf("failed to forward '%d' bytes from '%s' to downstream '%s'", n, p.upstream.addr.String(), addr.String()) 95 return 96 } 97 logrus.Debugf("proxy %d bytes %s <- %s", n, p.upstream.addr.String(), addr.String()) 98 } 99 100 func (p *Proxy) Block(addr string) { 101 p.mutex.Lock() 102 defer p.mutex.Unlock() 103 p.blocked = append(p.blocked, addr) 104 } 105 106 func (p *Proxy) UnBlock(addr string) { 107 p.mutex.Lock() 108 defer p.mutex.Unlock() 109 110 // Short cut 111 if len(p.blocked) == 1 { 112 p.blocked = nil 113 } 114 115 var blocked []string 116 for _, b := range p.blocked { 117 if b == addr { 118 continue 119 } 120 blocked = append(blocked, addr) 121 } 122 p.blocked = blocked 123 } 124 125 func (p *Proxy) Stop() { 126 p.mutex.Lock() 127 defer p.mutex.Unlock() 128 p.listen.Close() 129 p.upstream.Close() 130 }