github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/websocket/server.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "log" 6 "sync" 7 8 "github.com/gin-gonic/gin" 9 w0 "github.com/gorilla/websocket" 10 ) 11 12 var ClientSource []byte 13 14 type ConnectionFunc func(Connection) 15 16 type websocketRoomPayload struct { 17 roomName string 18 connectionID string 19 } 20 21 type websocketMessagePayload struct { 22 from string 23 to string 24 data []byte 25 } 26 27 type Server struct { 28 config Config 29 ClientSource []byte 30 messageSerializer *messageSerializer 31 connections sync.Map 32 rooms map[string][]string 33 mu sync.RWMutex 34 onConnectionListeners []ConnectionFunc 35 upgrader w0.Upgrader 36 } 37 38 func NewWSServer(cfg Config) *Server { 39 cfg = cfg.Validate() 40 return &Server{ 41 config: cfg, 42 ClientSource: bytes.Replace(ClientSource, []byte(DefaultEvtMessageKey), cfg.EvtMessagePrefix, -1), 43 messageSerializer: newMessageSerializer(cfg.EvtMessagePrefix), 44 connections: sync.Map{}, // ready-to-use, this is not necessary. 45 rooms: make(map[string][]string), 46 onConnectionListeners: make([]ConnectionFunc, 0), 47 upgrader: w0.Upgrader{ 48 HandshakeTimeout: cfg.HandshakeTimeout, 49 ReadBufferSize: cfg.ReadBufferSize, 50 WriteBufferSize: cfg.WriteBufferSize, 51 Error: cfg.Error, 52 CheckOrigin: cfg.CheckOrigin, 53 Subprotocols: cfg.Subprotocols, 54 EnableCompression: cfg.EnableCompression, 55 }, 56 } 57 } 58 59 func (s *Server) Handler() func(ctx *gin.Context) { 60 return func(ctx *gin.Context) { 61 c := s.Upgrade(ctx) 62 if c.Err() != nil { 63 return 64 } 65 for i := range s.onConnectionListeners { 66 s.onConnectionListeners[i](c) 67 } 68 c.Wait() 69 } 70 } 71 72 func (s *Server) Upgrade(ctx *gin.Context) Connection { 73 conn, err := s.upgrader.Upgrade(ctx.Writer, ctx.Request, nil) 74 if err != nil { 75 log.Printf("websocket error: %v\n", err) 76 ctx.AbortWithStatus(503) 77 return &connection{err: err} 78 } 79 80 return s.handleConnection(ctx, conn) 81 } 82 83 func (s *Server) addConnection(c *connection) { 84 s.connections.Store(c.id, c) 85 } 86 87 func (s *Server) getConnection(connID string) (*connection, bool) { 88 if cValue, ok := s.connections.Load(connID); ok { 89 if conn, ok := cValue.(*connection); ok { 90 return conn, ok 91 } 92 } 93 94 return nil, false 95 } 96 97 func (s *Server) handleConnection(ctx *gin.Context, websocketConn UnderlineConnection) *connection { 98 cid := s.config.IDGenerator(ctx) 99 c := newConnection(ctx, s, websocketConn, cid) 100 s.addConnection(c) 101 s.Join(c.id, c.id) 102 return c 103 } 104 105 func (s *Server) OnConnection(cb ConnectionFunc) { 106 s.onConnectionListeners = append(s.onConnectionListeners, cb) 107 } 108 109 func (s *Server) IsConnected(connID string) bool { 110 _, found := s.getConnection(connID) 111 return found 112 } 113 114 func (s *Server) Join(roomName string, connID string) { 115 s.mu.Lock() 116 s.join(roomName, connID) 117 s.mu.Unlock() 118 } 119 120 func (s *Server) join(roomName string, connID string) { 121 if s.rooms[roomName] == nil { 122 s.rooms[roomName] = make([]string, 0) 123 } 124 s.rooms[roomName] = append(s.rooms[roomName], connID) 125 } 126 127 func (s *Server) IsJoined(roomName string, connID string) bool { 128 s.mu.RLock() 129 room := s.rooms[roomName] 130 s.mu.RUnlock() 131 132 if room == nil { 133 return false 134 } 135 136 for _, connid := range room { 137 if connID == connid { 138 return true 139 } 140 } 141 142 return false 143 } 144 145 func (s *Server) LeaveAll(connID string) { 146 s.mu.Lock() 147 for name := range s.rooms { 148 s.leave(name, connID) 149 } 150 s.mu.Unlock() 151 } 152 153 func (s *Server) Leave(roomName string, connID string) bool { 154 s.mu.Lock() 155 left := s.leave(roomName, connID) 156 s.mu.Unlock() 157 return left 158 } 159 160 func (s *Server) leave(roomName string, connID string) (left bool) { 161 if s.rooms[roomName] != nil { 162 for i := range s.rooms[roomName] { 163 if s.rooms[roomName][i] == connID { 164 s.rooms[roomName] = append(s.rooms[roomName][:i], s.rooms[roomName][i+1:]...) 165 left = true 166 break 167 } 168 } 169 if len(s.rooms[roomName]) == 0 { 170 delete(s.rooms, roomName) 171 } 172 } 173 174 if left { 175 if c, ok := s.getConnection(connID); ok { 176 c.fireOnLeave(roomName) 177 } 178 } 179 return 180 } 181 182 func (s *Server) GetTotalConnections() (n int) { 183 s.connections.Range(func(k, v any) bool { 184 n++ 185 return true 186 }) 187 188 return n 189 } 190 191 func (s *Server) GetConnections() []Connection { 192 length := s.GetTotalConnections() 193 conns := make([]Connection, length, length) 194 i := 0 195 s.connections.Range(func(k, v any) bool { 196 conn, ok := v.(*connection) 197 if !ok { 198 return false 199 } 200 conns[i] = conn 201 i++ 202 return true 203 }) 204 205 return conns 206 } 207 208 func (s *Server) GetConnection(connID string) Connection { 209 conn, ok := s.getConnection(connID) 210 if !ok { 211 return nil 212 } 213 214 return conn 215 } 216 217 func (s *Server) GetConnectionsByRoom(roomName string) []Connection { 218 var conns []Connection 219 s.mu.RLock() 220 if connIDs, found := s.rooms[roomName]; found { 221 for _, connID := range connIDs { 222 if cValue, ok := s.connections.Load(connID); ok { 223 if conn, ok := cValue.(*connection); ok { 224 conns = append(conns, conn) 225 } 226 } 227 } 228 } 229 230 s.mu.RUnlock() 231 232 return conns 233 } 234 235 func (s *Server) emitMessage(from, to string, data []byte) { 236 if to != All && to != Broadcast { 237 s.mu.RLock() 238 room := s.rooms[to] 239 s.mu.RUnlock() 240 if room != nil { 241 for _, connectionIDInsideRoom := range room { 242 if c, ok := s.getConnection(connectionIDInsideRoom); ok { 243 c.writeDefault(data) 244 } else { 245 cid := connectionIDInsideRoom 246 if c != nil { 247 cid = c.id 248 } 249 s.Leave(cid, to) 250 } 251 } 252 } 253 } else { 254 s.connections.Range(func(k, v any) bool { 255 connID, ok := k.(string) 256 if !ok { 257 return true 258 } 259 260 if to != All && to != connID { 261 if to == Broadcast && from == connID { 262 return true 263 } 264 265 } 266 267 conn, ok := v.(*connection) 268 if ok { 269 conn.writeDefault(data) 270 } 271 272 return ok 273 }) 274 } 275 } 276 277 func (s *Server) Disconnect(connID string) (err error) { 278 s.LeaveAll(connID) 279 if conn, ok := s.getConnection(connID); ok { 280 conn.disconnected = true 281 conn.fireDisconnect() 282 err = conn.underline.Close() 283 s.connections.Delete(connID) 284 } 285 return 286 }