github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/websocket/connection.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"strconv"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gin-gonic/gin"
    13  	"github.com/gorilla/websocket"
    14  )
    15  
    16  type connectionValue struct {
    17  	key   []byte
    18  	value any
    19  }
    20  
    21  type ConnectionValues []connectionValue
    22  
    23  func (r *ConnectionValues) Set(key string, value any) {
    24  	args := *r
    25  	n := len(args)
    26  	for i := 0; i < n; i++ {
    27  		kv := &args[i]
    28  		if string(kv.key) == key {
    29  			kv.value = value
    30  			return
    31  		}
    32  	}
    33  
    34  	c := cap(args)
    35  	if c > n {
    36  		args = args[:n+1]
    37  		kv := &args[n]
    38  		kv.key = append(kv.key[:0], key...)
    39  		kv.value = value
    40  		*r = args
    41  		return
    42  	}
    43  
    44  	kv := connectionValue{}
    45  	kv.key = append(kv.key[:0], key...)
    46  	kv.value = value
    47  	*r = append(args, kv)
    48  }
    49  
    50  func (r *ConnectionValues) Get(key string) any {
    51  	args := *r
    52  	n := len(args)
    53  	for i := 0; i < n; i++ {
    54  		kv := &args[i]
    55  		if string(kv.key) == key {
    56  			return kv.value
    57  		}
    58  	}
    59  	return nil
    60  }
    61  
    62  func (r *ConnectionValues) Reset() {
    63  	*r = (*r)[:0]
    64  }
    65  
    66  type UnderlineConnection interface {
    67  	SetWriteDeadline(t time.Time) error
    68  	SetReadDeadline(t time.Time) error
    69  	SetReadLimit(limit int64)
    70  	SetPongHandler(h func(appData string) error)
    71  	SetPingHandler(h func(appData string) error)
    72  	WriteControl(messageType int, data []byte, deadline time.Time) error
    73  	WriteMessage(messageType int, data []byte) error
    74  	ReadMessage() (messageType int, p []byte, err error)
    75  	NextWriter(messageType int) (io.WriteCloser, error)
    76  	Close() error
    77  }
    78  
    79  type DisconnectFunc func()
    80  type LeaveRoomFunc func(roomName string)
    81  type ErrorFunc func(error)
    82  type NativeMessageFunc func([]byte)
    83  type MessageFunc any
    84  type PingFunc func()
    85  type PongFunc func()
    86  
    87  // Connection 接口
    88  type Connection interface {
    89  	Emitter
    90  	Err() error
    91  	ID() string
    92  	Server() *Server
    93  	Write(websocketMessageType int, data []byte) error
    94  	Context() *gin.Context
    95  	OnDisconnect(DisconnectFunc)
    96  	OnError(ErrorFunc)
    97  	OnPing(PingFunc)
    98  	OnPong(PongFunc)
    99  	FireOnError(err error)
   100  	To(string) Emitter
   101  	OnMessage(NativeMessageFunc)
   102  	On(string, MessageFunc)
   103  	Join(string)
   104  	IsJoined(roomName string) bool
   105  	Leave(string) bool
   106  	OnLeave(roomLeaveCb LeaveRoomFunc)
   107  	Wait()
   108  	Disconnect() error
   109  	SetValue(key string, value any)
   110  	GetValue(key string) any
   111  	GetValueArrString(key string) []string
   112  	GetValueString(key string) string
   113  	GetValueInt(key string) int
   114  }
   115  
   116  // Connection 实现
   117  type connection struct {
   118  	err                      error
   119  	underline                UnderlineConnection
   120  	id                       string
   121  	messageType              int
   122  	disconnected             bool
   123  	onDisconnectListeners    []DisconnectFunc
   124  	onRoomLeaveListeners     []LeaveRoomFunc
   125  	onErrorListeners         []ErrorFunc
   126  	onPingListeners          []PingFunc
   127  	onPongListeners          []PongFunc
   128  	onNativeMessageListeners []NativeMessageFunc
   129  	onEventListeners         map[string][]MessageFunc
   130  	started                  bool
   131  	self                     Emitter
   132  	broadcast                Emitter
   133  	all                      Emitter
   134  	ctx                      *gin.Context
   135  	values                   ConnectionValues
   136  	server                   *Server
   137  	writerMu                 sync.Mutex
   138  }
   139  
   140  var _ Connection = &connection{}
   141  
   142  const CloseMessage = websocket.CloseMessage
   143  
   144  func newConnection(ctx *gin.Context, s *Server, underlineConn UnderlineConnection, id string) *connection {
   145  	c := &connection{
   146  		underline:                underlineConn,
   147  		id:                       id,
   148  		messageType:              websocket.TextMessage,
   149  		onDisconnectListeners:    make([]DisconnectFunc, 0),
   150  		onRoomLeaveListeners:     make([]LeaveRoomFunc, 0),
   151  		onErrorListeners:         make([]ErrorFunc, 0),
   152  		onNativeMessageListeners: make([]NativeMessageFunc, 0),
   153  		onEventListeners:         make(map[string][]MessageFunc, 0),
   154  		onPongListeners:          make([]PongFunc, 0),
   155  		started:                  false,
   156  		ctx:                      ctx,
   157  		server:                   s,
   158  	}
   159  
   160  	if s.config.BinaryMessages {
   161  		c.messageType = websocket.BinaryMessage
   162  	}
   163  
   164  	c.self = newEmitter(c, c.id)
   165  	c.broadcast = newEmitter(c, Broadcast)
   166  	c.all = newEmitter(c, All)
   167  
   168  	return c
   169  }
   170  
   171  func (c *connection) Err() error {
   172  	return c.err
   173  }
   174  
   175  func (c *connection) Write(websocketMessageType int, data []byte) error {
   176  	c.writerMu.Lock()
   177  	if writeTimeout := c.server.config.WriteTimeout; writeTimeout > 0 {
   178  		_ = c.underline.SetWriteDeadline(time.Now().Add(writeTimeout))
   179  	}
   180  
   181  	err := c.underline.WriteMessage(websocketMessageType, data)
   182  	c.writerMu.Unlock()
   183  	if err != nil {
   184  		_ = c.Disconnect()
   185  	}
   186  	return err
   187  }
   188  
   189  func (c *connection) writeDefault(data []byte) {
   190  	_ = c.Write(c.messageType, data)
   191  }
   192  
   193  const WriteWait = 1 * time.Second
   194  
   195  func (c *connection) startPinger() {
   196  	pingHandler := func(message string) error {
   197  		err := c.underline.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(WriteWait))
   198  		if err == websocket.ErrCloseSent {
   199  			return nil
   200  		} else if _, ok := err.(net.Error); ok {
   201  			return nil
   202  		}
   203  		return err
   204  	}
   205  
   206  	c.underline.SetPingHandler(pingHandler)
   207  
   208  	go func() {
   209  		for {
   210  			time.Sleep(c.server.config.PingPeriod)
   211  			if c.disconnected {
   212  				break
   213  			}
   214  			c.fireOnPing()
   215  			err := c.Write(websocket.PingMessage, []byte{})
   216  			if err != nil {
   217  				break
   218  			}
   219  		}
   220  	}()
   221  }
   222  
   223  func (c *connection) fireOnPing() {
   224  	for i := range c.onPingListeners {
   225  		c.onPingListeners[i]()
   226  	}
   227  }
   228  
   229  func (c *connection) fireOnPong() {
   230  	for i := range c.onPongListeners {
   231  		c.onPongListeners[i]()
   232  	}
   233  }
   234  
   235  func (c *connection) startReader() {
   236  	conn := c.underline
   237  	hasReadTimeout := c.server.config.ReadTimeout > 0
   238  
   239  	conn.SetReadLimit(c.server.config.MaxMessageSize)
   240  	conn.SetPongHandler(func(s string) error {
   241  		if hasReadTimeout {
   242  			_ = conn.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
   243  		}
   244  		go c.fireOnPong()
   245  		return nil
   246  	})
   247  
   248  	defer func() { _ = c.Disconnect() }()
   249  
   250  	for {
   251  		if hasReadTimeout {
   252  			_ = conn.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
   253  		}
   254  		_, data, err := conn.ReadMessage()
   255  		if err != nil {
   256  			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
   257  				c.FireOnError(err)
   258  			}
   259  			break
   260  		} else {
   261  			c.messageReceived(data)
   262  		}
   263  	}
   264  }
   265  
   266  func (c *connection) messageReceived(data []byte) {
   267  
   268  	if bytes.HasPrefix(data, c.server.config.EvtMessagePrefix) {
   269  		receivedEvt := c.server.messageSerializer.getWebsocketCustomEvent(data)
   270  		listeners, ok := c.onEventListeners[string(receivedEvt)]
   271  		if !ok || len(listeners) == 0 {
   272  			return
   273  		}
   274  
   275  		customMessage, err := c.server.messageSerializer.deserialize(receivedEvt, data)
   276  		if customMessage == nil || err != nil {
   277  			return
   278  		}
   279  
   280  		for i := range listeners {
   281  			if fn, ok := listeners[i].(func()); ok {
   282  				fn()
   283  			} else if fnString, ok := listeners[i].(func(string)); ok {
   284  
   285  				if msgString, is := customMessage.(string); is {
   286  					fnString(msgString)
   287  				} else if msgInt, is := customMessage.(int); is {
   288  					fnString(strconv.Itoa(msgInt))
   289  				}
   290  
   291  			} else if fnInt, ok := listeners[i].(func(int)); ok {
   292  				fnInt(customMessage.(int))
   293  			} else if fnBool, ok := listeners[i].(func(bool)); ok {
   294  				fnBool(customMessage.(bool))
   295  			} else if fnBytes, ok := listeners[i].(func([]byte)); ok {
   296  				fnBytes(customMessage.([]byte))
   297  			} else {
   298  				listeners[i].(func(any))(customMessage)
   299  			}
   300  
   301  		}
   302  	} else {
   303  		for i := range c.onNativeMessageListeners {
   304  			c.onNativeMessageListeners[i](data)
   305  		}
   306  	}
   307  
   308  }
   309  
   310  func (c *connection) ID() string {
   311  	return c.id
   312  }
   313  
   314  func (c *connection) Server() *Server {
   315  	return c.server
   316  }
   317  
   318  func (c *connection) Context() *gin.Context {
   319  	return c.ctx
   320  }
   321  
   322  func (c *connection) Values() ConnectionValues {
   323  	return c.values
   324  }
   325  
   326  func (c *connection) fireDisconnect() {
   327  	for i := range c.onDisconnectListeners {
   328  		c.onDisconnectListeners[i]()
   329  	}
   330  }
   331  
   332  func (c *connection) OnDisconnect(cb DisconnectFunc) {
   333  	c.onDisconnectListeners = append(c.onDisconnectListeners, cb)
   334  }
   335  
   336  func (c *connection) OnError(cb ErrorFunc) {
   337  	c.onErrorListeners = append(c.onErrorListeners, cb)
   338  }
   339  
   340  func (c *connection) OnPing(cb PingFunc) {
   341  	c.onPingListeners = append(c.onPingListeners, cb)
   342  }
   343  
   344  func (c *connection) OnPong(cb PongFunc) {
   345  	c.onPongListeners = append(c.onPongListeners, cb)
   346  }
   347  
   348  func (c *connection) FireOnError(err error) {
   349  	for _, cb := range c.onErrorListeners {
   350  		cb(err)
   351  	}
   352  }
   353  
   354  func (c *connection) To(to string) Emitter {
   355  	if to == Broadcast {
   356  		return c.broadcast
   357  	} else if to == All {
   358  		return c.all
   359  	} else if to == c.id {
   360  		return c.self
   361  	}
   362  
   363  	return newEmitter(c, to)
   364  }
   365  
   366  func (c *connection) EmitMessage(nativeMessage []byte) error {
   367  	return c.self.EmitMessage(nativeMessage)
   368  }
   369  
   370  func (c *connection) Emit(event string, message any) error {
   371  	return c.self.Emit(event, message)
   372  }
   373  
   374  func (c *connection) OnMessage(cb NativeMessageFunc) {
   375  	c.onNativeMessageListeners = append(c.onNativeMessageListeners, cb)
   376  }
   377  
   378  func (c *connection) On(event string, cb MessageFunc) {
   379  	if c.onEventListeners[event] == nil {
   380  		c.onEventListeners[event] = make([]MessageFunc, 0)
   381  	}
   382  
   383  	c.onEventListeners[event] = append(c.onEventListeners[event], cb)
   384  }
   385  
   386  func (c *connection) Join(roomName string) {
   387  	c.server.Join(roomName, c.id)
   388  }
   389  
   390  func (c *connection) IsJoined(roomName string) bool {
   391  	return c.server.IsJoined(roomName, c.id)
   392  }
   393  
   394  func (c *connection) Leave(roomName string) bool {
   395  	return c.server.Leave(roomName, c.id)
   396  }
   397  
   398  func (c *connection) OnLeave(roomLeaveCb LeaveRoomFunc) {
   399  	c.onRoomLeaveListeners = append(c.onRoomLeaveListeners, roomLeaveCb)
   400  }
   401  
   402  func (c *connection) fireOnLeave(roomName string) {
   403  	if c == nil {
   404  		return
   405  	}
   406  	for i := range c.onRoomLeaveListeners {
   407  		c.onRoomLeaveListeners[i](roomName)
   408  	}
   409  }
   410  
   411  func (c *connection) Wait() {
   412  	if c.started {
   413  		return
   414  	}
   415  	c.started = true
   416  	c.startPinger()
   417  	c.startReader()
   418  }
   419  
   420  var ErrAlreadyDisconnected = errors.New("already disconnected")
   421  
   422  func (c *connection) Disconnect() error {
   423  	if c == nil || c.disconnected {
   424  		return ErrAlreadyDisconnected
   425  	}
   426  	return c.server.Disconnect(c.ID())
   427  }
   428  
   429  func (c *connection) SetValue(key string, value any) {
   430  	c.values.Set(key, value)
   431  }
   432  
   433  func (c *connection) GetValue(key string) any {
   434  	return c.values.Get(key)
   435  }
   436  
   437  func (c *connection) GetValueArrString(key string) []string {
   438  	if v := c.values.Get(key); v != nil {
   439  		if arrString, ok := v.([]string); ok {
   440  			return arrString
   441  		}
   442  	}
   443  	return nil
   444  }
   445  
   446  func (c *connection) GetValueString(key string) string {
   447  	if v := c.values.Get(key); v != nil {
   448  		if s, ok := v.(string); ok {
   449  			return s
   450  		}
   451  	}
   452  	return ""
   453  }
   454  
   455  func (c *connection) GetValueInt(key string) int {
   456  	if v := c.values.Get(key); v != nil {
   457  		if i, ok := v.(int); ok {
   458  			return i
   459  		} else if s, ok := v.(string); ok {
   460  			if iv, err := strconv.Atoi(s); err == nil {
   461  				return iv
   462  			}
   463  		}
   464  	}
   465  	return 0
   466  }