github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/websocket/connection.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "net" 8 "strconv" 9 "sync" 10 "time" 11 12 "github.com/gin-gonic/gin" 13 "github.com/gorilla/websocket" 14 ) 15 16 type connectionValue struct { 17 key []byte 18 value any 19 } 20 21 type ConnectionValues []connectionValue 22 23 func (r *ConnectionValues) Set(key string, value any) { 24 args := *r 25 n := len(args) 26 for i := 0; i < n; i++ { 27 kv := &args[i] 28 if string(kv.key) == key { 29 kv.value = value 30 return 31 } 32 } 33 34 c := cap(args) 35 if c > n { 36 args = args[:n+1] 37 kv := &args[n] 38 kv.key = append(kv.key[:0], key...) 39 kv.value = value 40 *r = args 41 return 42 } 43 44 kv := connectionValue{} 45 kv.key = append(kv.key[:0], key...) 46 kv.value = value 47 *r = append(args, kv) 48 } 49 50 func (r *ConnectionValues) Get(key string) any { 51 args := *r 52 n := len(args) 53 for i := 0; i < n; i++ { 54 kv := &args[i] 55 if string(kv.key) == key { 56 return kv.value 57 } 58 } 59 return nil 60 } 61 62 func (r *ConnectionValues) Reset() { 63 *r = (*r)[:0] 64 } 65 66 type UnderlineConnection interface { 67 SetWriteDeadline(t time.Time) error 68 SetReadDeadline(t time.Time) error 69 SetReadLimit(limit int64) 70 SetPongHandler(h func(appData string) error) 71 SetPingHandler(h func(appData string) error) 72 WriteControl(messageType int, data []byte, deadline time.Time) error 73 WriteMessage(messageType int, data []byte) error 74 ReadMessage() (messageType int, p []byte, err error) 75 NextWriter(messageType int) (io.WriteCloser, error) 76 Close() error 77 } 78 79 type DisconnectFunc func() 80 type LeaveRoomFunc func(roomName string) 81 type ErrorFunc func(error) 82 type NativeMessageFunc func([]byte) 83 type MessageFunc any 84 type PingFunc func() 85 type PongFunc func() 86 87 // Connection 接口 88 type Connection interface { 89 Emitter 90 Err() error 91 ID() string 92 Server() *Server 93 Write(websocketMessageType int, data []byte) error 94 Context() *gin.Context 95 OnDisconnect(DisconnectFunc) 96 OnError(ErrorFunc) 97 OnPing(PingFunc) 98 OnPong(PongFunc) 99 FireOnError(err error) 100 To(string) Emitter 101 OnMessage(NativeMessageFunc) 102 On(string, MessageFunc) 103 Join(string) 104 IsJoined(roomName string) bool 105 Leave(string) bool 106 OnLeave(roomLeaveCb LeaveRoomFunc) 107 Wait() 108 Disconnect() error 109 SetValue(key string, value any) 110 GetValue(key string) any 111 GetValueArrString(key string) []string 112 GetValueString(key string) string 113 GetValueInt(key string) int 114 } 115 116 // Connection 实现 117 type connection struct { 118 err error 119 underline UnderlineConnection 120 id string 121 messageType int 122 disconnected bool 123 onDisconnectListeners []DisconnectFunc 124 onRoomLeaveListeners []LeaveRoomFunc 125 onErrorListeners []ErrorFunc 126 onPingListeners []PingFunc 127 onPongListeners []PongFunc 128 onNativeMessageListeners []NativeMessageFunc 129 onEventListeners map[string][]MessageFunc 130 started bool 131 self Emitter 132 broadcast Emitter 133 all Emitter 134 ctx *gin.Context 135 values ConnectionValues 136 server *Server 137 writerMu sync.Mutex 138 } 139 140 var _ Connection = &connection{} 141 142 const CloseMessage = websocket.CloseMessage 143 144 func newConnection(ctx *gin.Context, s *Server, underlineConn UnderlineConnection, id string) *connection { 145 c := &connection{ 146 underline: underlineConn, 147 id: id, 148 messageType: websocket.TextMessage, 149 onDisconnectListeners: make([]DisconnectFunc, 0), 150 onRoomLeaveListeners: make([]LeaveRoomFunc, 0), 151 onErrorListeners: make([]ErrorFunc, 0), 152 onNativeMessageListeners: make([]NativeMessageFunc, 0), 153 onEventListeners: make(map[string][]MessageFunc, 0), 154 onPongListeners: make([]PongFunc, 0), 155 started: false, 156 ctx: ctx, 157 server: s, 158 } 159 160 if s.config.BinaryMessages { 161 c.messageType = websocket.BinaryMessage 162 } 163 164 c.self = newEmitter(c, c.id) 165 c.broadcast = newEmitter(c, Broadcast) 166 c.all = newEmitter(c, All) 167 168 return c 169 } 170 171 func (c *connection) Err() error { 172 return c.err 173 } 174 175 func (c *connection) Write(websocketMessageType int, data []byte) error { 176 c.writerMu.Lock() 177 if writeTimeout := c.server.config.WriteTimeout; writeTimeout > 0 { 178 _ = c.underline.SetWriteDeadline(time.Now().Add(writeTimeout)) 179 } 180 181 err := c.underline.WriteMessage(websocketMessageType, data) 182 c.writerMu.Unlock() 183 if err != nil { 184 _ = c.Disconnect() 185 } 186 return err 187 } 188 189 func (c *connection) writeDefault(data []byte) { 190 _ = c.Write(c.messageType, data) 191 } 192 193 const WriteWait = 1 * time.Second 194 195 func (c *connection) startPinger() { 196 pingHandler := func(message string) error { 197 err := c.underline.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(WriteWait)) 198 if err == websocket.ErrCloseSent { 199 return nil 200 } else if _, ok := err.(net.Error); ok { 201 return nil 202 } 203 return err 204 } 205 206 c.underline.SetPingHandler(pingHandler) 207 208 go func() { 209 for { 210 time.Sleep(c.server.config.PingPeriod) 211 if c.disconnected { 212 break 213 } 214 c.fireOnPing() 215 err := c.Write(websocket.PingMessage, []byte{}) 216 if err != nil { 217 break 218 } 219 } 220 }() 221 } 222 223 func (c *connection) fireOnPing() { 224 for i := range c.onPingListeners { 225 c.onPingListeners[i]() 226 } 227 } 228 229 func (c *connection) fireOnPong() { 230 for i := range c.onPongListeners { 231 c.onPongListeners[i]() 232 } 233 } 234 235 func (c *connection) startReader() { 236 conn := c.underline 237 hasReadTimeout := c.server.config.ReadTimeout > 0 238 239 conn.SetReadLimit(c.server.config.MaxMessageSize) 240 conn.SetPongHandler(func(s string) error { 241 if hasReadTimeout { 242 _ = conn.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout)) 243 } 244 go c.fireOnPong() 245 return nil 246 }) 247 248 defer func() { _ = c.Disconnect() }() 249 250 for { 251 if hasReadTimeout { 252 _ = conn.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout)) 253 } 254 _, data, err := conn.ReadMessage() 255 if err != nil { 256 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { 257 c.FireOnError(err) 258 } 259 break 260 } else { 261 c.messageReceived(data) 262 } 263 } 264 } 265 266 func (c *connection) messageReceived(data []byte) { 267 268 if bytes.HasPrefix(data, c.server.config.EvtMessagePrefix) { 269 receivedEvt := c.server.messageSerializer.getWebsocketCustomEvent(data) 270 listeners, ok := c.onEventListeners[string(receivedEvt)] 271 if !ok || len(listeners) == 0 { 272 return 273 } 274 275 customMessage, err := c.server.messageSerializer.deserialize(receivedEvt, data) 276 if customMessage == nil || err != nil { 277 return 278 } 279 280 for i := range listeners { 281 if fn, ok := listeners[i].(func()); ok { 282 fn() 283 } else if fnString, ok := listeners[i].(func(string)); ok { 284 285 if msgString, is := customMessage.(string); is { 286 fnString(msgString) 287 } else if msgInt, is := customMessage.(int); is { 288 fnString(strconv.Itoa(msgInt)) 289 } 290 291 } else if fnInt, ok := listeners[i].(func(int)); ok { 292 fnInt(customMessage.(int)) 293 } else if fnBool, ok := listeners[i].(func(bool)); ok { 294 fnBool(customMessage.(bool)) 295 } else if fnBytes, ok := listeners[i].(func([]byte)); ok { 296 fnBytes(customMessage.([]byte)) 297 } else { 298 listeners[i].(func(any))(customMessage) 299 } 300 301 } 302 } else { 303 for i := range c.onNativeMessageListeners { 304 c.onNativeMessageListeners[i](data) 305 } 306 } 307 308 } 309 310 func (c *connection) ID() string { 311 return c.id 312 } 313 314 func (c *connection) Server() *Server { 315 return c.server 316 } 317 318 func (c *connection) Context() *gin.Context { 319 return c.ctx 320 } 321 322 func (c *connection) Values() ConnectionValues { 323 return c.values 324 } 325 326 func (c *connection) fireDisconnect() { 327 for i := range c.onDisconnectListeners { 328 c.onDisconnectListeners[i]() 329 } 330 } 331 332 func (c *connection) OnDisconnect(cb DisconnectFunc) { 333 c.onDisconnectListeners = append(c.onDisconnectListeners, cb) 334 } 335 336 func (c *connection) OnError(cb ErrorFunc) { 337 c.onErrorListeners = append(c.onErrorListeners, cb) 338 } 339 340 func (c *connection) OnPing(cb PingFunc) { 341 c.onPingListeners = append(c.onPingListeners, cb) 342 } 343 344 func (c *connection) OnPong(cb PongFunc) { 345 c.onPongListeners = append(c.onPongListeners, cb) 346 } 347 348 func (c *connection) FireOnError(err error) { 349 for _, cb := range c.onErrorListeners { 350 cb(err) 351 } 352 } 353 354 func (c *connection) To(to string) Emitter { 355 if to == Broadcast { 356 return c.broadcast 357 } else if to == All { 358 return c.all 359 } else if to == c.id { 360 return c.self 361 } 362 363 return newEmitter(c, to) 364 } 365 366 func (c *connection) EmitMessage(nativeMessage []byte) error { 367 return c.self.EmitMessage(nativeMessage) 368 } 369 370 func (c *connection) Emit(event string, message any) error { 371 return c.self.Emit(event, message) 372 } 373 374 func (c *connection) OnMessage(cb NativeMessageFunc) { 375 c.onNativeMessageListeners = append(c.onNativeMessageListeners, cb) 376 } 377 378 func (c *connection) On(event string, cb MessageFunc) { 379 if c.onEventListeners[event] == nil { 380 c.onEventListeners[event] = make([]MessageFunc, 0) 381 } 382 383 c.onEventListeners[event] = append(c.onEventListeners[event], cb) 384 } 385 386 func (c *connection) Join(roomName string) { 387 c.server.Join(roomName, c.id) 388 } 389 390 func (c *connection) IsJoined(roomName string) bool { 391 return c.server.IsJoined(roomName, c.id) 392 } 393 394 func (c *connection) Leave(roomName string) bool { 395 return c.server.Leave(roomName, c.id) 396 } 397 398 func (c *connection) OnLeave(roomLeaveCb LeaveRoomFunc) { 399 c.onRoomLeaveListeners = append(c.onRoomLeaveListeners, roomLeaveCb) 400 } 401 402 func (c *connection) fireOnLeave(roomName string) { 403 if c == nil { 404 return 405 } 406 for i := range c.onRoomLeaveListeners { 407 c.onRoomLeaveListeners[i](roomName) 408 } 409 } 410 411 func (c *connection) Wait() { 412 if c.started { 413 return 414 } 415 c.started = true 416 c.startPinger() 417 c.startReader() 418 } 419 420 var ErrAlreadyDisconnected = errors.New("already disconnected") 421 422 func (c *connection) Disconnect() error { 423 if c == nil || c.disconnected { 424 return ErrAlreadyDisconnected 425 } 426 return c.server.Disconnect(c.ID()) 427 } 428 429 func (c *connection) SetValue(key string, value any) { 430 c.values.Set(key, value) 431 } 432 433 func (c *connection) GetValue(key string) any { 434 return c.values.Get(key) 435 } 436 437 func (c *connection) GetValueArrString(key string) []string { 438 if v := c.values.Get(key); v != nil { 439 if arrString, ok := v.([]string); ok { 440 return arrString 441 } 442 } 443 return nil 444 } 445 446 func (c *connection) GetValueString(key string) string { 447 if v := c.values.Get(key); v != nil { 448 if s, ok := v.(string); ok { 449 return s 450 } 451 } 452 return "" 453 } 454 455 func (c *connection) GetValueInt(key string) int { 456 if v := c.values.Get(key); v != nil { 457 if i, ok := v.(int); ok { 458 return i 459 } else if s, ok := v.(string); ok { 460 if iv, err := strconv.Atoi(s); err == nil { 461 return iv 462 } 463 } 464 } 465 return 0 466 }