github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/websocket/server.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bytes"
     5  	"log"
     6  	"sync"
     7  
     8  	"github.com/gin-gonic/gin"
     9  	w0 "github.com/gorilla/websocket"
    10  )
    11  
    12  var ClientSource []byte
    13  
    14  type ConnectionFunc func(Connection)
    15  
    16  type websocketRoomPayload struct {
    17  	roomName     string
    18  	connectionID string
    19  }
    20  
    21  type websocketMessagePayload struct {
    22  	from string
    23  	to   string
    24  	data []byte
    25  }
    26  
    27  type Server struct {
    28  	config                Config
    29  	ClientSource          []byte
    30  	messageSerializer     *messageSerializer
    31  	connections           sync.Map
    32  	rooms                 map[string][]string
    33  	mu                    sync.RWMutex
    34  	onConnectionListeners []ConnectionFunc
    35  	upgrader              w0.Upgrader
    36  }
    37  
    38  func NewWSServer(cfg Config) *Server {
    39  	cfg = cfg.Validate()
    40  	return &Server{
    41  		config:                cfg,
    42  		ClientSource:          bytes.Replace(ClientSource, []byte(DefaultEvtMessageKey), cfg.EvtMessagePrefix, -1),
    43  		messageSerializer:     newMessageSerializer(cfg.EvtMessagePrefix),
    44  		connections:           sync.Map{}, // ready-to-use, this is not necessary.
    45  		rooms:                 make(map[string][]string),
    46  		onConnectionListeners: make([]ConnectionFunc, 0),
    47  		upgrader: w0.Upgrader{
    48  			HandshakeTimeout:  cfg.HandshakeTimeout,
    49  			ReadBufferSize:    cfg.ReadBufferSize,
    50  			WriteBufferSize:   cfg.WriteBufferSize,
    51  			Error:             cfg.Error,
    52  			CheckOrigin:       cfg.CheckOrigin,
    53  			Subprotocols:      cfg.Subprotocols,
    54  			EnableCompression: cfg.EnableCompression,
    55  		},
    56  	}
    57  }
    58  
    59  func (s *Server) Handler() func(ctx *gin.Context) {
    60  	return func(ctx *gin.Context) {
    61  		c := s.Upgrade(ctx)
    62  		if c.Err() != nil {
    63  			return
    64  		}
    65  		for i := range s.onConnectionListeners {
    66  			s.onConnectionListeners[i](c)
    67  		}
    68  		c.Wait()
    69  	}
    70  }
    71  
    72  func (s *Server) Upgrade(ctx *gin.Context) Connection {
    73  	conn, err := s.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
    74  	if err != nil {
    75  		log.Printf("websocket error: %v\n", err)
    76  		ctx.AbortWithStatus(503)
    77  		return &connection{err: err}
    78  	}
    79  
    80  	return s.handleConnection(ctx, conn)
    81  }
    82  
    83  func (s *Server) addConnection(c *connection) {
    84  	s.connections.Store(c.id, c)
    85  }
    86  
    87  func (s *Server) getConnection(connID string) (*connection, bool) {
    88  	if cValue, ok := s.connections.Load(connID); ok {
    89  		if conn, ok := cValue.(*connection); ok {
    90  			return conn, ok
    91  		}
    92  	}
    93  
    94  	return nil, false
    95  }
    96  
    97  func (s *Server) handleConnection(ctx *gin.Context, websocketConn UnderlineConnection) *connection {
    98  	cid := s.config.IDGenerator(ctx)
    99  	c := newConnection(ctx, s, websocketConn, cid)
   100  	s.addConnection(c)
   101  	s.Join(c.id, c.id)
   102  	return c
   103  }
   104  
   105  func (s *Server) OnConnection(cb ConnectionFunc) {
   106  	s.onConnectionListeners = append(s.onConnectionListeners, cb)
   107  }
   108  
   109  func (s *Server) IsConnected(connID string) bool {
   110  	_, found := s.getConnection(connID)
   111  	return found
   112  }
   113  
   114  func (s *Server) Join(roomName string, connID string) {
   115  	s.mu.Lock()
   116  	s.join(roomName, connID)
   117  	s.mu.Unlock()
   118  }
   119  
   120  func (s *Server) join(roomName string, connID string) {
   121  	if s.rooms[roomName] == nil {
   122  		s.rooms[roomName] = make([]string, 0)
   123  	}
   124  	s.rooms[roomName] = append(s.rooms[roomName], connID)
   125  }
   126  
   127  func (s *Server) IsJoined(roomName string, connID string) bool {
   128  	s.mu.RLock()
   129  	room := s.rooms[roomName]
   130  	s.mu.RUnlock()
   131  
   132  	if room == nil {
   133  		return false
   134  	}
   135  
   136  	for _, connid := range room {
   137  		if connID == connid {
   138  			return true
   139  		}
   140  	}
   141  
   142  	return false
   143  }
   144  
   145  func (s *Server) LeaveAll(connID string) {
   146  	s.mu.Lock()
   147  	for name := range s.rooms {
   148  		s.leave(name, connID)
   149  	}
   150  	s.mu.Unlock()
   151  }
   152  
   153  func (s *Server) Leave(roomName string, connID string) bool {
   154  	s.mu.Lock()
   155  	left := s.leave(roomName, connID)
   156  	s.mu.Unlock()
   157  	return left
   158  }
   159  
   160  func (s *Server) leave(roomName string, connID string) (left bool) {
   161  	if s.rooms[roomName] != nil {
   162  		for i := range s.rooms[roomName] {
   163  			if s.rooms[roomName][i] == connID {
   164  				s.rooms[roomName] = append(s.rooms[roomName][:i], s.rooms[roomName][i+1:]...)
   165  				left = true
   166  				break
   167  			}
   168  		}
   169  		if len(s.rooms[roomName]) == 0 {
   170  			delete(s.rooms, roomName)
   171  		}
   172  	}
   173  
   174  	if left {
   175  		if c, ok := s.getConnection(connID); ok {
   176  			c.fireOnLeave(roomName)
   177  		}
   178  	}
   179  	return
   180  }
   181  
   182  func (s *Server) GetTotalConnections() (n int) {
   183  	s.connections.Range(func(k, v any) bool {
   184  		n++
   185  		return true
   186  	})
   187  
   188  	return n
   189  }
   190  
   191  func (s *Server) GetConnections() []Connection {
   192  	length := s.GetTotalConnections()
   193  	conns := make([]Connection, length, length)
   194  	i := 0
   195  	s.connections.Range(func(k, v any) bool {
   196  		conn, ok := v.(*connection)
   197  		if !ok {
   198  			return false
   199  		}
   200  		conns[i] = conn
   201  		i++
   202  		return true
   203  	})
   204  
   205  	return conns
   206  }
   207  
   208  func (s *Server) GetConnection(connID string) Connection {
   209  	conn, ok := s.getConnection(connID)
   210  	if !ok {
   211  		return nil
   212  	}
   213  
   214  	return conn
   215  }
   216  
   217  func (s *Server) GetConnectionsByRoom(roomName string) []Connection {
   218  	var conns []Connection
   219  	s.mu.RLock()
   220  	if connIDs, found := s.rooms[roomName]; found {
   221  		for _, connID := range connIDs {
   222  			if cValue, ok := s.connections.Load(connID); ok {
   223  				if conn, ok := cValue.(*connection); ok {
   224  					conns = append(conns, conn)
   225  				}
   226  			}
   227  		}
   228  	}
   229  
   230  	s.mu.RUnlock()
   231  
   232  	return conns
   233  }
   234  
   235  func (s *Server) emitMessage(from, to string, data []byte) {
   236  	if to != All && to != Broadcast {
   237  		s.mu.RLock()
   238  		room := s.rooms[to]
   239  		s.mu.RUnlock()
   240  		if room != nil {
   241  			for _, connectionIDInsideRoom := range room {
   242  				if c, ok := s.getConnection(connectionIDInsideRoom); ok {
   243  					c.writeDefault(data)
   244  				} else {
   245  					cid := connectionIDInsideRoom
   246  					if c != nil {
   247  						cid = c.id
   248  					}
   249  					s.Leave(cid, to)
   250  				}
   251  			}
   252  		}
   253  	} else {
   254  		s.connections.Range(func(k, v any) bool {
   255  			connID, ok := k.(string)
   256  			if !ok {
   257  				return true
   258  			}
   259  
   260  			if to != All && to != connID {
   261  				if to == Broadcast && from == connID {
   262  					return true
   263  				}
   264  
   265  			}
   266  
   267  			conn, ok := v.(*connection)
   268  			if ok {
   269  				conn.writeDefault(data)
   270  			}
   271  
   272  			return ok
   273  		})
   274  	}
   275  }
   276  
   277  func (s *Server) Disconnect(connID string) (err error) {
   278  	s.LeaveAll(connID)
   279  	if conn, ok := s.getConnection(connID); ok {
   280  		conn.disconnected = true
   281  		conn.fireDisconnect()
   282  		err = conn.underline.Close()
   283  		s.connections.Delete(connID)
   284  	}
   285  	return
   286  }