github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/gateway/tcp/gateway.go (about) 1 package tcpGateway 2 3 import ( 4 "fmt" 5 "net" 6 "net/http" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/gobwas/ws" 12 "github.com/mailru/easygo/netpoll" 13 "github.com/panjf2000/ants/v2" 14 "github.com/ronaksoft/rony" 15 "github.com/ronaksoft/rony/errors" 16 "github.com/ronaksoft/rony/internal/gateway/tcp/cors" 17 wsutil "github.com/ronaksoft/rony/internal/gateway/tcp/util" 18 "github.com/ronaksoft/rony/internal/metrics" 19 "github.com/ronaksoft/rony/log" 20 "github.com/ronaksoft/rony/pools" 21 "github.com/ronaksoft/rony/tools" 22 "github.com/valyala/fasthttp" 23 "go.uber.org/zap" 24 ) 25 26 /* 27 Creation Time: 2019 - Feb - 28 28 Created by: (ehsan) 29 Maintainers: 30 1. Ehsan N. Moosa (E2) 31 Auditor: Ehsan N. Moosa (E2) 32 Copyright Ronak Software Group 2020 33 */ 34 35 type UnsafeConn interface { 36 net.Conn 37 UnsafeConn() net.Conn 38 } 39 40 // Config holds all the configuration for Gateway 41 type Config struct { 42 Concurrency int 43 ListenAddress string 44 MaxBodySize int 45 MaxIdleTime time.Duration 46 Protocol rony.GatewayProtocol 47 ExternalAddrs []string 48 Logger log.Logger 49 // TextDataFrame if is set to TRUE then websocket data frames use OpText otherwise use OpBinary 50 TextDataFrame bool 51 52 // CORS 53 AllowedHeaders []string // Default Allow All 54 AllowedOrigins []string // Default Allow All 55 AllowedMethods []string // Default Allow All 56 } 57 58 // Gateway is one of the main components of the Rony framework. Basically Gateway is the component 59 // that connects edge.Server with the external world. Clients which are not part of our cluster MUST 60 // connect to our edge servers through Gateway. 61 // This is an implementation of gateway.Gateway interface with support for **Http** and **Websocket** connections. 62 type Gateway struct { 63 // Internals 64 cfg Config 65 transportMode rony.GatewayProtocol 66 listener *wrapListener 67 listenerAddressMtx sync.RWMutex 68 listenerAddresses []string 69 poller netpoll.Poller 70 stop int32 71 waitGroupAcceptors *sync.WaitGroup 72 waitGroupReaders *sync.WaitGroup 73 waitGroupWriters *sync.WaitGroup 74 cntReads uint64 75 cntWrites uint64 76 cors *cors.CORS 77 delegate rony.GatewayDelegate 78 79 // Websocket Internals 80 upgradeHandler ws.Upgrader 81 connGC *websocketConnGC 82 maxIdleTime int64 83 conns map[uint64]*websocketConn 84 connsMtx sync.RWMutex 85 connsTotal int32 86 connsLastID uint64 87 } 88 89 func New(config Config) (*Gateway, error) { 90 var err error 91 92 if config.Logger == nil { 93 config.Logger = log.DefaultLogger 94 } 95 96 g := &Gateway{ 97 cfg: config, 98 maxIdleTime: int64(defaultConnIdleTime), 99 waitGroupReaders: &sync.WaitGroup{}, 100 waitGroupWriters: &sync.WaitGroup{}, 101 waitGroupAcceptors: &sync.WaitGroup{}, 102 conns: make(map[uint64]*websocketConn, 100000), 103 transportMode: rony.TCP, 104 cors: cors.New(cors.Config{ 105 AllowedHeaders: config.AllowedHeaders, 106 AllowedMethods: config.AllowedMethods, 107 AllowedOrigins: config.AllowedOrigins, 108 }), 109 } 110 111 g.listener, err = newWrapListener(g.cfg.ListenAddress) 112 if err != nil { 113 return nil, err 114 } 115 116 if config.MaxIdleTime != 0 { 117 g.maxIdleTime = int64(config.MaxIdleTime) 118 } 119 if config.Protocol != rony.Undefined { 120 g.transportMode = config.Protocol 121 } 122 123 switch g.transportMode { 124 case rony.Websocket, rony.Http, rony.TCP: 125 default: 126 return nil, ErrUnsupportedProtocol 127 } 128 129 // initialize websocket upgrade handler 130 g.upgradeHandler = ws.DefaultUpgrader 131 132 // initialize idle websocket garbage collector 133 g.connGC = newWebsocketConnGC(g) 134 135 // set handlers 136 if poller, err := netpoll.New(&netpoll.Config{ 137 OnWaitError: func(e error) { 138 g.cfg.Logger.Warn("Error On NetPoller Wait", 139 zap.Error(e), 140 ) 141 }, 142 }); err != nil { 143 return nil, err 144 } else { 145 g.poller = poller 146 } 147 148 // try to detect the ip address of the listener 149 err = g.detectListenerAddress() 150 if err != nil { 151 g.cfg.Logger.Warn("Rony:: Gateway got error on detecting listener addresses", zap.Error(err)) 152 153 return nil, err 154 } 155 156 goPoolB, err = ants.NewPool(g.cfg.Concurrency, 157 ants.WithNonblocking(false), 158 ants.WithPreAlloc(true), 159 ) 160 if err != nil { 161 return nil, err 162 } 163 164 goPoolNB, err = ants.NewPool(g.cfg.Concurrency, 165 ants.WithNonblocking(true), 166 ants.WithPreAlloc(true), 167 ) 168 if err != nil { 169 return nil, err 170 } 171 172 // run the watchdog in background 173 go g.watchdog() 174 175 return g, nil 176 } 177 178 func MustNew(config Config) *Gateway { 179 g, err := New(config) 180 if err != nil { 181 panic(err) 182 } 183 184 return g 185 } 186 187 func (g *Gateway) watchdog() { 188 for { 189 metrics.SetGauge(metrics.GaugeActiveWebsocketConnections, float64(g.TotalConnections())) 190 err := g.detectListenerAddress() 191 if err != nil { 192 g.cfg.Logger.Warn("Gateway got error on detecting listener address", zap.Error(err)) 193 } 194 time.Sleep(time.Second * 15) 195 } 196 } 197 198 func (g *Gateway) detectListenerAddress() error { 199 // try to detect the ip address of the listener 200 ta, err := net.ResolveTCPAddr("tcp4", g.listener.Addr().String()) 201 if err != nil { 202 return err 203 } 204 listenerAddresses := make([]string, 0, 10) 205 if ta.IP.IsUnspecified() { 206 interfaceAddresses, err := net.InterfaceAddrs() 207 if err == nil { 208 for _, a := range interfaceAddresses { 209 switch x := a.(type) { 210 case *net.IPNet: 211 if x.IP.To4() == nil || x.IP.IsLoopback() { 212 continue 213 } 214 listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port)) 215 case *net.IPAddr: 216 if x.IP.To4() == nil || x.IP.IsLoopback() { 217 continue 218 } 219 listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port)) 220 case *net.TCPAddr: 221 if x.IP.To4() == nil || x.IP.IsLoopback() { 222 continue 223 } 224 listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port)) 225 } 226 } 227 } 228 } else { 229 listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", ta.IP, ta.Port)) 230 } 231 g.listenerAddressMtx.Lock() 232 g.listenerAddresses = append(g.listenerAddresses[:0], listenerAddresses...) 233 g.listenerAddressMtx.Unlock() 234 235 return nil 236 } 237 238 func (g *Gateway) Subscribe(d rony.GatewayDelegate) { 239 g.delegate = d 240 } 241 242 // Start is non-blocking and call the Run function in background 243 func (g *Gateway) Start() { 244 go g.Run() 245 } 246 247 // Run is blocking and runs the server endless loop until a non-temporary error happens 248 func (g *Gateway) Run() { 249 // initialize the fasthttp server. 250 server := fasthttp.Server{ 251 Name: "Rony TCP-Gateway", 252 Handler: g.requestHandler, 253 Concurrency: g.cfg.Concurrency, 254 KeepHijackedConns: true, 255 MaxRequestBodySize: g.cfg.MaxBodySize, 256 DisableKeepalive: true, 257 CloseOnShutdown: true, 258 } 259 260 // start serving in blocking mode 261 err := server.Serve(g.listener) 262 if err != nil { 263 g.cfg.Logger.Warn("Error On Serve", zap.Error(err)) 264 } 265 } 266 267 // Shutdown closes the server by stopping services in sequence, in a way that all the flying request 268 // will be served before server shutdown. 269 func (g *Gateway) Shutdown() { 270 // 1. Stop Accepting New Connections, i.e. Stop ConnectionAcceptor routines 271 g.cfg.Logger.Info("Connection Acceptors are closing...") 272 atomic.StoreInt32(&g.stop, 1) 273 _ = g.listener.Close() 274 g.waitGroupAcceptors.Wait() 275 g.cfg.Logger.Info("Connection Acceptors all closed") 276 277 // 2. Close all readPumps 278 g.cfg.Logger.Info("Read Pumpers are closing") 279 g.waitGroupReaders.Wait() 280 g.cfg.Logger.Info("Read Pumpers all closed") 281 282 // 3. Close all writePumps 283 g.cfg.Logger.Info("Write Pumpers are closing") 284 g.waitGroupWriters.Wait() 285 g.cfg.Logger.Info("Write Pumpers all closed") 286 287 g.cfg.Logger.Info("Stats", 288 zap.Uint64("Reads", g.cntReads), 289 zap.Uint64("Writes", g.cntWrites), 290 ) 291 292 g.connsMtx.Lock() 293 for id, c := range g.conns { 294 g.cfg.Logger.Info("Conn Stalled", 295 zap.Uint64("ID", id), 296 zap.Duration("SinceStart", time.Duration(tools.CPUTicks()-atomic.LoadInt64(&c.startTime))), 297 zap.Duration("SinceLastActivity", time.Duration(tools.CPUTicks()-(atomic.LoadInt64(&c.lastActivity)))), 298 ) 299 } 300 g.connsMtx.Unlock() 301 } 302 303 // Addr return the address which gateway is listen on 304 func (g *Gateway) Addr() []string { 305 if len(g.cfg.ExternalAddrs) > 0 { 306 return g.cfg.ExternalAddrs 307 } 308 g.listenerAddressMtx.RLock() 309 addrs := g.listenerAddresses 310 g.listenerAddressMtx.RUnlock() 311 312 return addrs 313 } 314 315 // GetConn returns the connection identified by connID 316 func (g *Gateway) GetConn(connID uint64) rony.Conn { 317 c := g.getConnection(connID) 318 if c == nil { 319 return nil 320 } 321 322 return c 323 } 324 325 func (g *Gateway) Support(p rony.GatewayProtocol) bool { 326 return g.transportMode&p == p 327 } 328 329 func (g *Gateway) TotalConnections() int { 330 g.connsMtx.RLock() 331 n := len(g.conns) 332 g.connsMtx.RUnlock() 333 334 return n 335 } 336 337 func (g *Gateway) Protocol() rony.GatewayProtocol { 338 return g.transportMode 339 } 340 341 func (g *Gateway) requestHandler(reqCtx *fasthttp.RequestCtx) { 342 if g.cors.Handle(reqCtx) { 343 return 344 } 345 346 // extract required information from the header of the RequestCtx 347 connInfo := acquireConnInfo(reqCtx) 348 349 // If this is a Http Upgrade then we Handle websocket 350 if connInfo.Upgrade() { 351 if !g.Support(rony.Websocket) { 352 reqCtx.SetConnectionClose() 353 reqCtx.SetStatusCode(http.StatusNotAcceptable) 354 355 return 356 } 357 reqCtx.HijackSetNoResponse(true) 358 reqCtx.Hijack( 359 func(c net.Conn) { 360 wc, _ := c.(UnsafeConn).UnsafeConn().(*wrapConn) 361 wc.ReadyForUpgrade() 362 g.waitGroupAcceptors.Add(1) 363 g.websocketHandler(wc, connInfo) 364 releaseConnInfo(connInfo) 365 }, 366 ) 367 368 return 369 } 370 371 // This is going to be an HTTP request 372 reqCtx.SetConnectionClose() 373 if !g.Support(rony.Http) { 374 reqCtx.SetStatusCode(http.StatusNotAcceptable) 375 376 return 377 } 378 379 conn := acquireHttpConn(g, reqCtx) 380 conn.SetClientIP(connInfo.clientIP) 381 conn.SetClientType(connInfo.clientType) 382 for k, v := range connInfo.kvs { 383 conn.Set(k, v) 384 } 385 386 metrics.IncCounter(metrics.CntGatewayIncomingHttpMessage) 387 388 g.delegate.OnConnect(conn) 389 390 g.delegate.OnMessage(conn, int64(reqCtx.ID()), reqCtx.PostBody()) 391 392 g.delegate.OnClose(conn) 393 394 releaseConnInfo(connInfo) 395 releaseHttpConn(conn) 396 } 397 398 func (g *Gateway) websocketHandler(c net.Conn, meta *connInfo) { 399 defer g.waitGroupAcceptors.Done() 400 if atomic.LoadInt32(&g.stop) == 1 { 401 return 402 } 403 if _, err := g.upgradeHandler.Upgrade(c); err != nil { 404 if ce := g.cfg.Logger.Check(log.InfoLevel, "got error in websocket upgrade"); ce != nil { 405 ce.Write( 406 zap.String("IP", tools.B2S(meta.clientIP)), 407 zap.String("ClientType", tools.B2S(meta.clientType)), 408 zap.Error(err), 409 ) 410 } 411 _ = c.Close() 412 413 return 414 } 415 416 var ( 417 err error 418 ) 419 420 wsConn, err := newWebsocketConn(g, c, meta.clientIP) 421 if err != nil { 422 g.cfg.Logger.Warn("got error on creating websocket connection", 423 zap.Error(err), 424 zap.Int("Total", g.TotalConnections()), 425 ) 426 427 return 428 } 429 for k, v := range meta.kvs { 430 wsConn.Set(k, v) 431 } 432 433 g.delegate.OnConnect(wsConn) 434 435 err = wsConn.registerDesc() 436 if err != nil { 437 g.cfg.Logger.Warn("got error in registering conn desc", 438 zap.Error(err), 439 zap.Any("Conn", wsConn.conn), 440 ) 441 } 442 } 443 444 func (g *Gateway) websocketReadPump(wc *websocketConn, wg *sync.WaitGroup) (err error) { 445 var ms []wsutil.Message 446 ms, err = wc.read(ms) 447 if err != nil { 448 if ce := g.cfg.Logger.Check(log.DebugLevel, "got error in websocket read pump"); ce != nil { 449 ce.Write( 450 zap.Uint64("ConnID", wc.connID), 451 zap.Error(err), 452 ) 453 } 454 455 return errors.Wrap(ErrUnexpectedSocketRead)(err) 456 } 457 atomic.AddUint64(&g.cntReads, 1) 458 459 // Handle messages 460 for idx := range ms { 461 switch ms[idx].OpCode { 462 case ws.OpPong: 463 case ws.OpPing: 464 err = wc.write(ws.OpPong, ms[idx].Payload) 465 pools.Bytes.Put(ms[idx].Payload) 466 case ws.OpBinary, ws.OpText: 467 wg.Add(1) 468 _ = goPoolB.Submit( 469 func(idx int) func() { 470 return func() { 471 metrics.IncCounter(metrics.CntGatewayIncomingWebsocketMessage) 472 g.delegate.OnMessage(wc, 0, ms[idx].Payload) 473 pools.Bytes.Put(ms[idx].Payload) 474 wg.Done() 475 } 476 }(idx), 477 ) 478 case ws.OpClose: 479 // remove the connection from the list 480 err = ErrOpCloseReceived 481 default: 482 g.cfg.Logger.Warn("Unknown OpCode") 483 } 484 } 485 486 return err 487 } 488 489 func (g *Gateway) websocketWritePump(wr *writeRequest) (err error) { 490 defer g.waitGroupWriters.Done() 491 492 switch wr.opCode { 493 case ws.OpBinary, ws.OpText: 494 err = wr.wc.write(wr.opCode, wr.payload) 495 if err != nil { 496 if ce := g.cfg.Logger.Check(log.DebugLevel, "Error in websocketWritePump"); ce != nil { 497 ce.Write(zap.Error(err), zap.Uint64("ConnID", wr.wc.connID)) 498 } 499 } else { 500 atomic.AddUint64(&g.cntWrites, 1) 501 } 502 } 503 504 return 505 } 506 507 func (g *Gateway) getConnection(connID uint64) *websocketConn { 508 g.connsMtx.RLock() 509 wsConn, ok := g.conns[connID] 510 g.connsMtx.RUnlock() 511 if ok { 512 return wsConn 513 } 514 515 return nil 516 }