github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/listener_tcp.go (about)

     1  package nodes
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
     7  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
     8  	"github.com/TeaOSLab/EdgeNode/internal/goman"
     9  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
    10  	"github.com/TeaOSLab/EdgeNode/internal/stats"
    11  	"github.com/TeaOSLab/EdgeNode/internal/utils"
    12  	"github.com/iwind/TeaGo/types"
    13  	"github.com/pires/go-proxyproto"
    14  	"net"
    15  	"strings"
    16  	"sync/atomic"
    17  )
    18  
    19  type TCPListener struct {
    20  	BaseListener
    21  
    22  	Listener net.Listener
    23  
    24  	port int
    25  }
    26  
    27  func (this *TCPListener) Serve() error {
    28  	var listener = this.Listener
    29  	if this.Group.IsTLS() {
    30  		listener = tls.NewListener(listener, this.buildTLSConfig())
    31  	}
    32  
    33  	// 获取分组端口
    34  	var groupAddr = this.Group.Addr()
    35  	var portIndex = strings.LastIndex(groupAddr, ":")
    36  	if portIndex >= 0 {
    37  		var port = groupAddr[portIndex+1:]
    38  		this.port = types.Int(port)
    39  	}
    40  
    41  	for {
    42  		conn, err := listener.Accept()
    43  		if err != nil {
    44  			break
    45  		}
    46  
    47  		atomic.AddInt64(&this.countActiveConnections, 1)
    48  
    49  		go func(conn net.Conn) {
    50  			var server = this.Group.FirstServer()
    51  			if server == nil {
    52  				return
    53  			}
    54  			err = this.handleConn(server, conn)
    55  			if err != nil {
    56  				remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
    57  			}
    58  			atomic.AddInt64(&this.countActiveConnections, -1)
    59  		}(conn)
    60  	}
    61  
    62  	return nil
    63  }
    64  
    65  func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
    66  	this.Group = group
    67  	this.Reset()
    68  }
    69  
    70  func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
    71  	if server == nil {
    72  		return errors.New("no server available")
    73  	}
    74  	if server.ReverseProxy == nil {
    75  		return errors.New("no ReverseProxy configured for the server")
    76  	}
    77  
    78  	// 绑定连接和服务
    79  	clientConn, ok := conn.(ClientConnInterface)
    80  	if ok {
    81  		var goNext = clientConn.SetServerId(server.Id)
    82  		if !goNext {
    83  			return nil
    84  		}
    85  		clientConn.SetUserId(server.UserId)
    86  
    87  		var userPlanId int64
    88  		if server.UserPlan != nil && server.UserPlan.Id > 0 {
    89  			userPlanId = server.UserPlan.Id
    90  		}
    91  		clientConn.SetUserPlanId(userPlanId)
    92  	} else {
    93  		tlsConn, ok := conn.(*tls.Conn)
    94  		if ok {
    95  			var internalConn = tlsConn.NetConn()
    96  			if internalConn != nil {
    97  				clientConn, ok = internalConn.(ClientConnInterface)
    98  				if ok {
    99  					var goNext = clientConn.SetServerId(server.Id)
   100  					if !goNext {
   101  						return nil
   102  					}
   103  					clientConn.SetUserId(server.UserId)
   104  
   105  					var userPlanId int64
   106  					if server.UserPlan != nil && server.UserPlan.Id > 0 {
   107  						userPlanId = server.UserPlan.Id
   108  					}
   109  					clientConn.SetUserPlanId(userPlanId)
   110  				}
   111  			}
   112  		}
   113  	}
   114  
   115  	// 是否已达到流量限制
   116  	if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
   117  		// 关闭连接
   118  		tcpConn, ok := conn.(LingerConn)
   119  		if ok {
   120  			_ = tcpConn.SetLinger(0)
   121  		}
   122  		_ = conn.Close()
   123  
   124  		// TODO 使用系统防火墙drop当前端口的数据包一段时间(1分钟)
   125  		// 不能使用阻止IP的方法,因为边缘节点只上有可能还有别的代理服务
   126  
   127  		return nil
   128  	}
   129  
   130  	// 记录域名排行
   131  	tlsConn, ok := conn.(*tls.Conn)
   132  	var recordStat = false
   133  	var serverName = ""
   134  	if ok {
   135  		serverName = tlsConn.ConnectionState().ServerName
   136  		if len(serverName) > 0 {
   137  			// 统计
   138  			stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
   139  			recordStat = true
   140  		}
   141  	}
   142  
   143  	// 统计
   144  	if !recordStat {
   145  		stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
   146  	}
   147  
   148  	// DAU统计
   149  	clientIP, _, parseErr := net.SplitHostPort(conn.RemoteAddr().String())
   150  	if parseErr == nil {
   151  		stats.SharedDAUManager.AddIP(server.Id, clientIP)
   152  	}
   153  
   154  	originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
   155  	if err != nil {
   156  		_ = conn.Close()
   157  		return err
   158  	}
   159  
   160  	var closer = func() {
   161  		_ = conn.Close()
   162  		_ = originConn.Close()
   163  	}
   164  
   165  	// PROXY Protocol
   166  	if server.ReverseProxy != nil &&
   167  		server.ReverseProxy.ProxyProtocol != nil &&
   168  		server.ReverseProxy.ProxyProtocol.IsOn &&
   169  		(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
   170  		var remoteAddr = conn.RemoteAddr()
   171  		var transportProtocol = proxyproto.TCPv4
   172  		if strings.Contains(remoteAddr.String(), "[") {
   173  			transportProtocol = proxyproto.TCPv6
   174  		}
   175  		var header = proxyproto.Header{
   176  			Version:           byte(server.ReverseProxy.ProxyProtocol.Version),
   177  			Command:           proxyproto.PROXY,
   178  			TransportProtocol: transportProtocol,
   179  			SourceAddr:        remoteAddr,
   180  			DestinationAddr:   conn.LocalAddr(),
   181  		}
   182  		_, err = header.WriteTo(originConn)
   183  		if err != nil {
   184  			closer()
   185  			return err
   186  		}
   187  	}
   188  
   189  	// 从源站读取
   190  	goman.New(func() {
   191  		var originBuf = utils.BytePool16k.Get()
   192  		defer func() {
   193  			utils.BytePool16k.Put(originBuf)
   194  		}()
   195  		for {
   196  			n, err := originConn.Read(originBuf.Bytes)
   197  			if n > 0 {
   198  				_, err = conn.Write(originBuf.Bytes[:n])
   199  				if err != nil {
   200  					closer()
   201  					break
   202  				}
   203  
   204  				// 记录流量
   205  				if server != nil {
   206  					stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
   207  				}
   208  			}
   209  			if err != nil {
   210  				closer()
   211  				break
   212  			}
   213  		}
   214  	})
   215  
   216  	// 从客户端读取
   217  	var clientBuf = utils.BytePool16k.Get()
   218  	defer func() {
   219  		utils.BytePool16k.Put(clientBuf)
   220  	}()
   221  	for {
   222  		// 是否已达到流量限制
   223  		if this.reachedTrafficLimit() {
   224  			closer()
   225  			return nil
   226  		}
   227  
   228  		n, err := conn.Read(clientBuf.Bytes)
   229  		if n > 0 {
   230  			_, err = originConn.Write(clientBuf.Bytes[:n])
   231  			if err != nil {
   232  				break
   233  			}
   234  		}
   235  		if err != nil {
   236  			break
   237  		}
   238  	}
   239  
   240  	// 关闭连接
   241  	closer()
   242  
   243  	return nil
   244  }
   245  
   246  func (this *TCPListener) Close() error {
   247  	return this.Listener.Close()
   248  }
   249  
   250  // 连接源站
   251  func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
   252  	if reverseProxy == nil {
   253  		return nil, errors.New("no reverse proxy config")
   254  	}
   255  
   256  	var requestCall = shared.NewRequestCall()
   257  	requestCall.Domain = requestHost
   258  
   259  	var retries = 3
   260  	var addr string
   261  
   262  	var failedOriginIds []int64
   263  
   264  	for i := 0; i < retries; i++ {
   265  		var origin *serverconfigs.OriginConfig
   266  		if len(failedOriginIds) > 0 {
   267  			origin = reverseProxy.AnyOrigin(requestCall, failedOriginIds)
   268  		}
   269  		if origin == nil {
   270  			origin = reverseProxy.NextOrigin(requestCall)
   271  		}
   272  		if origin == nil {
   273  			continue
   274  		}
   275  
   276  		// 回源主机名
   277  		if len(origin.RequestHost) > 0 {
   278  			requestHost = origin.RequestHost
   279  		} else if len(reverseProxy.RequestHost) > 0 {
   280  			requestHost = reverseProxy.RequestHost
   281  		}
   282  
   283  		conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)
   284  		if err != nil {
   285  			failedOriginIds = append(failedOriginIds, origin.Id)
   286  
   287  			remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
   288  
   289  			SharedOriginStateManager.Fail(origin, requestHost, reverseProxy, func() {
   290  				reverseProxy.ResetScheduling()
   291  			})
   292  
   293  			continue
   294  		} else {
   295  			if !origin.IsOk {
   296  				SharedOriginStateManager.Success(origin, func() {
   297  					reverseProxy.ResetScheduling()
   298  				})
   299  			}
   300  
   301  			return
   302  		}
   303  	}
   304  
   305  	if err == nil {
   306  		err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
   307  	}
   308  	return
   309  }
   310  
   311  // 检查是否已经达到流量限制
   312  func (this *TCPListener) reachedTrafficLimit() bool {
   313  	var server = this.Group.FirstServer()
   314  	if server == nil {
   315  		return true
   316  	}
   317  	return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
   318  }