github.com/oarkflow/sio@v0.0.6/server.go (about)

     1  package sio
     2  
     3  import (
     4  	"io"
     5  	"log/slog"
     6  	"os"
     7  	"os/signal"
     8  	"strings"
     9  	"sync"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/oarkflow/frame"
    14  	"github.com/oarkflow/frame/pkg/websocket"
    15  )
    16  
    17  const ( //                        ASCII chars
    18  	startOfHeaderByte uint8 = 1 // SOH
    19  	startOfDataByte         = 2 // STX
    20  
    21  	// SubProtocol is the official sacrificial-socket sub protocol
    22  	SubProtocol string = "sac-sock"
    23  )
    24  
    25  type event struct {
    26  	eventName    string
    27  	eventHandler func(*Socket, []byte)
    28  }
    29  
    30  // Config specifies parameters for upgrading an HTTP connection to a
    31  // WebSocket connection.
    32  //
    33  // It is safe to call Config's methods concurrently.
    34  type Config struct {
    35  	HandshakeTimeout                time.Duration
    36  	ReadBufferSize, WriteBufferSize int
    37  	WriteBufferPool                 websocket.BufferPool
    38  	Subprotocols                    []string
    39  	Error                           func(ctx *frame.Context, status int, reason error)
    40  	CheckOrigin                     func(r *frame.Context) bool
    41  	EnableCompression               bool
    42  	Mutable                         bool
    43  }
    44  
    45  // Server manages the coordination between
    46  // sockets, rooms, events and the socket hub
    47  // add my own custom field
    48  type Server struct {
    49  	hub              *hub
    50  	events           map[string]*event
    51  	onConnectFunc    func(*Socket) error
    52  	onDisconnectFunc func(*Socket) error
    53  	onError          func(*Socket, error)
    54  	l                *sync.RWMutex
    55  	upgrader         websocket.Upgrader
    56  }
    57  
    58  // New creates a new instance of Server
    59  func New(cfg ...Config) *Server {
    60  	var config Config
    61  	upgrader := DefaultUpgrader()
    62  	if len(cfg) > 0 {
    63  		config = cfg[0]
    64  	}
    65  	if config.CheckOrigin != nil {
    66  		upgrader.CheckOrigin = config.CheckOrigin
    67  	}
    68  	if config.HandshakeTimeout != 0 {
    69  		upgrader.HandshakeTimeout = config.HandshakeTimeout
    70  	}
    71  	if config.ReadBufferSize != 0 {
    72  		upgrader.ReadBufferSize = config.ReadBufferSize
    73  	}
    74  	if config.WriteBufferSize != 0 {
    75  		upgrader.WriteBufferSize = config.WriteBufferSize
    76  	}
    77  	if len(config.Subprotocols) > 0 {
    78  		upgrader.Subprotocols = config.Subprotocols
    79  	} else {
    80  		upgrader.Subprotocols = []string{SubProtocol}
    81  	}
    82  	if config.Error != nil {
    83  		upgrader.Error = config.Error
    84  	}
    85  	upgrader.EnableCompression = config.EnableCompression
    86  	s := &Server{
    87  		hub:      newHub(),
    88  		events:   make(map[string]*event),
    89  		l:        &sync.RWMutex{},
    90  		upgrader: upgrader,
    91  	}
    92  
    93  	return s
    94  }
    95  
    96  func (serv *Server) ShutdownWithSignal() {
    97  	c := make(chan bool)
    98  	serv.EnableSignalShutdown(c)
    99  	go func() {
   100  		<-c
   101  		os.Exit(0)
   102  	}()
   103  }
   104  
   105  // EnableSignalShutdown listens for linux syscalls SIGHUP, SIGINT, SIGTERM, SIGQUIT, SIGKILL and
   106  // calls the Server.Shutdown() to perform a clean shutdown. true will be passed into complete
   107  // after the Shutdown proccess is finished
   108  func (serv *Server) EnableSignalShutdown(complete chan<- bool) {
   109  	c := make(chan os.Signal, 1)
   110  	signal.Notify(c,
   111  		syscall.SIGHUP,
   112  		syscall.SIGINT,
   113  		syscall.SIGTERM,
   114  		syscall.SIGQUIT,
   115  		syscall.SIGKILL)
   116  
   117  	go func() {
   118  		<-c
   119  		complete <- serv.Shutdown()
   120  	}()
   121  }
   122  
   123  func (serv *Server) Lock() {
   124  	serv.l.Lock()
   125  }
   126  
   127  func (serv *Server) Unlock() {
   128  	serv.l.Unlock()
   129  }
   130  
   131  func (serv *Server) RoomSocketList(id string) map[string]*Socket {
   132  	sockets := make(map[string]*Socket)
   133  	if room, exists := serv.hub.rooms[id]; exists {
   134  		room.l.Lock()
   135  		for id, socket := range room.sockets {
   136  			sockets[id] = socket
   137  		}
   138  		room.l.Unlock()
   139  	}
   140  	return sockets
   141  }
   142  
   143  func (serv *Server) SocketList() map[string]*Socket {
   144  	sockets := make(map[string]*Socket)
   145  	serv.l.Lock()
   146  	for id, socket := range serv.hub.sockets {
   147  		sockets[id] = socket
   148  	}
   149  	serv.l.Unlock()
   150  	return sockets
   151  }
   152  
   153  // Shutdown closes all active sockets and triggers the Shutdown()
   154  // method on any Adapter that is currently set.
   155  func (serv *Server) Shutdown() bool {
   156  	slog.Info("shutting down")
   157  	// complete := serv.hub.shutdown()
   158  
   159  	serv.hub.shutdownCh <- true
   160  	socketList := <-serv.hub.socketList
   161  
   162  	for _, s := range socketList {
   163  		s.Close()
   164  	}
   165  
   166  	if serv.hub.multihomeEnabled {
   167  		slog.Info("shutting down multihome backend")
   168  		serv.hub.multihomeBackend.Shutdown()
   169  		slog.Info("backend shutdown")
   170  	}
   171  
   172  	slog.Info("shutdown")
   173  	return true
   174  }
   175  
   176  // EventHandler is an interface for registering events using SockerServer.OnEvent
   177  type EventHandler interface {
   178  	HandleEvent(*Socket, []byte)
   179  	EventName() string
   180  }
   181  
   182  // On registers event functions to be called on individual Socket connections
   183  // when the server's socket receives an Emit from the client's socket.
   184  //
   185  // Any event functions registered with On, must be safe for concurrent use by multiple
   186  // go routines
   187  func (serv *Server) On(eventName string, handleFunc func(*Socket, []byte)) {
   188  	serv.l.Lock()
   189  	defer serv.l.Unlock()
   190  	serv.events[eventName] = &event{eventName, handleFunc} // you think you can handle the func?
   191  }
   192  
   193  func (serv *Server) Off(eventName string) {
   194  	serv.l.Lock()
   195  	defer serv.l.Unlock()
   196  	delete(serv.events, eventName)
   197  }
   198  
   199  // OnEvent has the same functionality as On, but accepts
   200  // an EventHandler interface instead of a handler function.
   201  func (serv *Server) OnEvent(h EventHandler) {
   202  	serv.On(h.EventName(), h.HandleEvent)
   203  }
   204  
   205  // OnConnect registers an event function to be called whenever a new Socket connection
   206  // is created
   207  func (serv *Server) OnConnect(handleFunc func(*Socket) error) {
   208  	serv.onConnectFunc = handleFunc
   209  }
   210  
   211  // OnError registers an event function to be called whenever a new Socket connection
   212  // is created
   213  func (serv *Server) OnError(handleFunc func(*Socket, error)) {
   214  	serv.onError = handleFunc
   215  }
   216  
   217  // OnDisconnect registers an event function to be called as soon as a Socket connection
   218  // is closed
   219  func (serv *Server) OnDisconnect(handleFunc func(*Socket) error) {
   220  	serv.onDisconnectFunc = handleFunc
   221  }
   222  
   223  // Handle will upgrade a http request to a websocket using the sac-sock subprotocol
   224  func (serv *Server) Handle(ctx *frame.Context) {
   225  	err := serv.upgrader.Upgrade(ctx, func(ws *websocket.Conn) {
   226  		serv.loop(ctx, ws)
   227  	})
   228  	if err != nil {
   229  		slog.Error(err.Error())
   230  		return
   231  	}
   232  }
   233  
   234  // DefaultUpgrader returns a websocket upgrader suitable for creating sacrificial-socket websockets.
   235  func DefaultUpgrader() websocket.Upgrader {
   236  	return websocket.Upgrader{
   237  		ReadBufferSize:  1024,
   238  		WriteBufferSize: 1024,
   239  		CheckOrigin: func(ctx *frame.Context) bool {
   240  			return true
   241  		},
   242  		Subprotocols: []string{SubProtocol},
   243  	}
   244  }
   245  
   246  // SetUpgrader sets the websocket.Upgrader used by the Server.
   247  func (serv *Server) SetUpgrader(u websocket.Upgrader) {
   248  	serv.upgrader = u
   249  }
   250  
   251  // SetMultihomeBackend registers an Adapter interface and calls its Init() method
   252  func (serv *Server) SetMultihomeBackend(b Adapter) {
   253  	serv.hub.setMultihomeBackend(b)
   254  }
   255  
   256  // ToRoom dispatches an event to all Sockets in the specified room.
   257  func (serv *Server) ToRoom(roomName, eventName string, data any) {
   258  	serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data})
   259  }
   260  
   261  // ToRoomExcept dispatches an event to all Sockets in the specified room.
   262  func (serv *Server) ToRoomExcept(roomName string, except []string, eventName string, data any) {
   263  	serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data, Except: except})
   264  }
   265  
   266  // Broadcast dispatches an event to all Sockets on the Server.
   267  func (serv *Server) Broadcast(eventName string, data any) {
   268  	serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Data: data})
   269  }
   270  
   271  // BroadcastExcept dispatches an event to all Sockets on the Server.
   272  func (serv *Server) BroadcastExcept(except []string, eventName string, data any) {
   273  	serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Except: except, Data: data})
   274  }
   275  
   276  // ToSocket dispatches an event to the specified socket ID.
   277  func (serv *Server) ToSocket(socketID, eventName string, data any) {
   278  	serv.ToRoom("__socket_id:"+socketID, eventName, data)
   279  }
   280  
   281  // loop handles all the coordination between new sockets
   282  // reading frames and dispatching events
   283  func (serv *Server) loop(ctx *frame.Context, ws *websocket.Conn) {
   284  	s := newSocket(serv, ctx, ws)
   285  	slog.Info("connected", "id", s.ID())
   286  
   287  	defer s.Close()
   288  
   289  	s.Join("__socket_id:" + s.ID())
   290  	serv.l.RLock()
   291  	e := serv.onConnectFunc
   292  	serv.l.RUnlock()
   293  
   294  	if e != nil {
   295  		err := e(s)
   296  		if err != nil && serv.onError != nil {
   297  			serv.onError(s, err)
   298  		}
   299  	}
   300  	// ws.SetReadLimit(512)
   301  	// ws.SetReadDeadline(time.Now().Add(60 * time.Second))
   302  	// ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(60 * time.Second)); return nil })
   303  	for {
   304  		msg, err := s.receive()
   305  		if ignorableError(err) {
   306  			return
   307  		}
   308  		if err != nil {
   309  			slog.Error(err.Error())
   310  			return
   311  		}
   312  
   313  		eventName := ""
   314  		contentIdx := 0
   315  
   316  		for idx, chr := range msg {
   317  			if chr == startOfDataByte {
   318  				eventName = string(msg[:idx])
   319  				contentIdx = idx + 1
   320  				break
   321  			}
   322  		}
   323  		if eventName == "" {
   324  			slog.Warn("no event to dispatch")
   325  			continue
   326  		}
   327  
   328  		serv.l.RLock()
   329  		e, exists := serv.events[eventName]
   330  		serv.l.RUnlock()
   331  
   332  		if exists {
   333  			go e.eventHandler(s, msg[contentIdx:])
   334  		}
   335  	}
   336  }
   337  
   338  func ignorableError(err error) bool {
   339  	// not an error
   340  	if err == nil {
   341  		return false
   342  	}
   343  
   344  	return err == io.EOF ||
   345  		websocket.IsCloseError(err, 1000) ||
   346  		websocket.IsCloseError(err, 1001) ||
   347  		strings.HasSuffix(err.Error(), "use of closed network connection")
   348  }