github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/gateway/tcp/gateway.go (about)

     1  package tcpGateway
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"net/http"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/gobwas/ws"
    12  	"github.com/mailru/easygo/netpoll"
    13  	"github.com/panjf2000/ants/v2"
    14  	"github.com/ronaksoft/rony"
    15  	"github.com/ronaksoft/rony/errors"
    16  	"github.com/ronaksoft/rony/internal/gateway/tcp/cors"
    17  	wsutil "github.com/ronaksoft/rony/internal/gateway/tcp/util"
    18  	"github.com/ronaksoft/rony/internal/metrics"
    19  	"github.com/ronaksoft/rony/log"
    20  	"github.com/ronaksoft/rony/pools"
    21  	"github.com/ronaksoft/rony/tools"
    22  	"github.com/valyala/fasthttp"
    23  	"go.uber.org/zap"
    24  )
    25  
    26  /*
    27     Creation Time: 2019 - Feb - 28
    28     Created by:  (ehsan)
    29     Maintainers:
    30        1.  Ehsan N. Moosa (E2)
    31     Auditor: Ehsan N. Moosa (E2)
    32     Copyright Ronak Software Group 2020
    33  */
    34  
    35  type UnsafeConn interface {
    36  	net.Conn
    37  	UnsafeConn() net.Conn
    38  }
    39  
    40  // Config holds all the configuration for Gateway
    41  type Config struct {
    42  	Concurrency   int
    43  	ListenAddress string
    44  	MaxBodySize   int
    45  	MaxIdleTime   time.Duration
    46  	Protocol      rony.GatewayProtocol
    47  	ExternalAddrs []string
    48  	Logger        log.Logger
    49  	// TextDataFrame if is set to TRUE then websocket data frames use OpText otherwise use OpBinary
    50  	TextDataFrame bool
    51  
    52  	// CORS
    53  	AllowedHeaders []string // Default Allow All
    54  	AllowedOrigins []string // Default Allow All
    55  	AllowedMethods []string // Default Allow All
    56  }
    57  
    58  // Gateway is one of the main components of the Rony framework. Basically Gateway is the component
    59  // that connects edge.Server with the external world. Clients which are not part of our cluster MUST
    60  // connect to our edge servers through Gateway.
    61  // This is an implementation of gateway.Gateway interface with support for **Http** and **Websocket** connections.
    62  type Gateway struct {
    63  	// Internals
    64  	cfg                Config
    65  	transportMode      rony.GatewayProtocol
    66  	listener           *wrapListener
    67  	listenerAddressMtx sync.RWMutex
    68  	listenerAddresses  []string
    69  	poller             netpoll.Poller
    70  	stop               int32
    71  	waitGroupAcceptors *sync.WaitGroup
    72  	waitGroupReaders   *sync.WaitGroup
    73  	waitGroupWriters   *sync.WaitGroup
    74  	cntReads           uint64
    75  	cntWrites          uint64
    76  	cors               *cors.CORS
    77  	delegate           rony.GatewayDelegate
    78  
    79  	// Websocket Internals
    80  	upgradeHandler ws.Upgrader
    81  	connGC         *websocketConnGC
    82  	maxIdleTime    int64
    83  	conns          map[uint64]*websocketConn
    84  	connsMtx       sync.RWMutex
    85  	connsTotal     int32
    86  	connsLastID    uint64
    87  }
    88  
    89  func New(config Config) (*Gateway, error) {
    90  	var err error
    91  
    92  	if config.Logger == nil {
    93  		config.Logger = log.DefaultLogger
    94  	}
    95  
    96  	g := &Gateway{
    97  		cfg:                config,
    98  		maxIdleTime:        int64(defaultConnIdleTime),
    99  		waitGroupReaders:   &sync.WaitGroup{},
   100  		waitGroupWriters:   &sync.WaitGroup{},
   101  		waitGroupAcceptors: &sync.WaitGroup{},
   102  		conns:              make(map[uint64]*websocketConn, 100000),
   103  		transportMode:      rony.TCP,
   104  		cors: cors.New(cors.Config{
   105  			AllowedHeaders: config.AllowedHeaders,
   106  			AllowedMethods: config.AllowedMethods,
   107  			AllowedOrigins: config.AllowedOrigins,
   108  		}),
   109  	}
   110  
   111  	g.listener, err = newWrapListener(g.cfg.ListenAddress)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	if config.MaxIdleTime != 0 {
   117  		g.maxIdleTime = int64(config.MaxIdleTime)
   118  	}
   119  	if config.Protocol != rony.Undefined {
   120  		g.transportMode = config.Protocol
   121  	}
   122  
   123  	switch g.transportMode {
   124  	case rony.Websocket, rony.Http, rony.TCP:
   125  	default:
   126  		return nil, ErrUnsupportedProtocol
   127  	}
   128  
   129  	// initialize websocket upgrade handler
   130  	g.upgradeHandler = ws.DefaultUpgrader
   131  
   132  	// initialize idle websocket garbage collector
   133  	g.connGC = newWebsocketConnGC(g)
   134  
   135  	// set handlers
   136  	if poller, err := netpoll.New(&netpoll.Config{
   137  		OnWaitError: func(e error) {
   138  			g.cfg.Logger.Warn("Error On NetPoller Wait",
   139  				zap.Error(e),
   140  			)
   141  		},
   142  	}); err != nil {
   143  		return nil, err
   144  	} else {
   145  		g.poller = poller
   146  	}
   147  
   148  	// try to detect the ip address of the listener
   149  	err = g.detectListenerAddress()
   150  	if err != nil {
   151  		g.cfg.Logger.Warn("Rony:: Gateway got error on detecting listener addresses", zap.Error(err))
   152  
   153  		return nil, err
   154  	}
   155  
   156  	goPoolB, err = ants.NewPool(g.cfg.Concurrency,
   157  		ants.WithNonblocking(false),
   158  		ants.WithPreAlloc(true),
   159  	)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	goPoolNB, err = ants.NewPool(g.cfg.Concurrency,
   165  		ants.WithNonblocking(true),
   166  		ants.WithPreAlloc(true),
   167  	)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	// run the watchdog in background
   173  	go g.watchdog()
   174  
   175  	return g, nil
   176  }
   177  
   178  func MustNew(config Config) *Gateway {
   179  	g, err := New(config)
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  
   184  	return g
   185  }
   186  
   187  func (g *Gateway) watchdog() {
   188  	for {
   189  		metrics.SetGauge(metrics.GaugeActiveWebsocketConnections, float64(g.TotalConnections()))
   190  		err := g.detectListenerAddress()
   191  		if err != nil {
   192  			g.cfg.Logger.Warn("Gateway got error on detecting listener address", zap.Error(err))
   193  		}
   194  		time.Sleep(time.Second * 15)
   195  	}
   196  }
   197  
   198  func (g *Gateway) detectListenerAddress() error {
   199  	// try to detect the ip address of the listener
   200  	ta, err := net.ResolveTCPAddr("tcp4", g.listener.Addr().String())
   201  	if err != nil {
   202  		return err
   203  	}
   204  	listenerAddresses := make([]string, 0, 10)
   205  	if ta.IP.IsUnspecified() {
   206  		interfaceAddresses, err := net.InterfaceAddrs()
   207  		if err == nil {
   208  			for _, a := range interfaceAddresses {
   209  				switch x := a.(type) {
   210  				case *net.IPNet:
   211  					if x.IP.To4() == nil || x.IP.IsLoopback() {
   212  						continue
   213  					}
   214  					listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port))
   215  				case *net.IPAddr:
   216  					if x.IP.To4() == nil || x.IP.IsLoopback() {
   217  						continue
   218  					}
   219  					listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port))
   220  				case *net.TCPAddr:
   221  					if x.IP.To4() == nil || x.IP.IsLoopback() {
   222  						continue
   223  					}
   224  					listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", x.IP.String(), ta.Port))
   225  				}
   226  			}
   227  		}
   228  	} else {
   229  		listenerAddresses = append(listenerAddresses, fmt.Sprintf("%s:%d", ta.IP, ta.Port))
   230  	}
   231  	g.listenerAddressMtx.Lock()
   232  	g.listenerAddresses = append(g.listenerAddresses[:0], listenerAddresses...)
   233  	g.listenerAddressMtx.Unlock()
   234  
   235  	return nil
   236  }
   237  
   238  func (g *Gateway) Subscribe(d rony.GatewayDelegate) {
   239  	g.delegate = d
   240  }
   241  
   242  // Start is non-blocking and call the Run function in background
   243  func (g *Gateway) Start() {
   244  	go g.Run()
   245  }
   246  
   247  // Run is blocking and runs the server endless loop until a non-temporary error happens
   248  func (g *Gateway) Run() {
   249  	// initialize the fasthttp server.
   250  	server := fasthttp.Server{
   251  		Name:               "Rony TCP-Gateway",
   252  		Handler:            g.requestHandler,
   253  		Concurrency:        g.cfg.Concurrency,
   254  		KeepHijackedConns:  true,
   255  		MaxRequestBodySize: g.cfg.MaxBodySize,
   256  		DisableKeepalive:   true,
   257  		CloseOnShutdown:    true,
   258  	}
   259  
   260  	// start serving in blocking mode
   261  	err := server.Serve(g.listener)
   262  	if err != nil {
   263  		g.cfg.Logger.Warn("Error On Serve", zap.Error(err))
   264  	}
   265  }
   266  
   267  // Shutdown closes the server by stopping services in sequence, in a way that all the flying request
   268  // will be served before server shutdown.
   269  func (g *Gateway) Shutdown() {
   270  	// 1. Stop Accepting New Connections, i.e. Stop ConnectionAcceptor routines
   271  	g.cfg.Logger.Info("Connection Acceptors are closing...")
   272  	atomic.StoreInt32(&g.stop, 1)
   273  	_ = g.listener.Close()
   274  	g.waitGroupAcceptors.Wait()
   275  	g.cfg.Logger.Info("Connection Acceptors all closed")
   276  
   277  	// 2. Close all readPumps
   278  	g.cfg.Logger.Info("Read Pumpers are closing")
   279  	g.waitGroupReaders.Wait()
   280  	g.cfg.Logger.Info("Read Pumpers all closed")
   281  
   282  	// 3. Close all writePumps
   283  	g.cfg.Logger.Info("Write Pumpers are closing")
   284  	g.waitGroupWriters.Wait()
   285  	g.cfg.Logger.Info("Write Pumpers all closed")
   286  
   287  	g.cfg.Logger.Info("Stats",
   288  		zap.Uint64("Reads", g.cntReads),
   289  		zap.Uint64("Writes", g.cntWrites),
   290  	)
   291  
   292  	g.connsMtx.Lock()
   293  	for id, c := range g.conns {
   294  		g.cfg.Logger.Info("Conn Stalled",
   295  			zap.Uint64("ID", id),
   296  			zap.Duration("SinceStart", time.Duration(tools.CPUTicks()-atomic.LoadInt64(&c.startTime))),
   297  			zap.Duration("SinceLastActivity", time.Duration(tools.CPUTicks()-(atomic.LoadInt64(&c.lastActivity)))),
   298  		)
   299  	}
   300  	g.connsMtx.Unlock()
   301  }
   302  
   303  // Addr return the address which gateway is listen on
   304  func (g *Gateway) Addr() []string {
   305  	if len(g.cfg.ExternalAddrs) > 0 {
   306  		return g.cfg.ExternalAddrs
   307  	}
   308  	g.listenerAddressMtx.RLock()
   309  	addrs := g.listenerAddresses
   310  	g.listenerAddressMtx.RUnlock()
   311  
   312  	return addrs
   313  }
   314  
   315  // GetConn returns the connection identified by connID
   316  func (g *Gateway) GetConn(connID uint64) rony.Conn {
   317  	c := g.getConnection(connID)
   318  	if c == nil {
   319  		return nil
   320  	}
   321  
   322  	return c
   323  }
   324  
   325  func (g *Gateway) Support(p rony.GatewayProtocol) bool {
   326  	return g.transportMode&p == p
   327  }
   328  
   329  func (g *Gateway) TotalConnections() int {
   330  	g.connsMtx.RLock()
   331  	n := len(g.conns)
   332  	g.connsMtx.RUnlock()
   333  
   334  	return n
   335  }
   336  
   337  func (g *Gateway) Protocol() rony.GatewayProtocol {
   338  	return g.transportMode
   339  }
   340  
   341  func (g *Gateway) requestHandler(reqCtx *fasthttp.RequestCtx) {
   342  	if g.cors.Handle(reqCtx) {
   343  		return
   344  	}
   345  
   346  	// extract required information from the header of the RequestCtx
   347  	connInfo := acquireConnInfo(reqCtx)
   348  
   349  	// If this is a Http Upgrade then we Handle websocket
   350  	if connInfo.Upgrade() {
   351  		if !g.Support(rony.Websocket) {
   352  			reqCtx.SetConnectionClose()
   353  			reqCtx.SetStatusCode(http.StatusNotAcceptable)
   354  
   355  			return
   356  		}
   357  		reqCtx.HijackSetNoResponse(true)
   358  		reqCtx.Hijack(
   359  			func(c net.Conn) {
   360  				wc, _ := c.(UnsafeConn).UnsafeConn().(*wrapConn)
   361  				wc.ReadyForUpgrade()
   362  				g.waitGroupAcceptors.Add(1)
   363  				g.websocketHandler(wc, connInfo)
   364  				releaseConnInfo(connInfo)
   365  			},
   366  		)
   367  
   368  		return
   369  	}
   370  
   371  	// This is going to be an HTTP request
   372  	reqCtx.SetConnectionClose()
   373  	if !g.Support(rony.Http) {
   374  		reqCtx.SetStatusCode(http.StatusNotAcceptable)
   375  
   376  		return
   377  	}
   378  
   379  	conn := acquireHttpConn(g, reqCtx)
   380  	conn.SetClientIP(connInfo.clientIP)
   381  	conn.SetClientType(connInfo.clientType)
   382  	for k, v := range connInfo.kvs {
   383  		conn.Set(k, v)
   384  	}
   385  
   386  	metrics.IncCounter(metrics.CntGatewayIncomingHttpMessage)
   387  
   388  	g.delegate.OnConnect(conn)
   389  
   390  	g.delegate.OnMessage(conn, int64(reqCtx.ID()), reqCtx.PostBody())
   391  
   392  	g.delegate.OnClose(conn)
   393  
   394  	releaseConnInfo(connInfo)
   395  	releaseHttpConn(conn)
   396  }
   397  
   398  func (g *Gateway) websocketHandler(c net.Conn, meta *connInfo) {
   399  	defer g.waitGroupAcceptors.Done()
   400  	if atomic.LoadInt32(&g.stop) == 1 {
   401  		return
   402  	}
   403  	if _, err := g.upgradeHandler.Upgrade(c); err != nil {
   404  		if ce := g.cfg.Logger.Check(log.InfoLevel, "got error in websocket upgrade"); ce != nil {
   405  			ce.Write(
   406  				zap.String("IP", tools.B2S(meta.clientIP)),
   407  				zap.String("ClientType", tools.B2S(meta.clientType)),
   408  				zap.Error(err),
   409  			)
   410  		}
   411  		_ = c.Close()
   412  
   413  		return
   414  	}
   415  
   416  	var (
   417  		err error
   418  	)
   419  
   420  	wsConn, err := newWebsocketConn(g, c, meta.clientIP)
   421  	if err != nil {
   422  		g.cfg.Logger.Warn("got error on creating websocket connection",
   423  			zap.Error(err),
   424  			zap.Int("Total", g.TotalConnections()),
   425  		)
   426  
   427  		return
   428  	}
   429  	for k, v := range meta.kvs {
   430  		wsConn.Set(k, v)
   431  	}
   432  
   433  	g.delegate.OnConnect(wsConn)
   434  
   435  	err = wsConn.registerDesc()
   436  	if err != nil {
   437  		g.cfg.Logger.Warn("got error in registering conn desc",
   438  			zap.Error(err),
   439  			zap.Any("Conn", wsConn.conn),
   440  		)
   441  	}
   442  }
   443  
   444  func (g *Gateway) websocketReadPump(wc *websocketConn, wg *sync.WaitGroup) (err error) {
   445  	var ms []wsutil.Message
   446  	ms, err = wc.read(ms)
   447  	if err != nil {
   448  		if ce := g.cfg.Logger.Check(log.DebugLevel, "got error in websocket read pump"); ce != nil {
   449  			ce.Write(
   450  				zap.Uint64("ConnID", wc.connID),
   451  				zap.Error(err),
   452  			)
   453  		}
   454  
   455  		return errors.Wrap(ErrUnexpectedSocketRead)(err)
   456  	}
   457  	atomic.AddUint64(&g.cntReads, 1)
   458  
   459  	// Handle messages
   460  	for idx := range ms {
   461  		switch ms[idx].OpCode {
   462  		case ws.OpPong:
   463  		case ws.OpPing:
   464  			err = wc.write(ws.OpPong, ms[idx].Payload)
   465  			pools.Bytes.Put(ms[idx].Payload)
   466  		case ws.OpBinary, ws.OpText:
   467  			wg.Add(1)
   468  			_ = goPoolB.Submit(
   469  				func(idx int) func() {
   470  					return func() {
   471  						metrics.IncCounter(metrics.CntGatewayIncomingWebsocketMessage)
   472  						g.delegate.OnMessage(wc, 0, ms[idx].Payload)
   473  						pools.Bytes.Put(ms[idx].Payload)
   474  						wg.Done()
   475  					}
   476  				}(idx),
   477  			)
   478  		case ws.OpClose:
   479  			// remove the connection from the list
   480  			err = ErrOpCloseReceived
   481  		default:
   482  			g.cfg.Logger.Warn("Unknown OpCode")
   483  		}
   484  	}
   485  
   486  	return err
   487  }
   488  
   489  func (g *Gateway) websocketWritePump(wr *writeRequest) (err error) {
   490  	defer g.waitGroupWriters.Done()
   491  
   492  	switch wr.opCode {
   493  	case ws.OpBinary, ws.OpText:
   494  		err = wr.wc.write(wr.opCode, wr.payload)
   495  		if err != nil {
   496  			if ce := g.cfg.Logger.Check(log.DebugLevel, "Error in websocketWritePump"); ce != nil {
   497  				ce.Write(zap.Error(err), zap.Uint64("ConnID", wr.wc.connID))
   498  			}
   499  		} else {
   500  			atomic.AddUint64(&g.cntWrites, 1)
   501  		}
   502  	}
   503  
   504  	return
   505  }
   506  
   507  func (g *Gateway) getConnection(connID uint64) *websocketConn {
   508  	g.connsMtx.RLock()
   509  	wsConn, ok := g.conns[connID]
   510  	g.connsMtx.RUnlock()
   511  	if ok {
   512  		return wsConn
   513  	}
   514  
   515  	return nil
   516  }