github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/gateway/tcp/conn_ws.go (about)

     1  //go:build !windows && !appengine
     2  // +build !windows,!appengine
     3  
     4  package tcpGateway
     5  
     6  import (
     7  	"encoding/binary"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/allegro/bigcache/v2"
    14  	"github.com/gobwas/ws"
    15  	"github.com/mailru/easygo/netpoll"
    16  	wsutil "github.com/ronaksoft/rony/internal/gateway/tcp/util"
    17  	"github.com/ronaksoft/rony/internal/metrics"
    18  	"github.com/ronaksoft/rony/log"
    19  	"github.com/ronaksoft/rony/pools"
    20  	"github.com/ronaksoft/rony/tools"
    21  	"go.uber.org/zap"
    22  )
    23  
    24  /*
    25     Creation Time: 2019 - Feb - 28
    26     Created by:  (ehsan)
    27     Maintainers:
    28        1.  Ehsan N. Moosa (E2)
    29     Auditor: Ehsan N. Moosa (E2)
    30     Copyright Ronak Software Group 2020
    31  */
    32  
    33  // websocketConn
    34  type websocketConn struct {
    35  	mtx      sync.Mutex
    36  	connID   uint64
    37  	clientIP []byte
    38  
    39  	// KV Store
    40  	kvLock tools.SpinLock
    41  	kv     map[string]interface{}
    42  
    43  	// Internals
    44  	gateway      *Gateway
    45  	lastActivity int64
    46  	conn         net.Conn
    47  	desc         *netpoll.Desc
    48  	closed       bool
    49  	startTime    int64
    50  }
    51  
    52  func newWebsocketConn(g *Gateway, conn net.Conn, clientIP []byte) (*websocketConn, error) {
    53  	desc, err := netpoll.Handle(conn,
    54  		netpoll.EventRead|netpoll.EventHup|netpoll.EventOneShot,
    55  	)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	// Increment total connection counter and connection ID
    61  	totalConns := atomic.AddInt32(&g.connsTotal, 1)
    62  	connID := atomic.AddUint64(&g.connsLastID, 1)
    63  	wsConn := &websocketConn{
    64  		connID:       connID,
    65  		gateway:      g,
    66  		conn:         conn,
    67  		desc:         desc,
    68  		closed:       false,
    69  		kv:           make(map[string]interface{}, 4),
    70  		lastActivity: tools.CPUTicks(),
    71  	}
    72  	wsConn.SetClientIP(clientIP)
    73  
    74  	g.connsMtx.Lock()
    75  	g.conns[connID] = wsConn
    76  	g.connsMtx.Unlock()
    77  	if ce := g.cfg.Logger.Check(log.DebugLevel, "websocket connection created"); ce != nil {
    78  		ce.Write(
    79  			zap.Uint64("ConnID", connID),
    80  			zap.String("IP", wsConn.ClientIP()),
    81  			zap.Int32("Total", totalConns),
    82  		)
    83  	}
    84  	g.connGC.monitorConnection(connID)
    85  
    86  	return wsConn, nil
    87  }
    88  
    89  func (wc *websocketConn) registerDesc() error {
    90  	atomic.StoreInt64(&wc.startTime, tools.CPUTicks())
    91  	err := wc.gateway.poller.Start(wc.desc, wc.startEvent)
    92  	if err != nil {
    93  		if err != netpoll.ErrRegistered {
    94  			wc.release(1)
    95  
    96  			return err
    97  		}
    98  		_ = wc.gateway.poller.Stop(wc.desc)
    99  		err = wc.gateway.poller.Start(wc.desc, wc.startEvent)
   100  		if err != nil {
   101  			wc.release(1)
   102  
   103  			return err
   104  		}
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func (wc *websocketConn) release(_ int) {
   111  	// delete the reference from the gateway's conns
   112  	g := wc.gateway
   113  	g.connsMtx.Lock()
   114  	_, ok := g.conns[wc.connID]
   115  	if !ok {
   116  		g.connsMtx.Unlock()
   117  
   118  		return
   119  	}
   120  	delete(g.conns, wc.connID)
   121  	g.connsMtx.Unlock()
   122  
   123  	// Decrease the total connection counter
   124  	totalConns := atomic.AddInt32(&g.connsTotal, -1)
   125  
   126  	if ce := g.cfg.Logger.Check(log.DebugLevel, "websocket connection removed"); ce != nil {
   127  		ce.Write(
   128  			zap.Uint64("ConnID", wc.connID),
   129  			zap.Int32("Total", totalConns),
   130  		)
   131  	}
   132  
   133  	wc.mtx.Lock()
   134  	if wc.desc != nil {
   135  		_ = wc.gateway.poller.Stop(wc.desc)
   136  		err := wc.desc.Close()
   137  		if err != nil {
   138  			if ce := g.cfg.Logger.Check(log.DebugLevel, "got error on closing desc"); ce != nil {
   139  				ce.Write(zap.Error(err))
   140  			}
   141  		}
   142  	}
   143  	_ = wc.conn.Close()
   144  
   145  	if !wc.closed {
   146  		g.delegate.OnClose(wc)
   147  		wc.closed = true
   148  		wc.conn = nil
   149  	}
   150  
   151  	wc.mtx.Unlock()
   152  }
   153  
   154  func (wc *websocketConn) startEvent(event netpoll.Event) {
   155  	if atomic.LoadInt32(&wc.gateway.stop) == 1 {
   156  		return
   157  	}
   158  
   159  	if event&netpoll.EventReadHup != 0 {
   160  		wc.release(2)
   161  
   162  		return
   163  	}
   164  
   165  	if event&netpoll.EventRead != 0 {
   166  		atomic.StoreInt64(&wc.lastActivity, tools.CPUTicks())
   167  		wc.gateway.waitGroupReaders.Add(1)
   168  
   169  		err := goPoolNB.Submit(
   170  			func() {
   171  				waitGroup := pools.AcquireWaitGroup()
   172  				err := wc.gateway.websocketReadPump(wc, waitGroup)
   173  				if err != nil {
   174  					wc.release(3)
   175  				} else {
   176  					_ = wc.gateway.poller.Resume(wc.desc)
   177  				}
   178  				waitGroup.Wait()
   179  				pools.ReleaseWaitGroup(waitGroup)
   180  				wc.gateway.waitGroupReaders.Done()
   181  			},
   182  		)
   183  		if err != nil {
   184  			wc.gateway.cfg.Logger.Warn("got error on start event go-routine pool", zap.Error(err))
   185  		}
   186  	}
   187  }
   188  
   189  func (wc *websocketConn) read(ms []wsutil.Message) ([]wsutil.Message, error) {
   190  	var err error
   191  	wc.mtx.Lock()
   192  	if wc.conn != nil {
   193  		_ = wc.conn.SetReadDeadline(time.Now().Add(defaultReadTimout))
   194  		ms, err = wsutil.ReadMessage(wc.conn, ws.StateServerSide, ms)
   195  	} else {
   196  		err = ErrConnectionClosed
   197  	}
   198  	wc.mtx.Unlock()
   199  
   200  	return ms, err
   201  }
   202  
   203  func (wc *websocketConn) write(opCode ws.OpCode, payload []byte) (err error) {
   204  	wc.mtx.Lock()
   205  	if wc.conn != nil {
   206  		_ = wc.conn.SetWriteDeadline(time.Now().Add(defaultWriteTimeout))
   207  		err = wsutil.WriteMessage(wc.conn, ws.StateServerSide, opCode, payload)
   208  	} else {
   209  		err = ErrWriteToClosedConn
   210  	}
   211  	wc.mtx.Unlock()
   212  
   213  	return
   214  }
   215  
   216  func (wc *websocketConn) Get(key string) interface{} {
   217  	wc.kvLock.Lock()
   218  	v := wc.kv[key]
   219  	wc.kvLock.Unlock()
   220  
   221  	return v
   222  }
   223  
   224  func (wc *websocketConn) Set(key string, val interface{}) {
   225  	wc.kvLock.Lock()
   226  	wc.kv[key] = val
   227  	wc.kvLock.Unlock()
   228  }
   229  
   230  func (wc *websocketConn) Walk(f func(k string, v interface{}) bool) {
   231  	wc.kvLock.Lock()
   232  	defer wc.kvLock.Unlock()
   233  
   234  	for k, v := range wc.kv {
   235  		if !f(k, v) {
   236  			return
   237  		}
   238  	}
   239  }
   240  
   241  func (wc *websocketConn) ConnID() uint64 {
   242  	return atomic.LoadUint64(&wc.connID)
   243  }
   244  
   245  func (wc *websocketConn) ClientIP() string {
   246  	return string(wc.clientIP)
   247  }
   248  
   249  func (wc *websocketConn) SetClientIP(ip []byte) {
   250  	wc.clientIP = append(wc.clientIP[:0], ip...)
   251  }
   252  
   253  // WriteBinary
   254  // Make sure you don't use payload after calling this function, because its underlying
   255  // array will be put back into the pool to be reused.
   256  func (wc *websocketConn) WriteBinary(streamID int64, payload []byte) error {
   257  	if wc == nil || wc.closed {
   258  		return ErrWriteToClosedConn
   259  	}
   260  	wc.gateway.waitGroupWriters.Add(1)
   261  
   262  	opCode := ws.OpBinary
   263  	if wc.gateway.cfg.TextDataFrame {
   264  		opCode = ws.OpText
   265  	}
   266  	wr := acquireWriteRequest(wc, opCode)
   267  	wr.CopyPayload(payload)
   268  	err := wc.gateway.websocketWritePump(wr)
   269  	if err != nil {
   270  		wc.release(4)
   271  	}
   272  	releaseWriteRequest(wr)
   273  	metrics.IncCounter(metrics.CntGatewayOutgoingWebsocketMessage)
   274  
   275  	return nil
   276  }
   277  
   278  func (wc *websocketConn) Disconnect() {
   279  	wc.release(5)
   280  }
   281  
   282  func (wc *websocketConn) Persistent() bool {
   283  	return true
   284  }
   285  
   286  // writeRequest
   287  type writeRequest struct {
   288  	wc      *websocketConn
   289  	opCode  ws.OpCode
   290  	payload []byte
   291  }
   292  
   293  func (wr *writeRequest) CopyPayload(p []byte) {
   294  	wr.payload = append(wr.payload[:0], p...)
   295  }
   296  
   297  // websocketConnGC the garbage collector of the stalled websocket connections
   298  type websocketConnGC struct {
   299  	bg     *bigcache.BigCache
   300  	gw     *Gateway
   301  	inChan chan uint64
   302  }
   303  
   304  func newWebsocketConnGC(gw *Gateway) *websocketConnGC {
   305  	gc := &websocketConnGC{
   306  		gw:     gw,
   307  		inChan: make(chan uint64, 1000),
   308  	}
   309  	bgConf := bigcache.DefaultConfig(time.Duration(gw.maxIdleTime))
   310  	bgConf.CleanWindow = time.Second
   311  	bgConf.Verbose = false
   312  	bgConf.OnRemoveWithReason = gc.onRemove
   313  	bgConf.Shards = 128
   314  	bgConf.MaxEntrySize = 8
   315  	bgConf.MaxEntriesInWindow = 100000
   316  	gc.bg, _ = bigcache.NewBigCache(bgConf)
   317  
   318  	// background job for receiving connIDs
   319  	go func() {
   320  		b := make([]byte, 8)
   321  		for connID := range gc.inChan {
   322  			binary.BigEndian.PutUint64(b, connID)
   323  			_ = gc.bg.Set(tools.ByteToStr(b), b)
   324  		}
   325  	}()
   326  
   327  	return gc
   328  }
   329  
   330  func (gc *websocketConnGC) onRemove(key string, entry []byte, reason bigcache.RemoveReason) {
   331  	switch reason {
   332  	case bigcache.Expired:
   333  		connID := binary.BigEndian.Uint64(entry)
   334  		if wsConn := gc.gw.getConnection(connID); wsConn != nil {
   335  			if tools.CPUTicks()-atomic.LoadInt64(&wsConn.lastActivity) > gc.gw.maxIdleTime {
   336  				wsConn.release(6)
   337  			} else {
   338  				gc.monitorConnection(connID)
   339  			}
   340  		}
   341  	}
   342  }
   343  
   344  func (gc *websocketConnGC) monitorConnection(connID uint64) {
   345  	gc.inChan <- connID
   346  }