github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/tunnel/udp/tunnel.go (about)

     1  package udpTunnel
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/panjf2000/gnet"
    11  	"github.com/ronaksoft/rony/internal/metrics"
    12  	"github.com/ronaksoft/rony/internal/msg"
    13  	"github.com/ronaksoft/rony/internal/tunnel"
    14  	"github.com/ronaksoft/rony/log"
    15  	"github.com/ronaksoft/rony/pools/gopool"
    16  	"github.com/ronaksoft/rony/tools"
    17  	"go.uber.org/zap"
    18  )
    19  
    20  /*
    21     Creation Time: 2021 - Jan - 04
    22     Created by:  (ehsan)
    23     Maintainers:
    24        1.  Ehsan N. Moosa (E2)
    25     Auditor: Ehsan N. Moosa (E2)
    26     Copyright Ronak Software Group 2020
    27  */
    28  
    29  type Config struct {
    30  	ServerID      string
    31  	ListenAddress string
    32  	MaxBodySize   int
    33  	ExternalAddrs []string
    34  	Logger        log.Logger
    35  }
    36  
    37  type Tunnel struct {
    38  	tunnel.MessageHandler
    39  	cfg      Config
    40  	addrs    []string
    41  	shutdown int32 // atomic shutdown flag
    42  	connID   uint64
    43  }
    44  
    45  var _ gnet.EventHandler = (*Tunnel)(nil)
    46  
    47  func New(config Config) (*Tunnel, error) {
    48  	if config.Logger == nil {
    49  		config.Logger = log.DefaultLogger
    50  	}
    51  
    52  	t := &Tunnel{
    53  		cfg: config,
    54  	}
    55  
    56  	var hosts []string
    57  	// try to detect the ip address of the listener
    58  	ta, err := net.ResolveUDPAddr("udp", t.cfg.ListenAddress)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if ta.Port == 0 {
    63  		ta.Port = tools.RandomInt(63000)
    64  		t.cfg.ListenAddress = ta.String()
    65  	}
    66  
    67  	if ta.IP.IsUnspecified() {
    68  		addrs, err := net.InterfaceAddrs()
    69  
    70  		if err == nil {
    71  			for _, a := range addrs {
    72  				switch x := a.(type) {
    73  				case *net.IPNet:
    74  					if x.IP.To4() == nil || x.IP.IsLoopback() {
    75  						continue
    76  					}
    77  					hosts = append(hosts, x.IP.String())
    78  				case *net.IPAddr:
    79  					if x.IP.To4() == nil || x.IP.IsLoopback() {
    80  						continue
    81  					}
    82  					hosts = append(hosts, x.IP.String())
    83  				case *net.UDPAddr:
    84  					if x.IP.To4() == nil || x.IP.IsLoopback() {
    85  						continue
    86  					}
    87  					hosts = append(hosts, x.IP.String())
    88  				}
    89  			}
    90  		}
    91  	}
    92  
    93  	for _, h := range hosts {
    94  		t.addrs = append(t.addrs, fmt.Sprintf("%s:%d", h, ta.Port))
    95  	}
    96  
    97  	return t, nil
    98  }
    99  
   100  func (t *Tunnel) nextID() uint64 {
   101  	return atomic.AddUint64(&t.connID, 1)
   102  }
   103  
   104  func (t *Tunnel) Start() {
   105  	go t.Run()
   106  	time.Sleep(time.Millisecond * 100)
   107  }
   108  
   109  func (t *Tunnel) Run() {
   110  	err := gnet.Serve(t, fmt.Sprintf("udp://%s", t.cfg.ListenAddress),
   111  		gnet.WithReusePort(true),
   112  		gnet.WithMulticore(true),
   113  		gnet.WithLockOSThread(true),
   114  		gnet.WithLogLevel(log.WarnLevel),
   115  		gnet.WithLogger(t.cfg.Logger.Sugared()),
   116  	)
   117  
   118  	if err != nil {
   119  		panic(err)
   120  	}
   121  }
   122  
   123  func (t *Tunnel) Shutdown() {
   124  	atomic.StoreInt32(&t.shutdown, 1)
   125  	ctx, cf := context.WithTimeout(context.TODO(), time.Second*30)
   126  	defer cf()
   127  	if err := gnet.Stop(ctx, fmt.Sprintf("udp://%s", t.cfg.ListenAddress)); err != nil {
   128  		t.cfg.Logger.Warn("Error On Stopping Tunnel", zap.Error(err))
   129  	}
   130  }
   131  
   132  func (t *Tunnel) Addr() []string {
   133  	if len(t.cfg.ExternalAddrs) > 0 {
   134  		return t.cfg.ExternalAddrs
   135  	}
   136  
   137  	return t.addrs
   138  }
   139  
   140  func (t *Tunnel) OnInitComplete(server gnet.Server) (action gnet.Action) {
   141  	return gnet.None
   142  }
   143  
   144  func (t *Tunnel) OnShutdown(server gnet.Server) {
   145  	t.cfg.Logger.Info("Tunnel shutdown")
   146  }
   147  
   148  func (t *Tunnel) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) {
   149  	t.cfg.Logger.Info("Tunnel connection opened")
   150  
   151  	return nil, gnet.None
   152  }
   153  
   154  func (t *Tunnel) OnClosed(c gnet.Conn, err error) (action gnet.Action) {
   155  	t.cfg.Logger.Info("Tunnel connection closed", zap.Error(err))
   156  
   157  	return gnet.None
   158  }
   159  
   160  func (t *Tunnel) PreWrite(c gnet.Conn) {
   161  	//TODO implement me
   162  	panic("implement me")
   163  }
   164  
   165  func (t *Tunnel) AfterWrite(c gnet.Conn, b []byte) {}
   166  
   167  func (t *Tunnel) React(frame []byte, c gnet.Conn) (out []byte, action gnet.Action) {
   168  	if atomic.LoadInt32(&t.shutdown) == 1 {
   169  		return nil, gnet.Shutdown
   170  	}
   171  
   172  	req := msg.PoolTunnelMessage.Get()
   173  	if err := req.Unmarshal(frame); err != nil {
   174  		t.cfg.Logger.Warn("Error On Tunnel's data received", zap.Error(err))
   175  
   176  		return nil, gnet.Close
   177  	}
   178  
   179  	conn := newConn(t.nextID(), c)
   180  	gopool.Go(func() {
   181  		metrics.IncCounter(metrics.CntTunnelIncomingMessage)
   182  		t.MessageHandler(conn, req)
   183  		msg.PoolTunnelMessage.Put(req)
   184  	})
   185  
   186  	return
   187  }
   188  
   189  func (t *Tunnel) Tick() (delay time.Duration, action gnet.Action) {
   190  	return time.Minute, gnet.None
   191  }