github.com/metaworking/channeld@v0.7.3/pkg/channeld/connection.go (about)

     1  package channeld
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"path/filepath"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/golang/snappy"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/metaworking/channeld/pkg/channeldpb"
    16  	"github.com/metaworking/channeld/pkg/common"
    17  	"github.com/metaworking/channeld/pkg/fsm"
    18  	"github.com/metaworking/channeld/pkg/replaypb"
    19  	"github.com/puzpuzpuz/xsync/v2"
    20  	"github.com/xtaci/kcp-go"
    21  	"go.uber.org/zap"
    22  	"google.golang.org/protobuf/proto"
    23  )
    24  
    25  type ConnectionId uint32
    26  
    27  const MaxPacketSize int = 0x00ffff
    28  const PacketHeaderSize int = 5
    29  
    30  //type ConnectionState int32
    31  
    32  const (
    33  	ConnectionState_UNAUTHENTICATED int32 = 0
    34  	ConnectionState_AUTHENTICATED   int32 = 1
    35  	ConnectionState_CLOSING         int32 = 2
    36  )
    37  
    38  // Add an interface before the underlying network layer for the test purpose.
    39  type MessageSender interface {
    40  	Send(c *Connection, ctx MessageContext) //(c *Connection, channelId ChannelId, msgType channeldpb.MessageType, msg Message)
    41  }
    42  
    43  /*
    44  type queuedMessageCtxSender struct {
    45  	MessageSender
    46  }
    47  
    48  func (s *queuedMessageCtxSender) Send(c *Connection, ctx MessageContext) {
    49  	c.sendQueue <- ctx
    50  }
    51  */
    52  
    53  type queuedMessagePackSender struct {
    54  	MessageSender
    55  }
    56  
    57  func (s *queuedMessagePackSender) Send(c *Connection, ctx MessageContext) {
    58  	msgBody, err := proto.Marshal(ctx.Msg)
    59  	if err != nil {
    60  		c.logger.Error("failed to marshal message", zap.Error(err), zap.Uint32("msgType", uint32(ctx.MsgType)))
    61  		return
    62  	}
    63  
    64  	mp := &channeldpb.MessagePack{
    65  		ChannelId: ctx.ChannelId,
    66  		Broadcast: ctx.Broadcast,
    67  		StubId:    ctx.StubId,
    68  		MsgType:   uint32(ctx.MsgType),
    69  		MsgBody:   msgBody,
    70  	}
    71  
    72  	// Check the message pack size before adding to the queue
    73  	size := proto.Size(mp)
    74  	if size >= MaxPacketSize-PacketHeaderSize {
    75  		c.logger.Warn("failed to send the message and its size exceeds the limit", zap.Int("size", size))
    76  		return
    77  	}
    78  
    79  	// Double check
    80  	if !c.IsClosing() {
    81  		c.sendQueue <- mp
    82  	}
    83  }
    84  
    85  type Connection struct {
    86  	ConnectionInChannel
    87  	id              ConnectionId
    88  	connectionType  channeldpb.ConnectionType
    89  	compressionType channeldpb.CompressionType
    90  	conn            net.Conn
    91  	readBuffer      []byte
    92  	readPos         int
    93  	// reader          *bufio.Reader
    94  	// writer          *bufio.Writer
    95  	sender               MessageSender
    96  	sendQueue            chan *channeldpb.MessagePack //MessageContext
    97  	oversizedMsgPack     *channeldpb.MessagePack
    98  	pit                  string
    99  	fsm                  *fsm.FiniteStateMachine
   100  	fsmDisallowedCounter int
   101  	logger               *Logger
   102  	state                int32 // Don't put the connection state into the FSM as 1) the FSM's states are user-defined. 2) the FSM is not goroutine-safe.
   103  	connTime             time.Time
   104  	closeHandlers        []func()
   105  	replaySession        *replaypb.ReplaySession
   106  	spatialSubscriptions *xsync.MapOf[common.ChannelId, *channeldpb.ChannelSubscriptionOptions]
   107  }
   108  
   109  var allConnections *xsync.MapOf[ConnectionId, *Connection]
   110  var nextConnectionId uint32 = 0
   111  var serverFsm *fsm.FiniteStateMachine
   112  var clientFsm *fsm.FiniteStateMachine
   113  
   114  func InitConnections(serverFsmPath string, clientFsmPath string) {
   115  	if allConnections != nil {
   116  		return
   117  	}
   118  
   119  	allConnections = xsync.NewTypedMapOf[ConnectionId, *Connection](UintIdHasher[ConnectionId]())
   120  
   121  	bytes, err := os.ReadFile(serverFsmPath)
   122  	if err == nil {
   123  		serverFsm, err = fsm.Load(bytes)
   124  	}
   125  	if err != nil {
   126  		rootLogger.Panic("failed to read server FSM",
   127  			zap.Error(err),
   128  		)
   129  	} else {
   130  		rootLogger.Info("loaded server FSM",
   131  			zap.String("path", serverFsmPath),
   132  			zap.String("currentState", serverFsm.CurrentState().Name),
   133  		)
   134  	}
   135  
   136  	bytes, err = os.ReadFile(clientFsmPath)
   137  	if err == nil {
   138  		clientFsm, err = fsm.Load(bytes)
   139  	}
   140  	if err != nil {
   141  		rootLogger.Panic("failed to read client FSM", zap.Error(err))
   142  	} else {
   143  		rootLogger.Info("loaded client FSM",
   144  			zap.String("path", clientFsmPath),
   145  			zap.String("currentState", clientFsm.CurrentState().Name),
   146  		)
   147  	}
   148  }
   149  
   150  func GetConnection(id ConnectionId) *Connection {
   151  	c, ok := allConnections.Load(id)
   152  	if ok {
   153  		if c.IsClosing() {
   154  			return nil
   155  		}
   156  		return c
   157  	} else {
   158  		return nil
   159  	}
   160  }
   161  
   162  func startGoroutines(connection *Connection) {
   163  	// receive goroutine
   164  	go func() {
   165  		for !connection.IsClosing() {
   166  			connection.receive()
   167  		}
   168  	}()
   169  
   170  	// tick & flush goroutine
   171  	go func() {
   172  		for !connection.IsClosing() {
   173  			connection.flush()
   174  			time.Sleep(time.Millisecond)
   175  		}
   176  	}()
   177  }
   178  
   179  func StartListening(t channeldpb.ConnectionType, network string, address string) {
   180  	rootLogger.Info("start listenning",
   181  		zap.String("connType", t.String()),
   182  		zap.String("network", network),
   183  		zap.String("address", address),
   184  	)
   185  
   186  	var listener net.Listener
   187  	var err error
   188  	switch network {
   189  	case "ws", "websocket":
   190  		startWebSocketServer(t, address)
   191  		return
   192  	case "kcp":
   193  		listener, err = kcp.Listen(address)
   194  	default:
   195  		listener, err = net.Listen(network, address)
   196  	}
   197  
   198  	if err != nil {
   199  		rootLogger.Panic("failed to listen", zap.Error(err))
   200  		return
   201  	}
   202  
   203  	defer listener.Close()
   204  
   205  	for {
   206  		conn, err := listener.Accept()
   207  		if err != nil {
   208  			rootLogger.Error("failed to accept connection", zap.Error(err))
   209  		} else {
   210  			if network == "tcp" {
   211  				tcpConn := conn.(*net.TCPConn)
   212  				if err := tcpConn.SetReadBuffer(0x0fffff); err != nil {
   213  					rootLogger.Error("failed to set read buffer size", zap.Error(err))
   214  				}
   215  				if err := tcpConn.SetWriteBuffer(0x0fffff); err != nil {
   216  					rootLogger.Error("failed to set write buffer size", zap.Error(err))
   217  				}
   218  				tcpConn.SetNoDelay(true)
   219  			}
   220  
   221  			// Check if the IP address is banned.
   222  			ip := GetIP(conn.RemoteAddr())
   223  			_, banned := ipBlacklist[ip]
   224  			if banned {
   225  				securityLogger.Info("refused connection of banned IP address", zap.String("ip", ip))
   226  				conn.Close()
   227  				continue
   228  			}
   229  
   230  			connection := AddConnection(conn, t)
   231  			connection.Logger().Debug("accepted connection")
   232  			startGoroutines(connection)
   233  		}
   234  	}
   235  }
   236  
   237  func generateNextConnId(c net.Conn, maxConnId uint32) {
   238  	if GlobalSettings.Development {
   239  		atomic.AddUint32(&nextConnectionId, 1)
   240  		if nextConnectionId >= maxConnId {
   241  			// For now, we don't consider re-using the ConnectionId. Even if there are 100 incoming connections per sec, channeld can run over a year.
   242  			rootLogger.Panic("connectionId reached the limit", zap.Uint32("maxConnId", maxConnId))
   243  		}
   244  	} else {
   245  		// In non-dev mode, hash the (remote address + timestamp) to get a less guessable ID
   246  		hash := HashString(c.RemoteAddr().String())
   247  		hash = hash ^ uint32(time.Now().UnixNano())
   248  		nextConnectionId = hash & maxConnId
   249  	}
   250  }
   251  
   252  // NOT goroutine-safe. NEVER call AddConnection in different goroutines.
   253  func AddConnection(c net.Conn, t channeldpb.ConnectionType) *Connection {
   254  	var readerSize int
   255  	// var writerSize int
   256  	if t == channeldpb.ConnectionType_SERVER {
   257  		readerSize = GlobalSettings.ServerReadBufferSize
   258  		// writerSize = GlobalSettings.ServerWriteBufferSize
   259  	} else if t == channeldpb.ConnectionType_CLIENT {
   260  		readerSize = GlobalSettings.ClientReadBufferSize
   261  		// writerSize = GlobalSettings.ClientWriteBufferSize
   262  	} else {
   263  		rootLogger.Panic("invalid connection type", zap.Int32("connType", int32(t)))
   264  	}
   265  	if readerSize < MaxPacketSize+PacketHeaderSize {
   266  		readerSize = MaxPacketSize + PacketHeaderSize
   267  	}
   268  	maxConnId := uint32(1)<<GlobalSettings.MaxConnectionIdBits - 1
   269  
   270  	for tries := 0; ; tries++ {
   271  		generateNextConnId(c, maxConnId)
   272  		if _, exists := allConnections.Load(ConnectionId(nextConnectionId)); !exists {
   273  			break
   274  		}
   275  
   276  		rootLogger.Warn("there's a same connId existing, will try to generate a new one", zap.Uint32("connId", nextConnectionId))
   277  		if tries >= 100 {
   278  			rootLogger.Panic("could not find non-duplicate connId")
   279  		}
   280  	}
   281  
   282  	connection := &Connection{
   283  		id:              ConnectionId(nextConnectionId),
   284  		connectionType:  t,
   285  		compressionType: channeldpb.CompressionType_NO_COMPRESSION,
   286  		conn:            c,
   287  		readBuffer:      make([]byte, readerSize),
   288  		readPos:         0,
   289  		// reader:    bufio.NewReaderSize(c, readerSize),
   290  		// writer:    bufio.NewWriterSize(c, writerSize),
   291  		sender:               &queuedMessagePackSender{},
   292  		sendQueue:            make(chan *channeldpb.MessagePack, 128),
   293  		fsmDisallowedCounter: 0,
   294  		logger: &Logger{rootLogger.With(
   295  			zap.String("connType", t.String()),
   296  			zap.Uint32("connId", nextConnectionId),
   297  		)},
   298  		state:                ConnectionState_UNAUTHENTICATED,
   299  		connTime:             time.Now(),
   300  		closeHandlers:        make([]func(), 0),
   301  		spatialSubscriptions: xsync.NewTypedMapOf[common.ChannelId, *channeldpb.ChannelSubscriptionOptions](UintIdHasher[common.ChannelId]()),
   302  	}
   303  
   304  	if connection.isPacketRecordingEnabled() {
   305  		connection.replaySession = &replaypb.ReplaySession{
   306  			Packets: make([]*replaypb.ReplayPacket, 0, 1024),
   307  		}
   308  	}
   309  
   310  	switch t {
   311  	case channeldpb.ConnectionType_SERVER:
   312  		if serverFsm != nil {
   313  			// IMPORTANT: always make a value copy
   314  			fsm := *serverFsm
   315  			connection.fsm = &fsm
   316  		}
   317  	case channeldpb.ConnectionType_CLIENT:
   318  		if clientFsm != nil {
   319  			// IMPORTANT: always make a value copy
   320  			fsm := *clientFsm
   321  			connection.fsm = &fsm
   322  		}
   323  	}
   324  
   325  	if connection.fsm == nil {
   326  		rootLogger.Panic("cannot set the FSM for connection", zap.String("connType", t.String()))
   327  	}
   328  
   329  	allConnections.Store(connection.id, connection)
   330  
   331  	if GlobalSettings.ConnectionAuthTimeoutMs > 0 {
   332  		unauthenticatedConnections.Store(connection.id, connection)
   333  	}
   334  
   335  	connectionNum.WithLabelValues(t.String()).Inc()
   336  
   337  	return connection
   338  }
   339  
   340  func (c *Connection) AddCloseHandler(handlerFunc func()) {
   341  	c.closeHandlers = append(c.closeHandlers, handlerFunc)
   342  }
   343  
   344  func (c *Connection) Close() {
   345  	defer func() {
   346  		recover()
   347  	}()
   348  	if c.IsClosing() {
   349  		c.Logger().Debug("connection is already closed")
   350  		return
   351  	}
   352  
   353  	if c.isPacketRecordingEnabled() {
   354  		c.persistReplaySession()
   355  	}
   356  
   357  	for _, handlerFunc := range c.closeHandlers {
   358  		handlerFunc()
   359  	}
   360  
   361  	atomic.StoreInt32(&c.state, ConnectionState_CLOSING)
   362  	c.conn.Close()
   363  	close(c.sendQueue)
   364  	allConnections.Delete(c.id)
   365  	unauthenticatedConnections.Delete(c.id)
   366  
   367  	c.Logger().Info("closed connection")
   368  	connectionNum.WithLabelValues(c.connectionType.String()).Dec()
   369  }
   370  
   371  func (c *Connection) IsClosing() bool {
   372  	return c.state > ConnectionState_AUTHENTICATED
   373  }
   374  
   375  func (c *Connection) receive() {
   376  	// Read all bytes into the buffer at once
   377  	readPtr := c.readBuffer[c.readPos:]
   378  	bytesRead, err := c.conn.Read(readPtr)
   379  	if err != nil {
   380  		switch err := err.(type) {
   381  		case *net.OpError:
   382  			c.Logger().Info("net op error",
   383  				zap.String("op", err.Op),
   384  				zap.String("remoteAddr", c.conn.RemoteAddr().String()),
   385  				zap.Error(err),
   386  			)
   387  		case *websocket.CloseError:
   388  			c.Logger().Info("disconnected",
   389  				zap.String("remoteAddr", c.conn.RemoteAddr().String()),
   390  			)
   391  		}
   392  
   393  		if err == io.EOF {
   394  			c.Logger().Info("disconnected",
   395  				zap.String("remoteAddr", c.conn.RemoteAddr().String()),
   396  			)
   397  		}
   398  		c.Close()
   399  		return
   400  	}
   401  	c.readPos += bytesRead
   402  	if c.readPos < PacketHeaderSize {
   403  		// Unfinished header
   404  		fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc()
   405  		return
   406  	}
   407  
   408  	bufPos := 0
   409  	for bufPos = 0; bufPos < c.readPos; {
   410  		packet, err := c.readPacket(&bufPos)
   411  		// there's a wire format error, close the connection to give a quick feedback to the other end.
   412  		if err != nil {
   413  			c.Close()
   414  			return
   415  
   416  		}
   417  		// all fully received packets are handled
   418  		if packet == nil {
   419  			break
   420  		}
   421  
   422  		combinedPacketCount.WithLabelValues(c.connectionType.String()).Inc()
   423  	}
   424  
   425  	if bufPos < c.readPos {
   426  		// Move unhandled content to the front
   427  		copy(c.readBuffer, c.readBuffer[bufPos:c.readPos])
   428  	}
   429  
   430  	// Move read position
   431  	c.readPos -= bufPos
   432  }
   433  
   434  func readSize(tag []byte) int {
   435  	if tag[0] != 67 || tag[1] != 72 {
   436  		return 0
   437  	}
   438  
   439  	size := int(tag[3]) | int(tag[2])<<8
   440  
   441  	return size
   442  }
   443  
   444  func (c *Connection) readPacket(bufPos *int) (*channeldpb.Packet, error) {
   445  	if c.readPos-*bufPos < PacketHeaderSize {
   446  		// Unfinished header
   447  		fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc()
   448  		return nil, nil
   449  	}
   450  
   451  	tag := c.readBuffer[*bufPos : *bufPos+PacketHeaderSize]
   452  
   453  	packetSize := readSize(tag)
   454  	if packetSize == 0 {
   455  		c.readPos = 0
   456  		connectionClosed.WithLabelValues(c.connectionType.String()).Inc()
   457  		c.Logger().Warn("invalid tag, the connection will be closed",
   458  			zap.Binary("tag", tag),
   459  		)
   460  		return nil, errors.New("invlaid tag")
   461  	}
   462  
   463  	if packetSize > MaxPacketSize {
   464  		c.readPos = 0
   465  		connectionClosed.WithLabelValues(c.connectionType.String()).Inc()
   466  		c.Logger().Warn("packet size exceeds the limit, the connection will be closed", zap.Int("packetSize", packetSize), zap.Int("bufferSize", len(c.readBuffer)))
   467  		return nil, errors.New("packetSize too large")
   468  	}
   469  
   470  	fullSize := PacketHeaderSize + packetSize
   471  
   472  	if c.readPos < *bufPos+fullSize {
   473  		// Unfinished packet
   474  
   475  		fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc()
   476  		// this is a normal case, turn off the logs
   477  		//c.Logger().Info("read part of package", zap.Int("readpos", c.readPos), zap.Int("full size", fullSize))
   478  		return nil, nil
   479  	}
   480  
   481  	bytes := c.readBuffer[*bufPos+PacketHeaderSize : *bufPos+fullSize]
   482  
   483  	bytesReceived.WithLabelValues(c.connectionType.String()).Add(float64(fullSize))
   484  
   485  	// Apply the decompression from the 5th byte in the header
   486  	ct := tag[4]
   487  	_, valid := channeldpb.CompressionType_name[int32(ct)]
   488  	if valid && ct != 0 {
   489  		c.compressionType = channeldpb.CompressionType(ct)
   490  		if c.compressionType == channeldpb.CompressionType_SNAPPY {
   491  			len, err := snappy.DecodedLen(bytes)
   492  			if err != nil {
   493  				c.Logger().Error("snappy.DecodedLen", zap.Error(err))
   494  				return nil, err
   495  
   496  			}
   497  			dst := make([]byte, len)
   498  			bytes, err = snappy.Decode(dst, bytes)
   499  			if err != nil {
   500  				c.Logger().Error("snappy.Decode", zap.Error(err))
   501  				return nil, err
   502  
   503  			}
   504  		}
   505  	}
   506  
   507  	var p channeldpb.Packet
   508  	if err := proto.Unmarshal(bytes, &p); err != nil {
   509  		c.Logger().Error("failed to unmarshall packet, the connection will be closed", zap.Error(err),
   510  			zap.Uint32("size", uint32(packetSize)),
   511  			zap.Binary("tag", tag),
   512  		)
   513  		//if c.connectionType == channeldpb.ConnectionType_CLIENT {
   514  		connectionClosed.WithLabelValues(c.connectionType.String()).Inc()
   515  		return nil, nil
   516  	}
   517  
   518  	packetReceived.WithLabelValues(c.connectionType.String()).Inc()
   519  
   520  	if c.isPacketRecordingEnabled() {
   521  		c.recordPacket(&p)
   522  	}
   523  
   524  	for _, mp := range p.Messages {
   525  		c.receiveMessage(mp)
   526  	}
   527  
   528  	*bufPos += fullSize
   529  	return &p, nil
   530  }
   531  
   532  func (c *Connection) isPacketRecordingEnabled() bool {
   533  	return c.connectionType == channeldpb.ConnectionType_CLIENT && GlobalSettings.EnableRecordPacket
   534  }
   535  
   536  func (c *Connection) receiveMessage(mp *channeldpb.MessagePack) {
   537  	channel := GetChannel(common.ChannelId(mp.ChannelId))
   538  	if channel == nil {
   539  		// Sub to/unsub from a removed channel is allowed
   540  		if mp.MsgType != uint32(channeldpb.MessageType_SUB_TO_CHANNEL) && mp.MsgType != uint32(channeldpb.MessageType_UNSUB_FROM_CHANNEL) {
   541  			c.Logger().Warn("can't find channel",
   542  				zap.Uint32("channelId", mp.ChannelId),
   543  				zap.Uint32("msgType", mp.MsgType),
   544  			)
   545  		}
   546  		return
   547  	}
   548  
   549  	entry := MessageMap[channeldpb.MessageType(mp.MsgType)]
   550  	if entry == nil && mp.MsgType < uint32(channeldpb.MessageType_USER_SPACE_START) {
   551  		c.Logger().Error("undefined message type", zap.Uint32("msgType", mp.MsgType))
   552  		return
   553  	}
   554  
   555  	if !c.fsm.IsAllowed(mp.MsgType) {
   556  		Event_FsmDisallowed.Broadcast(c)
   557  		c.Logger().Warn("message is not allowed for current state",
   558  			zap.Uint32("msgType", mp.MsgType),
   559  			zap.String("connState", c.fsm.CurrentState().Name),
   560  		)
   561  		return
   562  	}
   563  
   564  	var msg common.Message
   565  	var handler MessageHandlerFunc
   566  	if mp.MsgType >= uint32(channeldpb.MessageType_USER_SPACE_START) && entry == nil {
   567  		// client -> channeld -> server
   568  		if c.connectionType == channeldpb.ConnectionType_CLIENT {
   569  			// User-space message without handler won't be deserialized.
   570  			msg = &channeldpb.ServerForwardMessage{ClientConnId: uint32(c.id), Payload: mp.MsgBody}
   571  			handler = handleClientToServerUserMessage
   572  		} else {
   573  			// server -> channeld -> client/server
   574  			msg = &channeldpb.ServerForwardMessage{}
   575  			err := proto.Unmarshal(mp.MsgBody, msg)
   576  			if err != nil {
   577  				c.Logger().Error("unmarshalling ServerForwardMessage", zap.Error(err))
   578  				return
   579  			}
   580  			handler = HandleServerToClientUserMessage
   581  		}
   582  	} else {
   583  		handler = entry.handler
   584  		// Always make a clone!
   585  		msg = proto.Clone(entry.msg)
   586  		err := proto.Unmarshal(mp.MsgBody, msg)
   587  		if err != nil {
   588  			c.Logger().Error("unmarshalling message", zap.Error(err))
   589  			return
   590  		}
   591  	}
   592  
   593  	c.fsm.OnReceived(mp.MsgType)
   594  
   595  	channel.PutMessage(msg, handler, c, mp)
   596  
   597  	c.Logger().VeryVerbose("received message", zap.Uint32("msgType", mp.MsgType), zap.Int("size", len(mp.MsgBody)))
   598  	//c.Logger().Debug("received message", zap.Uint32("msgType", mp.MsgType), zap.Int("size", len(mp.MsgBody)))
   599  
   600  	msgReceived.WithLabelValues(c.connectionType.String()).Inc() /*.WithLabelValues(
   601  		strconv.FormatUint(uint64(p.ChannelId), 10),
   602  		strconv.FormatUint(uint64(p.MsgType), 10),
   603  	)*/
   604  }
   605  
   606  func (c *Connection) Send(ctx MessageContext) {
   607  	if c.IsClosing() {
   608  		return
   609  	}
   610  
   611  	c.sender.Send(c, ctx)
   612  }
   613  
   614  // Should NOT be called outside the flush goroutine!
   615  func (c *Connection) flush() {
   616  	if len(c.sendQueue) == 0 {
   617  		return
   618  	}
   619  
   620  	p := channeldpb.Packet{Messages: make([]*channeldpb.MessagePack, 0, len(c.sendQueue))}
   621  	size := 0
   622  
   623  	// Add the oversided message pack first if any
   624  	if c.oversizedMsgPack != nil {
   625  		p.Messages = append(p.Messages, c.oversizedMsgPack)
   626  		c.oversizedMsgPack = nil
   627  		// No need to check the packet size now, as each message pack is already checked before adding to the queue.
   628  	}
   629  
   630  	// For now we don't limit the message numbers per packet
   631  	for len(c.sendQueue) > 0 {
   632  		mp := <-c.sendQueue
   633  		p.Messages = append(p.Messages, mp)
   634  		size = proto.Size(&p)
   635  		if size > MaxPacketSize {
   636  			c.Logger().Info("packet is going to be oversized",
   637  				zap.Int("packetSize", size),
   638  				zap.Uint32("msgType", uint32(mp.MsgType)),
   639  				zap.Int("msgSize", len(mp.MsgBody)),
   640  				zap.Int("msgNum", len(p.Messages)),
   641  				zap.Int("msgInQueue", len(c.sendQueue)),
   642  			)
   643  
   644  			// Revert adding the message that causes the oversize
   645  			p.Messages = p.Messages[:len(p.Messages)-1]
   646  
   647  			// Store the message pack that causes the overside
   648  			c.oversizedMsgPack = mp
   649  			break
   650  		}
   651  
   652  		c.Logger().VeryVerbose("sent message", zap.Uint32("msgType", uint32(mp.MsgType)), zap.Int("size", len(mp.MsgBody)))
   653  
   654  		msgSent.WithLabelValues(c.connectionType.String()).Inc() /*.WithLabelValues(
   655  			strconv.FormatUint(uint64(e.Channel.id), 10),
   656  			strconv.FormatUint(uint64(e.MsgType), 10),
   657  		)*/
   658  	}
   659  
   660  	bytes, err := proto.Marshal(&p)
   661  	if err != nil {
   662  		c.Logger().Error("failed to marshal packet", zap.Error(err))
   663  		return
   664  	}
   665  
   666  	// Apply the compression
   667  	if c.compressionType == channeldpb.CompressionType_SNAPPY {
   668  		dst := make([]byte, snappy.MaxEncodedLen(len(bytes)))
   669  		bytes = snappy.Encode(dst, bytes)
   670  	}
   671  
   672  	// 'CHNL' in ASCII
   673  	tag := []byte{67, 72, 78, 76, byte(c.compressionType)}
   674  	len := len(bytes)
   675  	tag[3] = byte(len & 0xff)
   676  	tag[2] = byte((len >> 8) & 0xff)
   677  	if len > MaxPacketSize {
   678  		// Should never happen, but log it just in case
   679  		c.Logger().Error("packet is oversized", zap.Int("size", len))
   680  		return
   681  	}
   682  
   683  	/* Avoid writing multple times. With WebSocket, every Write() sends a message.
   684  	writer.Write(tag)
   685  	*/
   686  	bytes = append(tag, bytes...)
   687  	/*
   688  		_, err = c.writer.Write(bytes)
   689  		if err != nil {
   690  			c.Logger().Error("error writing packet", zap.Error(err))
   691  			return
   692  		}
   693  
   694  		c.writer.Flush()
   695  	*/
   696  	len, err = c.conn.Write(bytes)
   697  	if err != nil {
   698  		c.Logger().Error("error writing packet", zap.Error(err))
   699  	}
   700  
   701  	packetSent.WithLabelValues(c.connectionType.String()).Inc()
   702  	bytesSent.WithLabelValues(c.connectionType.String()).Add(float64(len))
   703  }
   704  
   705  func (c *Connection) Disconnect() error {
   706  	return c.conn.Close()
   707  }
   708  
   709  func (c *Connection) Id() ConnectionId {
   710  	return c.id
   711  }
   712  
   713  func (c *Connection) GetConnectionType() channeldpb.ConnectionType {
   714  	return c.connectionType
   715  }
   716  
   717  func (c *Connection) OnAuthenticated(pit string) {
   718  	if c.IsClosing() {
   719  		return
   720  	}
   721  
   722  	atomic.StoreInt32(&c.state, ConnectionState_AUTHENTICATED)
   723  
   724  	unauthenticatedConnections.Delete(c.id)
   725  
   726  	c.pit = pit
   727  
   728  	if !c.fsm.MoveToNextState() {
   729  		c.Logger().Error("no state found after the authenticated state")
   730  	}
   731  }
   732  
   733  func (c *Connection) String() string {
   734  	return fmt.Sprintf("Connection(%s %d %s)", c.connectionType, c.id, c.fsm.CurrentState().Name)
   735  }
   736  
   737  func (c *Connection) Logger() *Logger {
   738  	return c.logger
   739  }
   740  
   741  func (c *Connection) RemoteAddr() net.Addr {
   742  	/* The address should still be available even after the connection is closed.
   743  	 * In this way, the anit-DDoS can save the address to the blacklist.
   744  	if c.IsClosing() {
   745  		return nil
   746  	}
   747  	*/
   748  	return c.conn.RemoteAddr()
   749  }
   750  
   751  func (c *Connection) recordPacket(p *channeldpb.Packet) {
   752  
   753  	recordedPacket := &channeldpb.Packet{
   754  		Messages: make([]*channeldpb.MessagePack, 0, len(p.Messages)),
   755  	}
   756  	proto.Merge(recordedPacket, p)
   757  
   758  	c.replaySession.Packets = append(c.replaySession.Packets, &replaypb.ReplayPacket{
   759  		OffsetTime: time.Now().UnixNano(),
   760  		Packet:     recordedPacket,
   761  	})
   762  }
   763  
   764  func (c *Connection) persistReplaySession() {
   765  
   766  	var prevPacketTime int64
   767  	if len(c.replaySession.Packets) > 0 {
   768  		prevPacketTime = c.replaySession.Packets[0].OffsetTime
   769  	} else {
   770  		c.Logger().Error("replay session is empty")
   771  		return
   772  	}
   773  
   774  	for _, packet := range c.replaySession.Packets {
   775  		t := packet.OffsetTime
   776  		packet.OffsetTime -= prevPacketTime
   777  		prevPacketTime = t
   778  	}
   779  
   780  	data, err := proto.Marshal(c.replaySession)
   781  	if err != nil {
   782  		c.Logger().Error("failed to marshal replay session", zap.Error(err))
   783  		return
   784  	}
   785  
   786  	var dir string
   787  	if GlobalSettings.ReplaySessionPersistenceDir != "" {
   788  		dir = GlobalSettings.ReplaySessionPersistenceDir
   789  	} else {
   790  		dir = "replays"
   791  	}
   792  
   793  	_, err = os.Stat(dir)
   794  	if err == nil || !os.IsExist(err) {
   795  		os.MkdirAll(dir, 0777)
   796  	}
   797  
   798  	path := filepath.Join(dir, fmt.Sprintf("session_%d_%s.cpr", c.id, time.Now().Local().Format("06-01-02_15-04-03")))
   799  	err = os.WriteFile(path, data, 0777)
   800  	if err != nil {
   801  		c.Logger().Error("failed to write replay session to location", zap.Error(err))
   802  	}
   803  
   804  }