github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/listener_tcp.go (about) 1 package nodes 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" 7 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" 8 "github.com/TeaOSLab/EdgeNode/internal/goman" 9 "github.com/TeaOSLab/EdgeNode/internal/remotelogs" 10 "github.com/TeaOSLab/EdgeNode/internal/stats" 11 "github.com/TeaOSLab/EdgeNode/internal/utils" 12 "github.com/iwind/TeaGo/types" 13 "github.com/pires/go-proxyproto" 14 "net" 15 "strings" 16 "sync/atomic" 17 ) 18 19 type TCPListener struct { 20 BaseListener 21 22 Listener net.Listener 23 24 port int 25 } 26 27 func (this *TCPListener) Serve() error { 28 var listener = this.Listener 29 if this.Group.IsTLS() { 30 listener = tls.NewListener(listener, this.buildTLSConfig()) 31 } 32 33 // 获取分组端口 34 var groupAddr = this.Group.Addr() 35 var portIndex = strings.LastIndex(groupAddr, ":") 36 if portIndex >= 0 { 37 var port = groupAddr[portIndex+1:] 38 this.port = types.Int(port) 39 } 40 41 for { 42 conn, err := listener.Accept() 43 if err != nil { 44 break 45 } 46 47 atomic.AddInt64(&this.countActiveConnections, 1) 48 49 go func(conn net.Conn) { 50 var server = this.Group.FirstServer() 51 if server == nil { 52 return 53 } 54 err = this.handleConn(server, conn) 55 if err != nil { 56 remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil) 57 } 58 atomic.AddInt64(&this.countActiveConnections, -1) 59 }(conn) 60 } 61 62 return nil 63 } 64 65 func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) { 66 this.Group = group 67 this.Reset() 68 } 69 70 func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error { 71 if server == nil { 72 return errors.New("no server available") 73 } 74 if server.ReverseProxy == nil { 75 return errors.New("no ReverseProxy configured for the server") 76 } 77 78 // 绑定连接和服务 79 clientConn, ok := conn.(ClientConnInterface) 80 if ok { 81 var goNext = clientConn.SetServerId(server.Id) 82 if !goNext { 83 return nil 84 } 85 clientConn.SetUserId(server.UserId) 86 87 var userPlanId int64 88 if server.UserPlan != nil && server.UserPlan.Id > 0 { 89 userPlanId = server.UserPlan.Id 90 } 91 clientConn.SetUserPlanId(userPlanId) 92 } else { 93 tlsConn, ok := conn.(*tls.Conn) 94 if ok { 95 var internalConn = tlsConn.NetConn() 96 if internalConn != nil { 97 clientConn, ok = internalConn.(ClientConnInterface) 98 if ok { 99 var goNext = clientConn.SetServerId(server.Id) 100 if !goNext { 101 return nil 102 } 103 clientConn.SetUserId(server.UserId) 104 105 var userPlanId int64 106 if server.UserPlan != nil && server.UserPlan.Id > 0 { 107 userPlanId = server.UserPlan.Id 108 } 109 clientConn.SetUserPlanId(userPlanId) 110 } 111 } 112 } 113 } 114 115 // 是否已达到流量限制 116 if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) { 117 // 关闭连接 118 tcpConn, ok := conn.(LingerConn) 119 if ok { 120 _ = tcpConn.SetLinger(0) 121 } 122 _ = conn.Close() 123 124 // TODO 使用系统防火墙drop当前端口的数据包一段时间(1分钟) 125 // 不能使用阻止IP的方法,因为边缘节点只上有可能还有别的代理服务 126 127 return nil 128 } 129 130 // 记录域名排行 131 tlsConn, ok := conn.(*tls.Conn) 132 var recordStat = false 133 var serverName = "" 134 if ok { 135 serverName = tlsConn.ConnectionState().ServerName 136 if len(serverName) > 0 { 137 // 统计 138 stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) 139 recordStat = true 140 } 141 } 142 143 // 统计 144 if !recordStat { 145 stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) 146 } 147 148 // DAU统计 149 clientIP, _, parseErr := net.SplitHostPort(conn.RemoteAddr().String()) 150 if parseErr == nil { 151 stats.SharedDAUManager.AddIP(server.Id, clientIP) 152 } 153 154 originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String()) 155 if err != nil { 156 _ = conn.Close() 157 return err 158 } 159 160 var closer = func() { 161 _ = conn.Close() 162 _ = originConn.Close() 163 } 164 165 // PROXY Protocol 166 if server.ReverseProxy != nil && 167 server.ReverseProxy.ProxyProtocol != nil && 168 server.ReverseProxy.ProxyProtocol.IsOn && 169 (server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { 170 var remoteAddr = conn.RemoteAddr() 171 var transportProtocol = proxyproto.TCPv4 172 if strings.Contains(remoteAddr.String(), "[") { 173 transportProtocol = proxyproto.TCPv6 174 } 175 var header = proxyproto.Header{ 176 Version: byte(server.ReverseProxy.ProxyProtocol.Version), 177 Command: proxyproto.PROXY, 178 TransportProtocol: transportProtocol, 179 SourceAddr: remoteAddr, 180 DestinationAddr: conn.LocalAddr(), 181 } 182 _, err = header.WriteTo(originConn) 183 if err != nil { 184 closer() 185 return err 186 } 187 } 188 189 // 从源站读取 190 goman.New(func() { 191 var originBuf = utils.BytePool16k.Get() 192 defer func() { 193 utils.BytePool16k.Put(originBuf) 194 }() 195 for { 196 n, err := originConn.Read(originBuf.Bytes) 197 if n > 0 { 198 _, err = conn.Write(originBuf.Bytes[:n]) 199 if err != nil { 200 closer() 201 break 202 } 203 204 // 记录流量 205 if server != nil { 206 stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) 207 } 208 } 209 if err != nil { 210 closer() 211 break 212 } 213 } 214 }) 215 216 // 从客户端读取 217 var clientBuf = utils.BytePool16k.Get() 218 defer func() { 219 utils.BytePool16k.Put(clientBuf) 220 }() 221 for { 222 // 是否已达到流量限制 223 if this.reachedTrafficLimit() { 224 closer() 225 return nil 226 } 227 228 n, err := conn.Read(clientBuf.Bytes) 229 if n > 0 { 230 _, err = originConn.Write(clientBuf.Bytes[:n]) 231 if err != nil { 232 break 233 } 234 } 235 if err != nil { 236 break 237 } 238 } 239 240 // 关闭连接 241 closer() 242 243 return nil 244 } 245 246 func (this *TCPListener) Close() error { 247 return this.Listener.Close() 248 } 249 250 // 连接源站 251 func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { 252 if reverseProxy == nil { 253 return nil, errors.New("no reverse proxy config") 254 } 255 256 var requestCall = shared.NewRequestCall() 257 requestCall.Domain = requestHost 258 259 var retries = 3 260 var addr string 261 262 var failedOriginIds []int64 263 264 for i := 0; i < retries; i++ { 265 var origin *serverconfigs.OriginConfig 266 if len(failedOriginIds) > 0 { 267 origin = reverseProxy.AnyOrigin(requestCall, failedOriginIds) 268 } 269 if origin == nil { 270 origin = reverseProxy.NextOrigin(requestCall) 271 } 272 if origin == nil { 273 continue 274 } 275 276 // 回源主机名 277 if len(origin.RequestHost) > 0 { 278 requestHost = origin.RequestHost 279 } else if len(reverseProxy.RequestHost) > 0 { 280 requestHost = reverseProxy.RequestHost 281 } 282 283 conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost) 284 if err != nil { 285 failedOriginIds = append(failedOriginIds, origin.Id) 286 287 remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil) 288 289 SharedOriginStateManager.Fail(origin, requestHost, reverseProxy, func() { 290 reverseProxy.ResetScheduling() 291 }) 292 293 continue 294 } else { 295 if !origin.IsOk { 296 SharedOriginStateManager.Success(origin, func() { 297 reverseProxy.ResetScheduling() 298 }) 299 } 300 301 return 302 } 303 } 304 305 if err == nil { 306 err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used") 307 } 308 return 309 } 310 311 // 检查是否已经达到流量限制 312 func (this *TCPListener) reachedTrafficLimit() bool { 313 var server = this.Group.FirstServer() 314 if server == nil { 315 return true 316 } 317 return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid() 318 }