github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/stub/stub_ws.go (about)

     1  package stub
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/clubpay/ronykit/kit"
    13  	"github.com/clubpay/ronykit/kit/utils"
    14  	"github.com/clubpay/ronykit/kit/utils/reflector"
    15  	"github.com/fasthttp/websocket"
    16  )
    17  
    18  type (
    19  	Header              map[string]string
    20  	RPCContainerHandler func(ctx context.Context, c kit.IncomingRPCContainer)
    21  	RPCMessageHandler   func(ctx context.Context, msg kit.Message, hdr Header, err error)
    22  )
    23  
    24  type RPCPreflightHandler func(req *WebsocketRequest)
    25  
    26  type WebsocketCtx struct {
    27  	cfg wsConfig
    28  	r   *reflector.Reflector
    29  	l   kit.Logger
    30  
    31  	pendingMtx   sync.Mutex
    32  	pending      map[string]chan kit.IncomingRPCContainer
    33  	lastActivity uint32
    34  	disconnect   bool
    35  
    36  	// fasthttp entities
    37  	url  string
    38  	cMtx sync.Mutex
    39  	c    *websocket.Conn
    40  
    41  	// stats
    42  	writeBytesTotal uint64
    43  	writeBytes      uint64
    44  	readBytesTotal  uint64
    45  	readBytes       uint64
    46  }
    47  
    48  func (wCtx *WebsocketCtx) Connect(ctx context.Context, path string) error {
    49  	path = strings.TrimLeft(path, "/")
    50  	if path != "" {
    51  		wCtx.url = fmt.Sprintf("%s/%s", wCtx.url, path)
    52  	}
    53  
    54  	return wCtx.connect(ctx)
    55  }
    56  
    57  func (wCtx *WebsocketCtx) connect(ctx context.Context) error {
    58  	wCtx.l.Debugf("connect: %s", wCtx.url)
    59  
    60  	d := wCtx.cfg.dialerBuilder()
    61  	if f := wCtx.cfg.preDial; f != nil {
    62  		f(d)
    63  	}
    64  	c, rsp, err := d.DialContext(ctx, wCtx.url, wCtx.cfg.upgradeHdr)
    65  	if err != nil {
    66  		return err
    67  	}
    68  	_ = rsp.Body.Close()
    69  
    70  	wCtx.setActivity()
    71  	c.SetPongHandler(
    72  		func(appData string) error {
    73  			wCtx.l.Debugf("websocket pong received")
    74  			wCtx.setActivity()
    75  
    76  			return nil
    77  		},
    78  	)
    79  	_ = c.SetCompressionLevel(wCtx.cfg.compressLevel)
    80  
    81  	wCtx.c = c
    82  	wCtx.writeBytes = 0
    83  	wCtx.readBytes = 0
    84  
    85  	// run receiver & watchdog in the background
    86  	go wCtx.receiver(c) //nolint:contextcheck
    87  	go wCtx.watchdog(c) //nolint:contextcheck
    88  
    89  	if f := wCtx.cfg.onConnect; f != nil {
    90  		f(wCtx)
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func (wCtx *WebsocketCtx) Disconnect() error {
    97  	wCtx.disconnect = true
    98  
    99  	return wCtx.c.Close()
   100  }
   101  
   102  func (wCtx *WebsocketCtx) setActivity() {
   103  	atomic.StoreUint32(&wCtx.lastActivity, uint32(utils.TimeUnix()))
   104  }
   105  
   106  func (wCtx *WebsocketCtx) getActivity() int64 {
   107  	return int64(atomic.LoadUint32(&wCtx.lastActivity))
   108  }
   109  
   110  func (wCtx *WebsocketCtx) watchdog(c *websocket.Conn) {
   111  	wCtx.l.Debugf("watchdog started: %s", c.LocalAddr().String())
   112  
   113  	t := time.NewTicker(wCtx.cfg.pingTime)
   114  	d := int64(wCtx.cfg.pingTime/time.Second) * 2
   115  	for range t.C {
   116  		if wCtx.disconnect {
   117  			wCtx.l.Debugf("going to disconnect: %s", c.LocalAddr().String())
   118  
   119  			_ = c.Close()
   120  
   121  			return
   122  		}
   123  
   124  		if utils.TimeUnix()-wCtx.getActivity() <= d {
   125  			wCtx.cMtx.Lock()
   126  			_ = c.WriteControl(websocket.PingMessage, nil, time.Now().Add(wCtx.cfg.writeTimeout))
   127  			wCtx.cMtx.Unlock()
   128  			wCtx.l.Debugf("websocket ping sent")
   129  
   130  			continue
   131  		}
   132  
   133  		if !wCtx.cfg.autoReconnect {
   134  			return
   135  		}
   136  
   137  		wCtx.l.Errorf("inactivity detected, reconnecting: %s", c.LocalAddr().String())
   138  		_ = c.Close()
   139  
   140  		ctx, cf := context.WithTimeout(context.Background(), wCtx.cfg.dialTimeout)
   141  		err := wCtx.connect(ctx)
   142  		cf()
   143  		if err != nil {
   144  			wCtx.l.Errorf("failed to reconnect: %s", err)
   145  
   146  			continue
   147  		}
   148  
   149  		return
   150  	}
   151  }
   152  
   153  func (wCtx *WebsocketCtx) receiver(c *websocket.Conn) {
   154  	for {
   155  		_, p, err := c.ReadMessage()
   156  		if err != nil || len(p) == 0 {
   157  			wCtx.l.Debugf("receiver shutdown: %s: %v", c.LocalAddr().String(), err)
   158  
   159  			return
   160  		}
   161  
   162  		wCtx.readBytesTotal += uint64(len(p))
   163  		wCtx.readBytes += uint64(len(p))
   164  		wCtx.setActivity()
   165  
   166  		rpcIn := wCtx.cfg.rpcInFactory()
   167  		err = rpcIn.Unmarshal(p)
   168  		if err != nil {
   169  			wCtx.l.Debugf("received unexpected message: %v", err)
   170  
   171  			continue
   172  		}
   173  
   174  		// if this is a reply message we return it to the pending channel
   175  		wCtx.pendingMtx.Lock()
   176  		ch, ok := wCtx.pending[rpcIn.GetID()]
   177  		wCtx.pendingMtx.Unlock()
   178  
   179  		if ok {
   180  			ch <- rpcIn
   181  
   182  			continue
   183  		}
   184  
   185  		ctx := context.Background()
   186  		if tp := wCtx.cfg.tracePropagator; tp != nil {
   187  			ctx = tp.Extract(ctx, containerTraceCarrier{in: rpcIn})
   188  		}
   189  
   190  		h, ok := wCtx.cfg.handlers[rpcIn.GetHdr(wCtx.cfg.predicateKey)]
   191  		if !ok {
   192  			h = wCtx.cfg.defaultHandler
   193  		}
   194  
   195  		if h == nil {
   196  			rpcIn.Release()
   197  
   198  			continue
   199  		}
   200  
   201  		select {
   202  		default:
   203  			wCtx.l.Errorf("ratelimit reached, packet dropped")
   204  		case wCtx.cfg.ratelimitChan <- struct{}{}:
   205  			wCtx.cfg.handlersWG.Add(1)
   206  			go func(ctx context.Context, rpcIn kit.IncomingRPCContainer) {
   207  				defer wCtx.recoverPanic()
   208  
   209  				h(ctx, rpcIn)
   210  				<-wCtx.cfg.ratelimitChan
   211  				wCtx.cfg.handlersWG.Done()
   212  				rpcIn.Release()
   213  			}(ctx, rpcIn)
   214  		}
   215  	}
   216  }
   217  
   218  func (wCtx *WebsocketCtx) recoverPanic() {
   219  	if r := recover(); r != nil {
   220  		wCtx.l.Errorf("panic recovered: %v", r)
   221  
   222  		if wCtx.cfg.panicRecoverFunc != nil {
   223  			wCtx.cfg.panicRecoverFunc(r)
   224  		}
   225  	}
   226  }
   227  
   228  func (wCtx *WebsocketCtx) TextMessage(
   229  	ctx context.Context, predicate string, req, res kit.Message,
   230  	cb RPCMessageHandler,
   231  ) error {
   232  	return wCtx.Do(
   233  		ctx,
   234  		WebsocketRequest{
   235  			Predicate:   predicate,
   236  			MessageType: websocket.TextMessage,
   237  			ReqMsg:      req,
   238  			ResMsg:      res,
   239  			ReqHdr:      nil,
   240  			Callback:    cb,
   241  		},
   242  	)
   243  }
   244  
   245  func (wCtx *WebsocketCtx) BinaryMessage(
   246  	ctx context.Context, predicate string, req, res kit.Message,
   247  	cb RPCMessageHandler,
   248  ) error {
   249  	return wCtx.Do(
   250  		ctx,
   251  		WebsocketRequest{
   252  			Predicate:   predicate,
   253  			MessageType: websocket.BinaryMessage,
   254  			ReqMsg:      req,
   255  			ResMsg:      res,
   256  			ReqHdr:      nil,
   257  			Callback:    cb,
   258  		},
   259  	)
   260  }
   261  
   262  // NetConn returns the underlying net.Conn, ONLY for advanced use cases
   263  func (wCtx *WebsocketCtx) NetConn() net.Conn {
   264  	return wCtx.c.NetConn()
   265  }
   266  
   267  type WebsocketStats struct {
   268  	// ReadBytes is the total number of bytes read from the current websocket connection
   269  	ReadBytes uint64
   270  	// ReadBytesTotal is the total number of bytes read since WebsocketCtx creation
   271  	ReadBytesTotal uint64
   272  	// WriteBytes is the total number of bytes written to the current websocket connection
   273  	WriteBytes uint64
   274  	// WriteBytesTotal is the total number of bytes written since WebsocketCtx creation
   275  	WriteBytesTotal uint64
   276  }
   277  
   278  func (wCtx *WebsocketCtx) Stats() WebsocketStats {
   279  	wCtx.cMtx.Lock()
   280  	defer wCtx.cMtx.Unlock()
   281  
   282  	return WebsocketStats{
   283  		ReadBytes:       wCtx.readBytes,
   284  		ReadBytesTotal:  wCtx.readBytesTotal,
   285  		WriteBytes:      wCtx.writeBytes,
   286  		WriteBytesTotal: wCtx.writeBytesTotal,
   287  	}
   288  }
   289  
   290  type WebsocketRequest struct {
   291  	// ID is optional, if you don't set it, a random string will be generated
   292  	ID string
   293  	// Predicate is the routing key for the message, which will be added to the kit.OutgoingRPCContainer
   294  	Predicate string
   295  	// MessageType is the type of the message, either websocket.TextMessage or websocket.BinaryMessage
   296  	MessageType int
   297  	ReqMsg      kit.Message
   298  	// ResMsg is the message that will be used to unmarshal the response.
   299  	// You should pass a pointer to the struct that you want to unmarshal the response into.
   300  	// If Callback is nil, then this field will be ignored.
   301  	ResMsg kit.Message
   302  	// ReqHdr is the headers that will be added to the kit.OutgoingRPCContainer
   303  	ReqHdr Header
   304  	// Callback is the callback that will be called when the response is received.
   305  	// If this is nil, the response will be ignored. However, the response will be caught by
   306  	// the default handler if it is set.
   307  	Callback RPCMessageHandler
   308  }
   309  
   310  // Do send a message to the websocket server and waits for the response. If the callback
   311  // is not nil, then make sure you provide a context with deadline or timeout, otherwise
   312  // you will leak goroutines.
   313  func (wCtx *WebsocketCtx) Do(ctx context.Context, req WebsocketRequest) error {
   314  	// run preflights
   315  	for _, pre := range wCtx.cfg.preflights {
   316  		pre(&req)
   317  	}
   318  
   319  	outC := wCtx.cfg.rpcOutFactory()
   320  	if req.ID == "" {
   321  		req.ID = utils.RandomDigit(10)
   322  	}
   323  	outC.InjectMessage(req.ReqMsg)
   324  	outC.SetHdr(wCtx.cfg.predicateKey, req.Predicate)
   325  	if tp := wCtx.cfg.tracePropagator; tp != nil {
   326  		tp.Inject(ctx, containerTraceCarrier{out: outC})
   327  	}
   328  	for k, v := range req.ReqHdr {
   329  		outC.SetHdr(k, v)
   330  	}
   331  	outC.SetID(req.ID)
   332  
   333  	reqData, err := outC.Marshal()
   334  	if err != nil {
   335  		return err
   336  	}
   337  
   338  	wCtx.cMtx.Lock()
   339  	wCtx.writeBytesTotal += uint64(len(reqData))
   340  	wCtx.writeBytes += uint64(len(reqData))
   341  	err = wCtx.c.WriteMessage(req.MessageType, reqData)
   342  	wCtx.cMtx.Unlock()
   343  	if err != nil {
   344  		return err
   345  	}
   346  
   347  	outC.Release()
   348  
   349  	if req.Callback != nil {
   350  		go wCtx.waitForMessage(ctx, req.ID, req.ResMsg, req.Callback)
   351  	}
   352  
   353  	return nil
   354  }
   355  
   356  func (wCtx *WebsocketCtx) waitForMessage(
   357  	ctx context.Context, id string, res kit.Message, cb RPCMessageHandler,
   358  ) {
   359  	resCh := make(chan kit.IncomingRPCContainer, 1)
   360  	wCtx.pendingMtx.Lock()
   361  	wCtx.pending[id] = resCh
   362  	wCtx.pendingMtx.Unlock()
   363  
   364  	select {
   365  	case c := <-resCh:
   366  		err := c.ExtractMessage(res)
   367  		cb(ctx, res, c.GetHdrMap(), err)
   368  
   369  	case <-ctx.Done():
   370  	}
   371  
   372  	wCtx.pendingMtx.Lock()
   373  	delete(wCtx.pending, id)
   374  	wCtx.pendingMtx.Unlock()
   375  }
   376  
   377  type containerTraceCarrier struct {
   378  	out kit.OutgoingRPCContainer
   379  	in  kit.IncomingRPCContainer
   380  }
   381  
   382  func (c containerTraceCarrier) Get(key string) string {
   383  	return c.in.GetHdr(key)
   384  }
   385  
   386  func (c containerTraceCarrier) Set(key string, value string) {
   387  	c.out.SetHdr(key, value)
   388  }
   389  
   390  var (
   391  	ErrBadHandshake = websocket.ErrBadHandshake
   392  	_               = ErrBadHandshake
   393  )