github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/listener_udp.go (about)

     1  package nodes
     2  
     3  import (
     4  	"errors"
     5  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
     6  	"github.com/TeaOSLab/EdgeNode/internal/goman"
     7  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
     8  	"github.com/TeaOSLab/EdgeNode/internal/stats"
     9  	"github.com/TeaOSLab/EdgeNode/internal/utils"
    10  	"github.com/iwind/TeaGo/types"
    11  	"github.com/pires/go-proxyproto"
    12  	"golang.org/x/net/ipv4"
    13  	"golang.org/x/net/ipv6"
    14  	"net"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  const (
    21  	UDPConnLifeSeconds = 30
    22  )
    23  
    24  type UDPPacketListener interface {
    25  	ReadFrom(b []byte) (n int, cm any, src net.Addr, err error)
    26  	WriteTo(b []byte, cm any, dst net.Addr) (n int, err error)
    27  	LocalAddr() net.Addr
    28  }
    29  
    30  type UDPIPv4Listener struct {
    31  	rawListener *ipv4.PacketConn
    32  }
    33  
    34  func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener {
    35  	return &UDPIPv4Listener{rawListener: rawListener}
    36  }
    37  
    38  func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
    39  	return this.rawListener.ReadFrom(b)
    40  }
    41  
    42  func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
    43  	return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
    44  }
    45  
    46  func (this *UDPIPv4Listener) LocalAddr() net.Addr {
    47  	return this.rawListener.LocalAddr()
    48  }
    49  
    50  type UDPIPv6Listener struct {
    51  	rawListener *ipv6.PacketConn
    52  }
    53  
    54  func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener {
    55  	return &UDPIPv6Listener{rawListener: rawListener}
    56  }
    57  
    58  func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
    59  	return this.rawListener.ReadFrom(b)
    60  }
    61  
    62  func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
    63  	return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
    64  }
    65  
    66  func (this *UDPIPv6Listener) LocalAddr() net.Addr {
    67  	return this.rawListener.LocalAddr()
    68  }
    69  
    70  type UDPListener struct {
    71  	BaseListener
    72  
    73  	IPv4Listener *ipv4.PacketConn
    74  	IPv6Listener *ipv6.PacketConn
    75  
    76  	connMap    map[string]*UDPConn
    77  	connLocker sync.Mutex
    78  	connTicker *utils.Ticker
    79  
    80  	reverseProxy *serverconfigs.ReverseProxyConfig
    81  
    82  	port int
    83  
    84  	isClosed bool
    85  }
    86  
    87  func (this *UDPListener) Serve() error {
    88  	if this.Group == nil {
    89  		return nil
    90  	}
    91  	var server = this.Group.FirstServer()
    92  	if server == nil {
    93  		return nil
    94  	}
    95  	var serverId = server.Id
    96  
    97  	var wg = &sync.WaitGroup{}
    98  	wg.Add(2) // 2 = ipv4 + ipv6
    99  
   100  	go func() {
   101  		defer wg.Done()
   102  
   103  		if this.IPv4Listener != nil {
   104  			err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true)
   105  			if err != nil {
   106  				remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
   107  				return
   108  			}
   109  
   110  			err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener))
   111  			if err != nil {
   112  				remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
   113  				return
   114  			}
   115  		}
   116  	}()
   117  
   118  	go func() {
   119  		defer wg.Done()
   120  
   121  		if this.IPv6Listener != nil {
   122  			err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true)
   123  			if err != nil {
   124  				remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
   125  				return
   126  			}
   127  
   128  			err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener))
   129  			if err != nil {
   130  				remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
   131  				return
   132  			}
   133  		}
   134  	}()
   135  
   136  	wg.Wait()
   137  
   138  	return nil
   139  }
   140  
   141  func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
   142  	// 获取分组端口
   143  	var groupAddr = this.Group.Addr()
   144  	var portIndex = strings.LastIndex(groupAddr, ":")
   145  	if portIndex >= 0 {
   146  		var port = groupAddr[portIndex+1:]
   147  		this.port = types.Int(port)
   148  	}
   149  
   150  	var firstServer = this.Group.FirstServer()
   151  	if firstServer == nil {
   152  		return errors.New("no server available")
   153  	}
   154  	this.reverseProxy = firstServer.ReverseProxy
   155  	if this.reverseProxy == nil {
   156  		return errors.New("no ReverseProxy configured for the server '" + firstServer.Name + "'")
   157  	}
   158  
   159  	this.connMap = map[string]*UDPConn{}
   160  	this.connTicker = utils.NewTicker(1 * time.Minute)
   161  	goman.New(func() {
   162  		for this.connTicker.Next() {
   163  			this.gcConns()
   164  		}
   165  	})
   166  
   167  	var buffer = make([]byte, 4*1024)
   168  	for {
   169  		if this.isClosed {
   170  			return nil
   171  		}
   172  
   173  		// 检查用户状态
   174  		if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
   175  			return nil
   176  		}
   177  
   178  		n, cm, clientAddr, err := listener.ReadFrom(buffer)
   179  		if err != nil {
   180  			if this.isClosed {
   181  				return nil
   182  			}
   183  			return err
   184  		}
   185  
   186  		if n > 0 {
   187  			this.connLocker.Lock()
   188  			conn, ok := this.connMap[clientAddr.String()]
   189  			this.connLocker.Unlock()
   190  			if ok && !conn.IsOk() {
   191  				_ = conn.Close()
   192  				ok = false
   193  			}
   194  			if !ok {
   195  				originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr)
   196  				if err != nil {
   197  					remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
   198  					continue
   199  				}
   200  				if originConn == nil {
   201  					remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
   202  					continue
   203  				}
   204  				conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn))
   205  				this.connLocker.Lock()
   206  				this.connMap[clientAddr.String()] = conn
   207  				this.connLocker.Unlock()
   208  			}
   209  			_, _ = conn.Write(buffer[:n])
   210  		}
   211  	}
   212  }
   213  
   214  func (this *UDPListener) Close() error {
   215  	this.isClosed = true
   216  
   217  	if this.connTicker != nil {
   218  		this.connTicker.Stop()
   219  	}
   220  
   221  	// 关闭所有连接
   222  	this.connLocker.Lock()
   223  	for _, conn := range this.connMap {
   224  		_ = conn.Close()
   225  	}
   226  	this.connLocker.Unlock()
   227  
   228  	var errorStrings = []string{}
   229  	if this.IPv4Listener != nil {
   230  		err := this.IPv4Listener.Close()
   231  		if err != nil {
   232  			errorStrings = append(errorStrings, err.Error())
   233  		}
   234  	}
   235  
   236  	if this.IPv6Listener != nil {
   237  		err := this.IPv6Listener.Close()
   238  		if err != nil {
   239  			errorStrings = append(errorStrings, err.Error())
   240  		}
   241  	}
   242  
   243  	if len(errorStrings) > 0 {
   244  		return errors.New(errorStrings[0])
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
   251  	this.Group = group
   252  	this.Reset()
   253  
   254  	// 重置配置
   255  	var firstServer = this.Group.FirstServer()
   256  	if firstServer == nil {
   257  		return
   258  	}
   259  	this.reverseProxy = firstServer.ReverseProxy
   260  }
   261  
   262  func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) {
   263  	if reverseProxy == nil {
   264  		return nil, errors.New("no reverse proxy config")
   265  	}
   266  
   267  	var retries = 3
   268  	var addr string
   269  
   270  	var failedOriginIds []int64
   271  
   272  	for i := 0; i < retries; i++ {
   273  		var origin *serverconfigs.OriginConfig
   274  		if len(failedOriginIds) > 0 {
   275  			origin = reverseProxy.AnyOrigin(nil, failedOriginIds)
   276  		}
   277  		if origin == nil {
   278  			origin = reverseProxy.NextOrigin(nil)
   279  		}
   280  		if origin == nil {
   281  			continue
   282  		}
   283  
   284  		conn, addr, err = OriginConnect(origin, this.port, remoteAddr.String(), "")
   285  		if err != nil {
   286  			failedOriginIds = append(failedOriginIds, origin.Id)
   287  
   288  			remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
   289  
   290  			SharedOriginStateManager.Fail(origin, "", reverseProxy, func() {
   291  				reverseProxy.ResetScheduling()
   292  			})
   293  
   294  			continue
   295  		} else {
   296  			if !origin.IsOk {
   297  				SharedOriginStateManager.Success(origin, func() {
   298  					reverseProxy.ResetScheduling()
   299  				})
   300  			}
   301  
   302  			// PROXY Protocol
   303  			if reverseProxy != nil &&
   304  				reverseProxy.ProxyProtocol != nil &&
   305  				reverseProxy.ProxyProtocol.IsOn &&
   306  				(reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
   307  				var transportProtocol = proxyproto.UDPv4
   308  				if strings.Contains(remoteAddr.String(), "[") {
   309  					transportProtocol = proxyproto.UDPv6
   310  				}
   311  				var header = proxyproto.Header{
   312  					Version:           byte(reverseProxy.ProxyProtocol.Version),
   313  					Command:           proxyproto.PROXY,
   314  					TransportProtocol: transportProtocol,
   315  					SourceAddr:        remoteAddr,
   316  					DestinationAddr:   localAddr,
   317  				}
   318  				_, err = header.WriteTo(conn)
   319  				if err != nil {
   320  					_ = conn.Close()
   321  					return nil, err
   322  				}
   323  			}
   324  
   325  			return
   326  		}
   327  	}
   328  
   329  	if err == nil {
   330  		err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
   331  	}
   332  	return
   333  }
   334  
   335  // 回收连接
   336  func (this *UDPListener) gcConns() {
   337  	this.connLocker.Lock()
   338  	var closingConns = []*UDPConn{}
   339  	for addr, conn := range this.connMap {
   340  		if !conn.IsOk() {
   341  			closingConns = append(closingConns, conn)
   342  			delete(this.connMap, addr)
   343  		}
   344  	}
   345  	this.connLocker.Unlock()
   346  
   347  	for _, conn := range closingConns {
   348  		_ = conn.Close()
   349  	}
   350  }
   351  
   352  // UDPConn 自定义的UDP连接管理
   353  type UDPConn struct {
   354  	addr          net.Addr
   355  	proxyListener UDPPacketListener
   356  	serverConn    net.Conn
   357  	activatedAt   int64
   358  	isOk          bool
   359  	isClosed      bool
   360  }
   361  
   362  func NewUDPConn(server *serverconfigs.ServerConfig, clientAddr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn {
   363  	var conn = &UDPConn{
   364  		addr:          clientAddr,
   365  		proxyListener: proxyListener,
   366  		serverConn:    serverConn,
   367  		activatedAt:   time.Now().Unix(),
   368  		isOk:          true,
   369  	}
   370  
   371  	// 统计
   372  	if server != nil {
   373  		stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
   374  
   375  		// DAU统计
   376  		clientIP, _, parseErr := net.SplitHostPort(clientAddr.String())
   377  		if parseErr == nil {
   378  			stats.SharedDAUManager.AddIP(server.Id, clientIP)
   379  		}
   380  	}
   381  
   382  	// 处理ControlMessage
   383  	switch controlMessage := cm.(type) {
   384  	case *ipv4.ControlMessage:
   385  		controlMessage.Src = controlMessage.Dst
   386  	case *ipv6.ControlMessage:
   387  		controlMessage.Src = controlMessage.Dst
   388  	}
   389  
   390  	goman.New(func() {
   391  		var buf = utils.BytePool4k.Get()
   392  		defer func() {
   393  			utils.BytePool4k.Put(buf)
   394  		}()
   395  
   396  		for {
   397  			n, err := serverConn.Read(buf.Bytes)
   398  			if n > 0 {
   399  				conn.activatedAt = time.Now().Unix()
   400  
   401  				_, writingErr := proxyListener.WriteTo(buf.Bytes[:n], cm, clientAddr)
   402  				if writingErr != nil {
   403  					conn.isOk = false
   404  					break
   405  				}
   406  
   407  				// 记录流量和带宽
   408  				if server != nil {
   409  					// 流量
   410  					stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
   411  
   412  					// 带宽
   413  					var userPlanId int64
   414  					if server.UserPlan != nil && server.UserPlan.Id > 0 {
   415  						userPlanId = server.UserPlan.Id
   416  					}
   417  					stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n))
   418  				}
   419  			}
   420  			if err != nil {
   421  				conn.isOk = false
   422  				break
   423  			}
   424  		}
   425  	})
   426  	return conn
   427  }
   428  
   429  func (this *UDPConn) Write(b []byte) (n int, err error) {
   430  	this.activatedAt = time.Now().Unix()
   431  	n, err = this.serverConn.Write(b)
   432  	if err != nil {
   433  		this.isOk = false
   434  	}
   435  	return
   436  }
   437  
   438  func (this *UDPConn) Close() error {
   439  	this.isOk = false
   440  	if this.isClosed {
   441  		return nil
   442  	}
   443  	this.isClosed = true
   444  	return this.serverConn.Close()
   445  }
   446  
   447  func (this *UDPConn) IsOk() bool {
   448  	if !this.isOk {
   449  		return false
   450  	}
   451  	return time.Now().Unix()-this.activatedAt < UDPConnLifeSeconds // 如果超过 N 秒没有活动我们认为是超时
   452  }