github.com/metaworking/channeld@v0.7.3/pkg/channeld/connection.go (about) 1 package channeld 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "net" 8 "os" 9 "path/filepath" 10 "sync/atomic" 11 "time" 12 13 "github.com/golang/snappy" 14 "github.com/gorilla/websocket" 15 "github.com/metaworking/channeld/pkg/channeldpb" 16 "github.com/metaworking/channeld/pkg/common" 17 "github.com/metaworking/channeld/pkg/fsm" 18 "github.com/metaworking/channeld/pkg/replaypb" 19 "github.com/puzpuzpuz/xsync/v2" 20 "github.com/xtaci/kcp-go" 21 "go.uber.org/zap" 22 "google.golang.org/protobuf/proto" 23 ) 24 25 type ConnectionId uint32 26 27 const MaxPacketSize int = 0x00ffff 28 const PacketHeaderSize int = 5 29 30 //type ConnectionState int32 31 32 const ( 33 ConnectionState_UNAUTHENTICATED int32 = 0 34 ConnectionState_AUTHENTICATED int32 = 1 35 ConnectionState_CLOSING int32 = 2 36 ) 37 38 // Add an interface before the underlying network layer for the test purpose. 39 type MessageSender interface { 40 Send(c *Connection, ctx MessageContext) //(c *Connection, channelId ChannelId, msgType channeldpb.MessageType, msg Message) 41 } 42 43 /* 44 type queuedMessageCtxSender struct { 45 MessageSender 46 } 47 48 func (s *queuedMessageCtxSender) Send(c *Connection, ctx MessageContext) { 49 c.sendQueue <- ctx 50 } 51 */ 52 53 type queuedMessagePackSender struct { 54 MessageSender 55 } 56 57 func (s *queuedMessagePackSender) Send(c *Connection, ctx MessageContext) { 58 msgBody, err := proto.Marshal(ctx.Msg) 59 if err != nil { 60 c.logger.Error("failed to marshal message", zap.Error(err), zap.Uint32("msgType", uint32(ctx.MsgType))) 61 return 62 } 63 64 mp := &channeldpb.MessagePack{ 65 ChannelId: ctx.ChannelId, 66 Broadcast: ctx.Broadcast, 67 StubId: ctx.StubId, 68 MsgType: uint32(ctx.MsgType), 69 MsgBody: msgBody, 70 } 71 72 // Check the message pack size before adding to the queue 73 size := proto.Size(mp) 74 if size >= MaxPacketSize-PacketHeaderSize { 75 c.logger.Warn("failed to send the message and its size exceeds the limit", zap.Int("size", size)) 76 return 77 } 78 79 // Double check 80 if !c.IsClosing() { 81 c.sendQueue <- mp 82 } 83 } 84 85 type Connection struct { 86 ConnectionInChannel 87 id ConnectionId 88 connectionType channeldpb.ConnectionType 89 compressionType channeldpb.CompressionType 90 conn net.Conn 91 readBuffer []byte 92 readPos int 93 // reader *bufio.Reader 94 // writer *bufio.Writer 95 sender MessageSender 96 sendQueue chan *channeldpb.MessagePack //MessageContext 97 oversizedMsgPack *channeldpb.MessagePack 98 pit string 99 fsm *fsm.FiniteStateMachine 100 fsmDisallowedCounter int 101 logger *Logger 102 state int32 // Don't put the connection state into the FSM as 1) the FSM's states are user-defined. 2) the FSM is not goroutine-safe. 103 connTime time.Time 104 closeHandlers []func() 105 replaySession *replaypb.ReplaySession 106 spatialSubscriptions *xsync.MapOf[common.ChannelId, *channeldpb.ChannelSubscriptionOptions] 107 } 108 109 var allConnections *xsync.MapOf[ConnectionId, *Connection] 110 var nextConnectionId uint32 = 0 111 var serverFsm *fsm.FiniteStateMachine 112 var clientFsm *fsm.FiniteStateMachine 113 114 func InitConnections(serverFsmPath string, clientFsmPath string) { 115 if allConnections != nil { 116 return 117 } 118 119 allConnections = xsync.NewTypedMapOf[ConnectionId, *Connection](UintIdHasher[ConnectionId]()) 120 121 bytes, err := os.ReadFile(serverFsmPath) 122 if err == nil { 123 serverFsm, err = fsm.Load(bytes) 124 } 125 if err != nil { 126 rootLogger.Panic("failed to read server FSM", 127 zap.Error(err), 128 ) 129 } else { 130 rootLogger.Info("loaded server FSM", 131 zap.String("path", serverFsmPath), 132 zap.String("currentState", serverFsm.CurrentState().Name), 133 ) 134 } 135 136 bytes, err = os.ReadFile(clientFsmPath) 137 if err == nil { 138 clientFsm, err = fsm.Load(bytes) 139 } 140 if err != nil { 141 rootLogger.Panic("failed to read client FSM", zap.Error(err)) 142 } else { 143 rootLogger.Info("loaded client FSM", 144 zap.String("path", clientFsmPath), 145 zap.String("currentState", clientFsm.CurrentState().Name), 146 ) 147 } 148 } 149 150 func GetConnection(id ConnectionId) *Connection { 151 c, ok := allConnections.Load(id) 152 if ok { 153 if c.IsClosing() { 154 return nil 155 } 156 return c 157 } else { 158 return nil 159 } 160 } 161 162 func startGoroutines(connection *Connection) { 163 // receive goroutine 164 go func() { 165 for !connection.IsClosing() { 166 connection.receive() 167 } 168 }() 169 170 // tick & flush goroutine 171 go func() { 172 for !connection.IsClosing() { 173 connection.flush() 174 time.Sleep(time.Millisecond) 175 } 176 }() 177 } 178 179 func StartListening(t channeldpb.ConnectionType, network string, address string) { 180 rootLogger.Info("start listenning", 181 zap.String("connType", t.String()), 182 zap.String("network", network), 183 zap.String("address", address), 184 ) 185 186 var listener net.Listener 187 var err error 188 switch network { 189 case "ws", "websocket": 190 startWebSocketServer(t, address) 191 return 192 case "kcp": 193 listener, err = kcp.Listen(address) 194 default: 195 listener, err = net.Listen(network, address) 196 } 197 198 if err != nil { 199 rootLogger.Panic("failed to listen", zap.Error(err)) 200 return 201 } 202 203 defer listener.Close() 204 205 for { 206 conn, err := listener.Accept() 207 if err != nil { 208 rootLogger.Error("failed to accept connection", zap.Error(err)) 209 } else { 210 if network == "tcp" { 211 tcpConn := conn.(*net.TCPConn) 212 if err := tcpConn.SetReadBuffer(0x0fffff); err != nil { 213 rootLogger.Error("failed to set read buffer size", zap.Error(err)) 214 } 215 if err := tcpConn.SetWriteBuffer(0x0fffff); err != nil { 216 rootLogger.Error("failed to set write buffer size", zap.Error(err)) 217 } 218 tcpConn.SetNoDelay(true) 219 } 220 221 // Check if the IP address is banned. 222 ip := GetIP(conn.RemoteAddr()) 223 _, banned := ipBlacklist[ip] 224 if banned { 225 securityLogger.Info("refused connection of banned IP address", zap.String("ip", ip)) 226 conn.Close() 227 continue 228 } 229 230 connection := AddConnection(conn, t) 231 connection.Logger().Debug("accepted connection") 232 startGoroutines(connection) 233 } 234 } 235 } 236 237 func generateNextConnId(c net.Conn, maxConnId uint32) { 238 if GlobalSettings.Development { 239 atomic.AddUint32(&nextConnectionId, 1) 240 if nextConnectionId >= maxConnId { 241 // For now, we don't consider re-using the ConnectionId. Even if there are 100 incoming connections per sec, channeld can run over a year. 242 rootLogger.Panic("connectionId reached the limit", zap.Uint32("maxConnId", maxConnId)) 243 } 244 } else { 245 // In non-dev mode, hash the (remote address + timestamp) to get a less guessable ID 246 hash := HashString(c.RemoteAddr().String()) 247 hash = hash ^ uint32(time.Now().UnixNano()) 248 nextConnectionId = hash & maxConnId 249 } 250 } 251 252 // NOT goroutine-safe. NEVER call AddConnection in different goroutines. 253 func AddConnection(c net.Conn, t channeldpb.ConnectionType) *Connection { 254 var readerSize int 255 // var writerSize int 256 if t == channeldpb.ConnectionType_SERVER { 257 readerSize = GlobalSettings.ServerReadBufferSize 258 // writerSize = GlobalSettings.ServerWriteBufferSize 259 } else if t == channeldpb.ConnectionType_CLIENT { 260 readerSize = GlobalSettings.ClientReadBufferSize 261 // writerSize = GlobalSettings.ClientWriteBufferSize 262 } else { 263 rootLogger.Panic("invalid connection type", zap.Int32("connType", int32(t))) 264 } 265 if readerSize < MaxPacketSize+PacketHeaderSize { 266 readerSize = MaxPacketSize + PacketHeaderSize 267 } 268 maxConnId := uint32(1)<<GlobalSettings.MaxConnectionIdBits - 1 269 270 for tries := 0; ; tries++ { 271 generateNextConnId(c, maxConnId) 272 if _, exists := allConnections.Load(ConnectionId(nextConnectionId)); !exists { 273 break 274 } 275 276 rootLogger.Warn("there's a same connId existing, will try to generate a new one", zap.Uint32("connId", nextConnectionId)) 277 if tries >= 100 { 278 rootLogger.Panic("could not find non-duplicate connId") 279 } 280 } 281 282 connection := &Connection{ 283 id: ConnectionId(nextConnectionId), 284 connectionType: t, 285 compressionType: channeldpb.CompressionType_NO_COMPRESSION, 286 conn: c, 287 readBuffer: make([]byte, readerSize), 288 readPos: 0, 289 // reader: bufio.NewReaderSize(c, readerSize), 290 // writer: bufio.NewWriterSize(c, writerSize), 291 sender: &queuedMessagePackSender{}, 292 sendQueue: make(chan *channeldpb.MessagePack, 128), 293 fsmDisallowedCounter: 0, 294 logger: &Logger{rootLogger.With( 295 zap.String("connType", t.String()), 296 zap.Uint32("connId", nextConnectionId), 297 )}, 298 state: ConnectionState_UNAUTHENTICATED, 299 connTime: time.Now(), 300 closeHandlers: make([]func(), 0), 301 spatialSubscriptions: xsync.NewTypedMapOf[common.ChannelId, *channeldpb.ChannelSubscriptionOptions](UintIdHasher[common.ChannelId]()), 302 } 303 304 if connection.isPacketRecordingEnabled() { 305 connection.replaySession = &replaypb.ReplaySession{ 306 Packets: make([]*replaypb.ReplayPacket, 0, 1024), 307 } 308 } 309 310 switch t { 311 case channeldpb.ConnectionType_SERVER: 312 if serverFsm != nil { 313 // IMPORTANT: always make a value copy 314 fsm := *serverFsm 315 connection.fsm = &fsm 316 } 317 case channeldpb.ConnectionType_CLIENT: 318 if clientFsm != nil { 319 // IMPORTANT: always make a value copy 320 fsm := *clientFsm 321 connection.fsm = &fsm 322 } 323 } 324 325 if connection.fsm == nil { 326 rootLogger.Panic("cannot set the FSM for connection", zap.String("connType", t.String())) 327 } 328 329 allConnections.Store(connection.id, connection) 330 331 if GlobalSettings.ConnectionAuthTimeoutMs > 0 { 332 unauthenticatedConnections.Store(connection.id, connection) 333 } 334 335 connectionNum.WithLabelValues(t.String()).Inc() 336 337 return connection 338 } 339 340 func (c *Connection) AddCloseHandler(handlerFunc func()) { 341 c.closeHandlers = append(c.closeHandlers, handlerFunc) 342 } 343 344 func (c *Connection) Close() { 345 defer func() { 346 recover() 347 }() 348 if c.IsClosing() { 349 c.Logger().Debug("connection is already closed") 350 return 351 } 352 353 if c.isPacketRecordingEnabled() { 354 c.persistReplaySession() 355 } 356 357 for _, handlerFunc := range c.closeHandlers { 358 handlerFunc() 359 } 360 361 atomic.StoreInt32(&c.state, ConnectionState_CLOSING) 362 c.conn.Close() 363 close(c.sendQueue) 364 allConnections.Delete(c.id) 365 unauthenticatedConnections.Delete(c.id) 366 367 c.Logger().Info("closed connection") 368 connectionNum.WithLabelValues(c.connectionType.String()).Dec() 369 } 370 371 func (c *Connection) IsClosing() bool { 372 return c.state > ConnectionState_AUTHENTICATED 373 } 374 375 func (c *Connection) receive() { 376 // Read all bytes into the buffer at once 377 readPtr := c.readBuffer[c.readPos:] 378 bytesRead, err := c.conn.Read(readPtr) 379 if err != nil { 380 switch err := err.(type) { 381 case *net.OpError: 382 c.Logger().Info("net op error", 383 zap.String("op", err.Op), 384 zap.String("remoteAddr", c.conn.RemoteAddr().String()), 385 zap.Error(err), 386 ) 387 case *websocket.CloseError: 388 c.Logger().Info("disconnected", 389 zap.String("remoteAddr", c.conn.RemoteAddr().String()), 390 ) 391 } 392 393 if err == io.EOF { 394 c.Logger().Info("disconnected", 395 zap.String("remoteAddr", c.conn.RemoteAddr().String()), 396 ) 397 } 398 c.Close() 399 return 400 } 401 c.readPos += bytesRead 402 if c.readPos < PacketHeaderSize { 403 // Unfinished header 404 fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc() 405 return 406 } 407 408 bufPos := 0 409 for bufPos = 0; bufPos < c.readPos; { 410 packet, err := c.readPacket(&bufPos) 411 // there's a wire format error, close the connection to give a quick feedback to the other end. 412 if err != nil { 413 c.Close() 414 return 415 416 } 417 // all fully received packets are handled 418 if packet == nil { 419 break 420 } 421 422 combinedPacketCount.WithLabelValues(c.connectionType.String()).Inc() 423 } 424 425 if bufPos < c.readPos { 426 // Move unhandled content to the front 427 copy(c.readBuffer, c.readBuffer[bufPos:c.readPos]) 428 } 429 430 // Move read position 431 c.readPos -= bufPos 432 } 433 434 func readSize(tag []byte) int { 435 if tag[0] != 67 || tag[1] != 72 { 436 return 0 437 } 438 439 size := int(tag[3]) | int(tag[2])<<8 440 441 return size 442 } 443 444 func (c *Connection) readPacket(bufPos *int) (*channeldpb.Packet, error) { 445 if c.readPos-*bufPos < PacketHeaderSize { 446 // Unfinished header 447 fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc() 448 return nil, nil 449 } 450 451 tag := c.readBuffer[*bufPos : *bufPos+PacketHeaderSize] 452 453 packetSize := readSize(tag) 454 if packetSize == 0 { 455 c.readPos = 0 456 connectionClosed.WithLabelValues(c.connectionType.String()).Inc() 457 c.Logger().Warn("invalid tag, the connection will be closed", 458 zap.Binary("tag", tag), 459 ) 460 return nil, errors.New("invlaid tag") 461 } 462 463 if packetSize > MaxPacketSize { 464 c.readPos = 0 465 connectionClosed.WithLabelValues(c.connectionType.String()).Inc() 466 c.Logger().Warn("packet size exceeds the limit, the connection will be closed", zap.Int("packetSize", packetSize), zap.Int("bufferSize", len(c.readBuffer))) 467 return nil, errors.New("packetSize too large") 468 } 469 470 fullSize := PacketHeaderSize + packetSize 471 472 if c.readPos < *bufPos+fullSize { 473 // Unfinished packet 474 475 fragmentedPacketCount.WithLabelValues(c.connectionType.String()).Inc() 476 // this is a normal case, turn off the logs 477 //c.Logger().Info("read part of package", zap.Int("readpos", c.readPos), zap.Int("full size", fullSize)) 478 return nil, nil 479 } 480 481 bytes := c.readBuffer[*bufPos+PacketHeaderSize : *bufPos+fullSize] 482 483 bytesReceived.WithLabelValues(c.connectionType.String()).Add(float64(fullSize)) 484 485 // Apply the decompression from the 5th byte in the header 486 ct := tag[4] 487 _, valid := channeldpb.CompressionType_name[int32(ct)] 488 if valid && ct != 0 { 489 c.compressionType = channeldpb.CompressionType(ct) 490 if c.compressionType == channeldpb.CompressionType_SNAPPY { 491 len, err := snappy.DecodedLen(bytes) 492 if err != nil { 493 c.Logger().Error("snappy.DecodedLen", zap.Error(err)) 494 return nil, err 495 496 } 497 dst := make([]byte, len) 498 bytes, err = snappy.Decode(dst, bytes) 499 if err != nil { 500 c.Logger().Error("snappy.Decode", zap.Error(err)) 501 return nil, err 502 503 } 504 } 505 } 506 507 var p channeldpb.Packet 508 if err := proto.Unmarshal(bytes, &p); err != nil { 509 c.Logger().Error("failed to unmarshall packet, the connection will be closed", zap.Error(err), 510 zap.Uint32("size", uint32(packetSize)), 511 zap.Binary("tag", tag), 512 ) 513 //if c.connectionType == channeldpb.ConnectionType_CLIENT { 514 connectionClosed.WithLabelValues(c.connectionType.String()).Inc() 515 return nil, nil 516 } 517 518 packetReceived.WithLabelValues(c.connectionType.String()).Inc() 519 520 if c.isPacketRecordingEnabled() { 521 c.recordPacket(&p) 522 } 523 524 for _, mp := range p.Messages { 525 c.receiveMessage(mp) 526 } 527 528 *bufPos += fullSize 529 return &p, nil 530 } 531 532 func (c *Connection) isPacketRecordingEnabled() bool { 533 return c.connectionType == channeldpb.ConnectionType_CLIENT && GlobalSettings.EnableRecordPacket 534 } 535 536 func (c *Connection) receiveMessage(mp *channeldpb.MessagePack) { 537 channel := GetChannel(common.ChannelId(mp.ChannelId)) 538 if channel == nil { 539 // Sub to/unsub from a removed channel is allowed 540 if mp.MsgType != uint32(channeldpb.MessageType_SUB_TO_CHANNEL) && mp.MsgType != uint32(channeldpb.MessageType_UNSUB_FROM_CHANNEL) { 541 c.Logger().Warn("can't find channel", 542 zap.Uint32("channelId", mp.ChannelId), 543 zap.Uint32("msgType", mp.MsgType), 544 ) 545 } 546 return 547 } 548 549 entry := MessageMap[channeldpb.MessageType(mp.MsgType)] 550 if entry == nil && mp.MsgType < uint32(channeldpb.MessageType_USER_SPACE_START) { 551 c.Logger().Error("undefined message type", zap.Uint32("msgType", mp.MsgType)) 552 return 553 } 554 555 if !c.fsm.IsAllowed(mp.MsgType) { 556 Event_FsmDisallowed.Broadcast(c) 557 c.Logger().Warn("message is not allowed for current state", 558 zap.Uint32("msgType", mp.MsgType), 559 zap.String("connState", c.fsm.CurrentState().Name), 560 ) 561 return 562 } 563 564 var msg common.Message 565 var handler MessageHandlerFunc 566 if mp.MsgType >= uint32(channeldpb.MessageType_USER_SPACE_START) && entry == nil { 567 // client -> channeld -> server 568 if c.connectionType == channeldpb.ConnectionType_CLIENT { 569 // User-space message without handler won't be deserialized. 570 msg = &channeldpb.ServerForwardMessage{ClientConnId: uint32(c.id), Payload: mp.MsgBody} 571 handler = handleClientToServerUserMessage 572 } else { 573 // server -> channeld -> client/server 574 msg = &channeldpb.ServerForwardMessage{} 575 err := proto.Unmarshal(mp.MsgBody, msg) 576 if err != nil { 577 c.Logger().Error("unmarshalling ServerForwardMessage", zap.Error(err)) 578 return 579 } 580 handler = HandleServerToClientUserMessage 581 } 582 } else { 583 handler = entry.handler 584 // Always make a clone! 585 msg = proto.Clone(entry.msg) 586 err := proto.Unmarshal(mp.MsgBody, msg) 587 if err != nil { 588 c.Logger().Error("unmarshalling message", zap.Error(err)) 589 return 590 } 591 } 592 593 c.fsm.OnReceived(mp.MsgType) 594 595 channel.PutMessage(msg, handler, c, mp) 596 597 c.Logger().VeryVerbose("received message", zap.Uint32("msgType", mp.MsgType), zap.Int("size", len(mp.MsgBody))) 598 //c.Logger().Debug("received message", zap.Uint32("msgType", mp.MsgType), zap.Int("size", len(mp.MsgBody))) 599 600 msgReceived.WithLabelValues(c.connectionType.String()).Inc() /*.WithLabelValues( 601 strconv.FormatUint(uint64(p.ChannelId), 10), 602 strconv.FormatUint(uint64(p.MsgType), 10), 603 )*/ 604 } 605 606 func (c *Connection) Send(ctx MessageContext) { 607 if c.IsClosing() { 608 return 609 } 610 611 c.sender.Send(c, ctx) 612 } 613 614 // Should NOT be called outside the flush goroutine! 615 func (c *Connection) flush() { 616 if len(c.sendQueue) == 0 { 617 return 618 } 619 620 p := channeldpb.Packet{Messages: make([]*channeldpb.MessagePack, 0, len(c.sendQueue))} 621 size := 0 622 623 // Add the oversided message pack first if any 624 if c.oversizedMsgPack != nil { 625 p.Messages = append(p.Messages, c.oversizedMsgPack) 626 c.oversizedMsgPack = nil 627 // No need to check the packet size now, as each message pack is already checked before adding to the queue. 628 } 629 630 // For now we don't limit the message numbers per packet 631 for len(c.sendQueue) > 0 { 632 mp := <-c.sendQueue 633 p.Messages = append(p.Messages, mp) 634 size = proto.Size(&p) 635 if size > MaxPacketSize { 636 c.Logger().Info("packet is going to be oversized", 637 zap.Int("packetSize", size), 638 zap.Uint32("msgType", uint32(mp.MsgType)), 639 zap.Int("msgSize", len(mp.MsgBody)), 640 zap.Int("msgNum", len(p.Messages)), 641 zap.Int("msgInQueue", len(c.sendQueue)), 642 ) 643 644 // Revert adding the message that causes the oversize 645 p.Messages = p.Messages[:len(p.Messages)-1] 646 647 // Store the message pack that causes the overside 648 c.oversizedMsgPack = mp 649 break 650 } 651 652 c.Logger().VeryVerbose("sent message", zap.Uint32("msgType", uint32(mp.MsgType)), zap.Int("size", len(mp.MsgBody))) 653 654 msgSent.WithLabelValues(c.connectionType.String()).Inc() /*.WithLabelValues( 655 strconv.FormatUint(uint64(e.Channel.id), 10), 656 strconv.FormatUint(uint64(e.MsgType), 10), 657 )*/ 658 } 659 660 bytes, err := proto.Marshal(&p) 661 if err != nil { 662 c.Logger().Error("failed to marshal packet", zap.Error(err)) 663 return 664 } 665 666 // Apply the compression 667 if c.compressionType == channeldpb.CompressionType_SNAPPY { 668 dst := make([]byte, snappy.MaxEncodedLen(len(bytes))) 669 bytes = snappy.Encode(dst, bytes) 670 } 671 672 // 'CHNL' in ASCII 673 tag := []byte{67, 72, 78, 76, byte(c.compressionType)} 674 len := len(bytes) 675 tag[3] = byte(len & 0xff) 676 tag[2] = byte((len >> 8) & 0xff) 677 if len > MaxPacketSize { 678 // Should never happen, but log it just in case 679 c.Logger().Error("packet is oversized", zap.Int("size", len)) 680 return 681 } 682 683 /* Avoid writing multple times. With WebSocket, every Write() sends a message. 684 writer.Write(tag) 685 */ 686 bytes = append(tag, bytes...) 687 /* 688 _, err = c.writer.Write(bytes) 689 if err != nil { 690 c.Logger().Error("error writing packet", zap.Error(err)) 691 return 692 } 693 694 c.writer.Flush() 695 */ 696 len, err = c.conn.Write(bytes) 697 if err != nil { 698 c.Logger().Error("error writing packet", zap.Error(err)) 699 } 700 701 packetSent.WithLabelValues(c.connectionType.String()).Inc() 702 bytesSent.WithLabelValues(c.connectionType.String()).Add(float64(len)) 703 } 704 705 func (c *Connection) Disconnect() error { 706 return c.conn.Close() 707 } 708 709 func (c *Connection) Id() ConnectionId { 710 return c.id 711 } 712 713 func (c *Connection) GetConnectionType() channeldpb.ConnectionType { 714 return c.connectionType 715 } 716 717 func (c *Connection) OnAuthenticated(pit string) { 718 if c.IsClosing() { 719 return 720 } 721 722 atomic.StoreInt32(&c.state, ConnectionState_AUTHENTICATED) 723 724 unauthenticatedConnections.Delete(c.id) 725 726 c.pit = pit 727 728 if !c.fsm.MoveToNextState() { 729 c.Logger().Error("no state found after the authenticated state") 730 } 731 } 732 733 func (c *Connection) String() string { 734 return fmt.Sprintf("Connection(%s %d %s)", c.connectionType, c.id, c.fsm.CurrentState().Name) 735 } 736 737 func (c *Connection) Logger() *Logger { 738 return c.logger 739 } 740 741 func (c *Connection) RemoteAddr() net.Addr { 742 /* The address should still be available even after the connection is closed. 743 * In this way, the anit-DDoS can save the address to the blacklist. 744 if c.IsClosing() { 745 return nil 746 } 747 */ 748 return c.conn.RemoteAddr() 749 } 750 751 func (c *Connection) recordPacket(p *channeldpb.Packet) { 752 753 recordedPacket := &channeldpb.Packet{ 754 Messages: make([]*channeldpb.MessagePack, 0, len(p.Messages)), 755 } 756 proto.Merge(recordedPacket, p) 757 758 c.replaySession.Packets = append(c.replaySession.Packets, &replaypb.ReplayPacket{ 759 OffsetTime: time.Now().UnixNano(), 760 Packet: recordedPacket, 761 }) 762 } 763 764 func (c *Connection) persistReplaySession() { 765 766 var prevPacketTime int64 767 if len(c.replaySession.Packets) > 0 { 768 prevPacketTime = c.replaySession.Packets[0].OffsetTime 769 } else { 770 c.Logger().Error("replay session is empty") 771 return 772 } 773 774 for _, packet := range c.replaySession.Packets { 775 t := packet.OffsetTime 776 packet.OffsetTime -= prevPacketTime 777 prevPacketTime = t 778 } 779 780 data, err := proto.Marshal(c.replaySession) 781 if err != nil { 782 c.Logger().Error("failed to marshal replay session", zap.Error(err)) 783 return 784 } 785 786 var dir string 787 if GlobalSettings.ReplaySessionPersistenceDir != "" { 788 dir = GlobalSettings.ReplaySessionPersistenceDir 789 } else { 790 dir = "replays" 791 } 792 793 _, err = os.Stat(dir) 794 if err == nil || !os.IsExist(err) { 795 os.MkdirAll(dir, 0777) 796 } 797 798 path := filepath.Join(dir, fmt.Sprintf("session_%d_%s.cpr", c.id, time.Now().Local().Format("06-01-02_15-04-03"))) 799 err = os.WriteFile(path, data, 0777) 800 if err != nil { 801 c.Logger().Error("failed to write replay session to location", zap.Error(err)) 802 } 803 804 }