github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/stub/stub_ws.go (about) 1 package stub 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "strings" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/clubpay/ronykit/kit" 13 "github.com/clubpay/ronykit/kit/utils" 14 "github.com/clubpay/ronykit/kit/utils/reflector" 15 "github.com/fasthttp/websocket" 16 ) 17 18 type ( 19 Header map[string]string 20 RPCContainerHandler func(ctx context.Context, c kit.IncomingRPCContainer) 21 RPCMessageHandler func(ctx context.Context, msg kit.Message, hdr Header, err error) 22 ) 23 24 type RPCPreflightHandler func(req *WebsocketRequest) 25 26 type WebsocketCtx struct { 27 cfg wsConfig 28 r *reflector.Reflector 29 l kit.Logger 30 31 pendingMtx sync.Mutex 32 pending map[string]chan kit.IncomingRPCContainer 33 lastActivity uint32 34 disconnect bool 35 36 // fasthttp entities 37 url string 38 cMtx sync.Mutex 39 c *websocket.Conn 40 41 // stats 42 writeBytesTotal uint64 43 writeBytes uint64 44 readBytesTotal uint64 45 readBytes uint64 46 } 47 48 func (wCtx *WebsocketCtx) Connect(ctx context.Context, path string) error { 49 path = strings.TrimLeft(path, "/") 50 if path != "" { 51 wCtx.url = fmt.Sprintf("%s/%s", wCtx.url, path) 52 } 53 54 return wCtx.connect(ctx) 55 } 56 57 func (wCtx *WebsocketCtx) connect(ctx context.Context) error { 58 wCtx.l.Debugf("connect: %s", wCtx.url) 59 60 d := wCtx.cfg.dialerBuilder() 61 if f := wCtx.cfg.preDial; f != nil { 62 f(d) 63 } 64 c, rsp, err := d.DialContext(ctx, wCtx.url, wCtx.cfg.upgradeHdr) 65 if err != nil { 66 return err 67 } 68 _ = rsp.Body.Close() 69 70 wCtx.setActivity() 71 c.SetPongHandler( 72 func(appData string) error { 73 wCtx.l.Debugf("websocket pong received") 74 wCtx.setActivity() 75 76 return nil 77 }, 78 ) 79 _ = c.SetCompressionLevel(wCtx.cfg.compressLevel) 80 81 wCtx.c = c 82 wCtx.writeBytes = 0 83 wCtx.readBytes = 0 84 85 // run receiver & watchdog in the background 86 go wCtx.receiver(c) //nolint:contextcheck 87 go wCtx.watchdog(c) //nolint:contextcheck 88 89 if f := wCtx.cfg.onConnect; f != nil { 90 f(wCtx) 91 } 92 93 return nil 94 } 95 96 func (wCtx *WebsocketCtx) Disconnect() error { 97 wCtx.disconnect = true 98 99 return wCtx.c.Close() 100 } 101 102 func (wCtx *WebsocketCtx) setActivity() { 103 atomic.StoreUint32(&wCtx.lastActivity, uint32(utils.TimeUnix())) 104 } 105 106 func (wCtx *WebsocketCtx) getActivity() int64 { 107 return int64(atomic.LoadUint32(&wCtx.lastActivity)) 108 } 109 110 func (wCtx *WebsocketCtx) watchdog(c *websocket.Conn) { 111 wCtx.l.Debugf("watchdog started: %s", c.LocalAddr().String()) 112 113 t := time.NewTicker(wCtx.cfg.pingTime) 114 d := int64(wCtx.cfg.pingTime/time.Second) * 2 115 for range t.C { 116 if wCtx.disconnect { 117 wCtx.l.Debugf("going to disconnect: %s", c.LocalAddr().String()) 118 119 _ = c.Close() 120 121 return 122 } 123 124 if utils.TimeUnix()-wCtx.getActivity() <= d { 125 wCtx.cMtx.Lock() 126 _ = c.WriteControl(websocket.PingMessage, nil, time.Now().Add(wCtx.cfg.writeTimeout)) 127 wCtx.cMtx.Unlock() 128 wCtx.l.Debugf("websocket ping sent") 129 130 continue 131 } 132 133 if !wCtx.cfg.autoReconnect { 134 return 135 } 136 137 wCtx.l.Errorf("inactivity detected, reconnecting: %s", c.LocalAddr().String()) 138 _ = c.Close() 139 140 ctx, cf := context.WithTimeout(context.Background(), wCtx.cfg.dialTimeout) 141 err := wCtx.connect(ctx) 142 cf() 143 if err != nil { 144 wCtx.l.Errorf("failed to reconnect: %s", err) 145 146 continue 147 } 148 149 return 150 } 151 } 152 153 func (wCtx *WebsocketCtx) receiver(c *websocket.Conn) { 154 for { 155 _, p, err := c.ReadMessage() 156 if err != nil || len(p) == 0 { 157 wCtx.l.Debugf("receiver shutdown: %s: %v", c.LocalAddr().String(), err) 158 159 return 160 } 161 162 wCtx.readBytesTotal += uint64(len(p)) 163 wCtx.readBytes += uint64(len(p)) 164 wCtx.setActivity() 165 166 rpcIn := wCtx.cfg.rpcInFactory() 167 err = rpcIn.Unmarshal(p) 168 if err != nil { 169 wCtx.l.Debugf("received unexpected message: %v", err) 170 171 continue 172 } 173 174 // if this is a reply message we return it to the pending channel 175 wCtx.pendingMtx.Lock() 176 ch, ok := wCtx.pending[rpcIn.GetID()] 177 wCtx.pendingMtx.Unlock() 178 179 if ok { 180 ch <- rpcIn 181 182 continue 183 } 184 185 ctx := context.Background() 186 if tp := wCtx.cfg.tracePropagator; tp != nil { 187 ctx = tp.Extract(ctx, containerTraceCarrier{in: rpcIn}) 188 } 189 190 h, ok := wCtx.cfg.handlers[rpcIn.GetHdr(wCtx.cfg.predicateKey)] 191 if !ok { 192 h = wCtx.cfg.defaultHandler 193 } 194 195 if h == nil { 196 rpcIn.Release() 197 198 continue 199 } 200 201 select { 202 default: 203 wCtx.l.Errorf("ratelimit reached, packet dropped") 204 case wCtx.cfg.ratelimitChan <- struct{}{}: 205 wCtx.cfg.handlersWG.Add(1) 206 go func(ctx context.Context, rpcIn kit.IncomingRPCContainer) { 207 defer wCtx.recoverPanic() 208 209 h(ctx, rpcIn) 210 <-wCtx.cfg.ratelimitChan 211 wCtx.cfg.handlersWG.Done() 212 rpcIn.Release() 213 }(ctx, rpcIn) 214 } 215 } 216 } 217 218 func (wCtx *WebsocketCtx) recoverPanic() { 219 if r := recover(); r != nil { 220 wCtx.l.Errorf("panic recovered: %v", r) 221 222 if wCtx.cfg.panicRecoverFunc != nil { 223 wCtx.cfg.panicRecoverFunc(r) 224 } 225 } 226 } 227 228 func (wCtx *WebsocketCtx) TextMessage( 229 ctx context.Context, predicate string, req, res kit.Message, 230 cb RPCMessageHandler, 231 ) error { 232 return wCtx.Do( 233 ctx, 234 WebsocketRequest{ 235 Predicate: predicate, 236 MessageType: websocket.TextMessage, 237 ReqMsg: req, 238 ResMsg: res, 239 ReqHdr: nil, 240 Callback: cb, 241 }, 242 ) 243 } 244 245 func (wCtx *WebsocketCtx) BinaryMessage( 246 ctx context.Context, predicate string, req, res kit.Message, 247 cb RPCMessageHandler, 248 ) error { 249 return wCtx.Do( 250 ctx, 251 WebsocketRequest{ 252 Predicate: predicate, 253 MessageType: websocket.BinaryMessage, 254 ReqMsg: req, 255 ResMsg: res, 256 ReqHdr: nil, 257 Callback: cb, 258 }, 259 ) 260 } 261 262 // NetConn returns the underlying net.Conn, ONLY for advanced use cases 263 func (wCtx *WebsocketCtx) NetConn() net.Conn { 264 return wCtx.c.NetConn() 265 } 266 267 type WebsocketStats struct { 268 // ReadBytes is the total number of bytes read from the current websocket connection 269 ReadBytes uint64 270 // ReadBytesTotal is the total number of bytes read since WebsocketCtx creation 271 ReadBytesTotal uint64 272 // WriteBytes is the total number of bytes written to the current websocket connection 273 WriteBytes uint64 274 // WriteBytesTotal is the total number of bytes written since WebsocketCtx creation 275 WriteBytesTotal uint64 276 } 277 278 func (wCtx *WebsocketCtx) Stats() WebsocketStats { 279 wCtx.cMtx.Lock() 280 defer wCtx.cMtx.Unlock() 281 282 return WebsocketStats{ 283 ReadBytes: wCtx.readBytes, 284 ReadBytesTotal: wCtx.readBytesTotal, 285 WriteBytes: wCtx.writeBytes, 286 WriteBytesTotal: wCtx.writeBytesTotal, 287 } 288 } 289 290 type WebsocketRequest struct { 291 // ID is optional, if you don't set it, a random string will be generated 292 ID string 293 // Predicate is the routing key for the message, which will be added to the kit.OutgoingRPCContainer 294 Predicate string 295 // MessageType is the type of the message, either websocket.TextMessage or websocket.BinaryMessage 296 MessageType int 297 ReqMsg kit.Message 298 // ResMsg is the message that will be used to unmarshal the response. 299 // You should pass a pointer to the struct that you want to unmarshal the response into. 300 // If Callback is nil, then this field will be ignored. 301 ResMsg kit.Message 302 // ReqHdr is the headers that will be added to the kit.OutgoingRPCContainer 303 ReqHdr Header 304 // Callback is the callback that will be called when the response is received. 305 // If this is nil, the response will be ignored. However, the response will be caught by 306 // the default handler if it is set. 307 Callback RPCMessageHandler 308 } 309 310 // Do send a message to the websocket server and waits for the response. If the callback 311 // is not nil, then make sure you provide a context with deadline or timeout, otherwise 312 // you will leak goroutines. 313 func (wCtx *WebsocketCtx) Do(ctx context.Context, req WebsocketRequest) error { 314 // run preflights 315 for _, pre := range wCtx.cfg.preflights { 316 pre(&req) 317 } 318 319 outC := wCtx.cfg.rpcOutFactory() 320 if req.ID == "" { 321 req.ID = utils.RandomDigit(10) 322 } 323 outC.InjectMessage(req.ReqMsg) 324 outC.SetHdr(wCtx.cfg.predicateKey, req.Predicate) 325 if tp := wCtx.cfg.tracePropagator; tp != nil { 326 tp.Inject(ctx, containerTraceCarrier{out: outC}) 327 } 328 for k, v := range req.ReqHdr { 329 outC.SetHdr(k, v) 330 } 331 outC.SetID(req.ID) 332 333 reqData, err := outC.Marshal() 334 if err != nil { 335 return err 336 } 337 338 wCtx.cMtx.Lock() 339 wCtx.writeBytesTotal += uint64(len(reqData)) 340 wCtx.writeBytes += uint64(len(reqData)) 341 err = wCtx.c.WriteMessage(req.MessageType, reqData) 342 wCtx.cMtx.Unlock() 343 if err != nil { 344 return err 345 } 346 347 outC.Release() 348 349 if req.Callback != nil { 350 go wCtx.waitForMessage(ctx, req.ID, req.ResMsg, req.Callback) 351 } 352 353 return nil 354 } 355 356 func (wCtx *WebsocketCtx) waitForMessage( 357 ctx context.Context, id string, res kit.Message, cb RPCMessageHandler, 358 ) { 359 resCh := make(chan kit.IncomingRPCContainer, 1) 360 wCtx.pendingMtx.Lock() 361 wCtx.pending[id] = resCh 362 wCtx.pendingMtx.Unlock() 363 364 select { 365 case c := <-resCh: 366 err := c.ExtractMessage(res) 367 cb(ctx, res, c.GetHdrMap(), err) 368 369 case <-ctx.Done(): 370 } 371 372 wCtx.pendingMtx.Lock() 373 delete(wCtx.pending, id) 374 wCtx.pendingMtx.Unlock() 375 } 376 377 type containerTraceCarrier struct { 378 out kit.OutgoingRPCContainer 379 in kit.IncomingRPCContainer 380 } 381 382 func (c containerTraceCarrier) Get(key string) string { 383 return c.in.GetHdr(key) 384 } 385 386 func (c containerTraceCarrier) Set(key string, value string) { 387 c.out.SetHdr(key, value) 388 } 389 390 var ( 391 ErrBadHandshake = websocket.ErrBadHandshake 392 _ = ErrBadHandshake 393 )