github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_session.go (about) 1 package service 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "net" 8 "net/netip" 9 "os" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/database64128/shadowsocks-go/conn" 15 "github.com/database64128/shadowsocks-go/router" 16 "github.com/database64128/shadowsocks-go/stats" 17 "github.com/database64128/shadowsocks-go/zerocopy" 18 "go.uber.org/zap" 19 ) 20 21 // sessionQueuedPacket is the structure used by send channels to queue packets for sending. 22 type sessionQueuedPacket struct { 23 buf []byte 24 start int 25 length int 26 targetAddr conn.Addr 27 clientAddrPort netip.AddrPort 28 } 29 30 // sessionClientAddrInfo stores a session's client address information. 31 type sessionClientAddrInfo struct { 32 addrPort netip.AddrPort 33 pktinfo []byte 34 } 35 36 // session keeps track of a UDP session. 37 type session struct { 38 // state synchronizes session initialization and shutdown. 39 // 40 // - Swap the natConn in to signal initialization completion. 41 // - Swap the serverConn in to signal shutdown. 42 // 43 // Callers must check the swapped-out value to determine the next action. 44 // 45 // - During initialization, if the swapped-out value is non-nil, 46 // initialization must not proceed. 47 // - During shutdown, if the swapped-out value is nil, preceed to the next entry. 48 state atomic.Pointer[net.UDPConn] 49 clientAddrInfo atomic.Pointer[sessionClientAddrInfo] 50 clientAddrPortCache netip.AddrPort 51 clientPktinfoCache []byte 52 natConnSendCh chan<- *sessionQueuedPacket 53 serverConn *net.UDPConn 54 serverConnUnpacker zerocopy.ServerUnpacker 55 username string 56 logger *zap.Logger 57 } 58 59 // sessionUplinkGeneric is used for passing information about relay uplink to the relay goroutine. 60 type sessionUplinkGeneric struct { 61 csid uint64 62 clientName string 63 natConn *net.UDPConn 64 natConnSendCh <-chan *sessionQueuedPacket 65 natConnPacker zerocopy.ClientPacker 66 natTimeout time.Duration 67 username string 68 logger *zap.Logger 69 } 70 71 // sessionDownlinkGeneric is used for passing information about relay downlink to the relay goroutine. 72 type sessionDownlinkGeneric struct { 73 csid uint64 74 clientName string 75 clientAddrInfop *sessionClientAddrInfo 76 clientAddrInfo *atomic.Pointer[sessionClientAddrInfo] 77 natConn *net.UDPConn 78 natConnRecvBufSize int 79 natConnUnpacker zerocopy.ClientUnpacker 80 serverConn *net.UDPConn 81 serverConnPacker zerocopy.ServerPacker 82 username string 83 logger *zap.Logger 84 } 85 86 // UDPSessionRelay is a session-based UDP relay service. 87 // 88 // Incoming UDP packets are dispatched to NAT sessions based on the client session ID. 89 type UDPSessionRelay struct { 90 serverName string 91 serverIndex int 92 mtu int 93 packetBufFrontHeadroom int 94 packetBufRecvSize int 95 listeners []udpRelayServerConn 96 server zerocopy.UDPSessionServer 97 collector stats.Collector 98 router *router.Router 99 logger *zap.Logger 100 queuedPacketPool sync.Pool 101 wg sync.WaitGroup 102 mwg sync.WaitGroup 103 table map[uint64]*session 104 } 105 106 func NewUDPSessionRelay( 107 serverName string, 108 serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int, 109 listeners []udpRelayServerConn, 110 server zerocopy.UDPSessionServer, 111 collector stats.Collector, 112 router *router.Router, 113 logger *zap.Logger, 114 ) *UDPSessionRelay { 115 return &UDPSessionRelay{ 116 serverName: serverName, 117 serverIndex: serverIndex, 118 mtu: mtu, 119 packetBufFrontHeadroom: packetBufFrontHeadroom, 120 packetBufRecvSize: packetBufRecvSize, 121 listeners: listeners, 122 server: server, 123 collector: collector, 124 router: router, 125 logger: logger, 126 queuedPacketPool: sync.Pool{ 127 New: func() any { 128 return &sessionQueuedPacket{ 129 buf: make([]byte, packetBufSize), 130 } 131 }, 132 }, 133 table: make(map[uint64]*session), 134 } 135 } 136 137 // String implements the Service String method. 138 func (s *UDPSessionRelay) String() string { 139 return "UDP session relay service for " + s.serverName 140 } 141 142 // Start implements the Service Start method. 143 func (s *UDPSessionRelay) Start(ctx context.Context) error { 144 for i := range s.listeners { 145 if err := s.start(ctx, i, &s.listeners[i]); err != nil { 146 return err 147 } 148 } 149 return nil 150 } 151 152 func (s *UDPSessionRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) { 153 lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) 154 if err != nil { 155 return 156 } 157 lnc.address = lnc.serverConn.LocalAddr().String() 158 lnc.logger = s.logger.With( 159 zap.String("server", s.serverName), 160 zap.Int("listener", index), 161 zap.String("listenAddress", lnc.address), 162 ) 163 164 s.mwg.Add(1) 165 166 go func() { 167 s.recvFromServerConnGeneric(ctx, lnc) 168 s.mwg.Done() 169 }() 170 171 lnc.logger.Info("Started UDP session relay service listener") 172 return 173 } 174 175 func (s *UDPSessionRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRelayServerConn) { 176 cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize) 177 178 var ( 179 n int 180 cmsgn int 181 flags int 182 err error 183 packetsReceived uint64 184 payloadBytesReceived uint64 185 ) 186 187 for { 188 queuedPacket := s.getQueuedPacket() 189 recvBuf := queuedPacket.buf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+s.packetBufRecvSize] 190 191 n, cmsgn, flags, queuedPacket.clientAddrPort, err = lnc.serverConn.ReadMsgUDPAddrPort(recvBuf, cmsgBuf) 192 if err != nil { 193 if errors.Is(err, os.ErrDeadlineExceeded) { 194 s.putQueuedPacket(queuedPacket) 195 break 196 } 197 198 lnc.logger.Warn("Failed to read packet from serverConn", 199 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 200 zap.Int("packetLength", n), 201 zap.Error(err), 202 ) 203 204 s.putQueuedPacket(queuedPacket) 205 continue 206 } 207 err = conn.ParseFlagsForError(flags) 208 if err != nil { 209 lnc.logger.Warn("Failed to read packet from serverConn", 210 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 211 zap.Int("packetLength", n), 212 zap.Error(err), 213 ) 214 215 s.putQueuedPacket(queuedPacket) 216 continue 217 } 218 219 packet := recvBuf[:n] 220 221 csid, err := s.server.SessionInfo(packet) 222 if err != nil { 223 lnc.logger.Warn("Failed to extract session info from packet", 224 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 225 zap.Int("packetLength", n), 226 zap.Error(err), 227 ) 228 229 s.putQueuedPacket(queuedPacket) 230 continue 231 } 232 233 s.server.Lock() 234 235 entry, ok := s.table[csid] 236 if !ok { 237 entry = &session{ 238 serverConn: lnc.serverConn, 239 logger: lnc.logger, 240 } 241 242 entry.serverConnUnpacker, entry.username, err = s.server.NewUnpacker(packet, csid) 243 if err != nil { 244 lnc.logger.Warn("Failed to create unpacker for client session", 245 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 246 zap.Uint64("clientSessionID", csid), 247 zap.Int("packetLength", n), 248 zap.Error(err), 249 ) 250 251 s.putQueuedPacket(queuedPacket) 252 s.server.Unlock() 253 continue 254 } 255 } 256 257 queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(queuedPacket.buf, queuedPacket.clientAddrPort, s.packetBufFrontHeadroom, n) 258 if err != nil { 259 lnc.logger.Warn("Failed to unpack packet", 260 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 261 zap.String("username", entry.username), 262 zap.Uint64("clientSessionID", csid), 263 zap.Int("packetLength", n), 264 zap.Error(err), 265 ) 266 267 s.putQueuedPacket(queuedPacket) 268 s.server.Unlock() 269 continue 270 } 271 272 packetsReceived++ 273 payloadBytesReceived += uint64(queuedPacket.length) 274 275 var clientAddrInfop *sessionClientAddrInfo 276 cmsg := cmsgBuf[:cmsgn] 277 278 updateClientAddrPort := entry.clientAddrPortCache != queuedPacket.clientAddrPort 279 updateClientPktinfo := !bytes.Equal(entry.clientPktinfoCache, cmsg) 280 281 if updateClientAddrPort { 282 entry.clientAddrPortCache = queuedPacket.clientAddrPort 283 } 284 285 if updateClientPktinfo { 286 entry.clientPktinfoCache = make([]byte, len(cmsg)) 287 copy(entry.clientPktinfoCache, cmsg) 288 } 289 290 if updateClientAddrPort || updateClientPktinfo { 291 clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) 292 if err != nil { 293 lnc.logger.Warn("Failed to parse pktinfo control message from serverConn", 294 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 295 zap.String("username", entry.username), 296 zap.Uint64("clientSessionID", csid), 297 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 298 zap.Error(err), 299 ) 300 301 s.putQueuedPacket(queuedPacket) 302 s.server.Unlock() 303 continue 304 } 305 306 clientAddrInfop = &sessionClientAddrInfo{entry.clientAddrPortCache, entry.clientPktinfoCache} 307 entry.clientAddrInfo.Store(clientAddrInfop) 308 309 if ce := lnc.logger.Check(zap.DebugLevel, "Updated client address info"); ce != nil { 310 ce.Write( 311 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 312 zap.String("username", entry.username), 313 zap.Uint64("clientSessionID", csid), 314 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 315 zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), 316 zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), 317 ) 318 } 319 } 320 321 if !ok { 322 natConnSendCh := make(chan *sessionQueuedPacket, lnc.sendChannelCapacity) 323 entry.natConnSendCh = natConnSendCh 324 s.table[csid] = entry 325 s.wg.Add(1) 326 327 go func() { 328 var sendChClean bool 329 330 defer func() { 331 s.server.Lock() 332 close(natConnSendCh) 333 delete(s.table, csid) 334 s.server.Unlock() 335 336 if !sendChClean { 337 for queuedPacket := range natConnSendCh { 338 s.putQueuedPacket(queuedPacket) 339 } 340 } 341 342 s.wg.Done() 343 }() 344 345 c, err := s.router.GetUDPClient(ctx, router.RequestInfo{ 346 ServerIndex: s.serverIndex, 347 Username: entry.username, 348 SourceAddrPort: queuedPacket.clientAddrPort, 349 TargetAddr: queuedPacket.targetAddr, 350 }) 351 if err != nil { 352 lnc.logger.Warn("Failed to get UDP client for new NAT session", 353 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 354 zap.String("username", entry.username), 355 zap.Uint64("clientSessionID", csid), 356 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 357 zap.Error(err), 358 ) 359 return 360 } 361 362 clientInfo, clientSession, err := c.NewSession(ctx) 363 if err != nil { 364 lnc.logger.Warn("Failed to create new UDP client session", 365 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 366 zap.String("username", entry.username), 367 zap.Uint64("clientSessionID", csid), 368 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 369 zap.String("client", clientInfo.Name), 370 zap.Error(err), 371 ) 372 return 373 } 374 375 natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") 376 if err != nil { 377 lnc.logger.Warn("Failed to create UDP socket for new NAT session", 378 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 379 zap.String("username", entry.username), 380 zap.Uint64("clientSessionID", csid), 381 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 382 zap.String("client", clientInfo.Name), 383 zap.Error(err), 384 ) 385 clientSession.Close() 386 return 387 } 388 389 err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout)) 390 if err != nil { 391 lnc.logger.Warn("Failed to set read deadline on natConn", 392 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 393 zap.String("username", entry.username), 394 zap.Uint64("clientSessionID", csid), 395 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 396 zap.String("client", clientInfo.Name), 397 zap.Duration("natTimeout", lnc.natTimeout), 398 zap.Error(err), 399 ) 400 natConn.Close() 401 clientSession.Close() 402 return 403 } 404 405 serverConnPacker, err := entry.serverConnUnpacker.NewPacker() 406 if err != nil { 407 lnc.logger.Warn("Failed to create packer for client session", 408 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 409 zap.String("username", entry.username), 410 zap.Uint64("clientSessionID", csid), 411 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 412 zap.Error(err), 413 ) 414 natConn.Close() 415 clientSession.Close() 416 return 417 } 418 419 oldState := entry.state.Swap(natConn) 420 if oldState != nil { 421 natConn.Close() 422 clientSession.Close() 423 return 424 } 425 426 // No more early returns! 427 sendChClean = true 428 429 lnc.logger.Info("UDP session relay started", 430 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 431 zap.String("username", entry.username), 432 zap.Uint64("clientSessionID", csid), 433 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 434 zap.String("client", clientInfo.Name), 435 ) 436 437 s.wg.Add(1) 438 439 go func() { 440 s.relayServerConnToNatConnGeneric(ctx, sessionUplinkGeneric{ 441 csid: csid, 442 clientName: clientInfo.Name, 443 natConn: natConn, 444 natConnSendCh: natConnSendCh, 445 natConnPacker: clientSession.Packer, 446 natTimeout: lnc.natTimeout, 447 username: entry.username, 448 logger: lnc.logger, 449 }) 450 natConn.Close() 451 clientSession.Close() 452 s.wg.Done() 453 }() 454 455 s.relayNatConnToServerConnGeneric(sessionDownlinkGeneric{ 456 csid: csid, 457 clientName: clientInfo.Name, 458 clientAddrInfop: clientAddrInfop, 459 clientAddrInfo: &entry.clientAddrInfo, 460 natConn: natConn, 461 natConnRecvBufSize: clientSession.MaxPacketSize, 462 natConnUnpacker: clientSession.Unpacker, 463 serverConn: lnc.serverConn, 464 serverConnPacker: serverConnPacker, 465 username: entry.username, 466 logger: lnc.logger, 467 }) 468 }() 469 470 if ce := lnc.logger.Check(zap.DebugLevel, "New UDP session"); ce != nil { 471 ce.Write( 472 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 473 zap.String("username", entry.username), 474 zap.Uint64("clientSessionID", csid), 475 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 476 ) 477 } 478 } 479 480 select { 481 case entry.natConnSendCh <- queuedPacket: 482 default: 483 if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil { 484 ce.Write( 485 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 486 zap.String("username", entry.username), 487 zap.Uint64("clientSessionID", csid), 488 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 489 ) 490 } 491 492 s.putQueuedPacket(queuedPacket) 493 } 494 495 s.server.Unlock() 496 } 497 498 lnc.logger.Info("Finished receiving from serverConn", 499 zap.Uint64("packetsReceived", packetsReceived), 500 zap.Uint64("payloadBytesReceived", payloadBytesReceived), 501 ) 502 } 503 504 func (s *UDPSessionRelay) relayServerConnToNatConnGeneric(ctx context.Context, uplink sessionUplinkGeneric) { 505 var ( 506 destAddrPort netip.AddrPort 507 packetStart int 508 packetLength int 509 err error 510 packetsSent uint64 511 payloadBytesSent uint64 512 ) 513 514 for queuedPacket := range uplink.natConnSendCh { 515 destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length) 516 if err != nil { 517 uplink.logger.Warn("Failed to pack packet", 518 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 519 zap.String("username", uplink.username), 520 zap.Uint64("clientSessionID", uplink.csid), 521 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 522 zap.String("client", uplink.clientName), 523 zap.Int("payloadLength", queuedPacket.length), 524 zap.Error(err), 525 ) 526 527 s.putQueuedPacket(queuedPacket) 528 continue 529 } 530 531 _, err = uplink.natConn.WriteToUDPAddrPort(queuedPacket.buf[packetStart:packetStart+packetLength], destAddrPort) 532 if err != nil { 533 uplink.logger.Warn("Failed to write packet to natConn", 534 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 535 zap.String("username", uplink.username), 536 zap.Uint64("clientSessionID", uplink.csid), 537 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 538 zap.String("client", uplink.clientName), 539 zap.Stringer("writeDestAddress", destAddrPort), 540 zap.Int("packetLength", packetLength), 541 zap.Error(err), 542 ) 543 } 544 545 err = uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)) 546 if err != nil { 547 uplink.logger.Warn("Failed to set read deadline on natConn", 548 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 549 zap.String("username", uplink.username), 550 zap.Uint64("clientSessionID", uplink.csid), 551 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 552 zap.String("client", uplink.clientName), 553 zap.Stringer("writeDestAddress", destAddrPort), 554 zap.Duration("natTimeout", uplink.natTimeout), 555 zap.Error(err), 556 ) 557 } 558 559 s.putQueuedPacket(queuedPacket) 560 packetsSent++ 561 payloadBytesSent += uint64(queuedPacket.length) 562 } 563 564 uplink.logger.Info("Finished relay serverConn -> natConn", 565 zap.String("username", uplink.username), 566 zap.Uint64("clientSessionID", uplink.csid), 567 zap.String("client", uplink.clientName), 568 zap.Stringer("lastWriteDestAddress", destAddrPort), 569 zap.Uint64("packetsSent", packetsSent), 570 zap.Uint64("payloadBytesSent", payloadBytesSent), 571 ) 572 573 s.collector.CollectUDPSessionUplink(uplink.username, packetsSent, payloadBytesSent) 574 } 575 576 func (s *UDPSessionRelay) relayNatConnToServerConnGeneric(downlink sessionDownlinkGeneric) { 577 clientAddrInfop := downlink.clientAddrInfop 578 clientAddrPort := clientAddrInfop.addrPort 579 clientPktinfo := clientAddrInfop.pktinfo 580 maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr()) 581 582 serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo() 583 natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo() 584 headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom) 585 586 var ( 587 packetsSent uint64 588 payloadBytesSent uint64 589 ) 590 591 packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear) 592 recvBuf := packetBuf[headroom.Front : headroom.Front+downlink.natConnRecvBufSize] 593 594 for { 595 n, _, flags, packetSourceAddrPort, err := downlink.natConn.ReadMsgUDPAddrPort(recvBuf, nil) 596 if err != nil { 597 if errors.Is(err, os.ErrDeadlineExceeded) { 598 break 599 } 600 601 downlink.logger.Warn("Failed to read packet from natConn", 602 zap.Stringer("clientAddress", clientAddrPort), 603 zap.String("username", downlink.username), 604 zap.Uint64("clientSessionID", downlink.csid), 605 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 606 zap.String("client", downlink.clientName), 607 zap.Int("packetLength", n), 608 zap.Error(err), 609 ) 610 continue 611 } 612 err = conn.ParseFlagsForError(flags) 613 if err != nil { 614 downlink.logger.Warn("Failed to read packet from natConn", 615 zap.Stringer("clientAddress", clientAddrPort), 616 zap.String("username", downlink.username), 617 zap.Uint64("clientSessionID", downlink.csid), 618 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 619 zap.String("client", downlink.clientName), 620 zap.Int("packetLength", n), 621 zap.Error(err), 622 ) 623 continue 624 } 625 626 payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, n) 627 if err != nil { 628 downlink.logger.Warn("Failed to unpack packet", 629 zap.Stringer("clientAddress", clientAddrPort), 630 zap.String("username", downlink.username), 631 zap.Uint64("clientSessionID", downlink.csid), 632 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 633 zap.String("client", downlink.clientName), 634 zap.Int("packetLength", n), 635 zap.Error(err), 636 ) 637 continue 638 } 639 640 if caip := downlink.clientAddrInfo.Load(); caip != clientAddrInfop { 641 clientAddrInfop = caip 642 clientAddrPort = caip.addrPort 643 clientPktinfo = caip.pktinfo 644 maxClientPacketSize = zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr()) 645 } 646 647 packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize) 648 if err != nil { 649 downlink.logger.Warn("Failed to pack packet", 650 zap.Stringer("clientAddress", clientAddrPort), 651 zap.String("username", downlink.username), 652 zap.Uint64("clientSessionID", downlink.csid), 653 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 654 zap.String("client", downlink.clientName), 655 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 656 zap.Int("payloadLength", payloadLength), 657 zap.Int("maxClientPacketSize", maxClientPacketSize), 658 zap.Error(err), 659 ) 660 continue 661 } 662 663 _, _, err = downlink.serverConn.WriteMsgUDPAddrPort(packetBuf[packetStart:packetStart+packetLength], clientPktinfo, clientAddrPort) 664 if err != nil { 665 downlink.logger.Warn("Failed to write packet to serverConn", 666 zap.Stringer("clientAddress", clientAddrPort), 667 zap.String("username", downlink.username), 668 zap.Uint64("clientSessionID", downlink.csid), 669 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 670 zap.String("client", downlink.clientName), 671 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 672 zap.Int("packetLength", packetLength), 673 zap.Error(err), 674 ) 675 } 676 677 packetsSent++ 678 payloadBytesSent += uint64(payloadLength) 679 } 680 681 downlink.logger.Info("Finished relay serverConn <- natConn", 682 zap.Stringer("clientAddress", clientAddrPort), 683 zap.String("username", downlink.username), 684 zap.Uint64("clientSessionID", downlink.csid), 685 zap.Uint64("packetsSent", packetsSent), 686 zap.Uint64("payloadBytesSent", payloadBytesSent), 687 ) 688 689 s.collector.CollectUDPSessionDownlink(downlink.username, packetsSent, payloadBytesSent) 690 } 691 692 // getQueuedPacket retrieves a queued packet from the pool. 693 func (s *UDPSessionRelay) getQueuedPacket() *sessionQueuedPacket { 694 return s.queuedPacketPool.Get().(*sessionQueuedPacket) 695 } 696 697 // putQueuedPacket puts the queued packet back into the pool. 698 func (s *UDPSessionRelay) putQueuedPacket(queuedPacket *sessionQueuedPacket) { 699 s.queuedPacketPool.Put(queuedPacket) 700 } 701 702 // Stop implements the Service Stop method. 703 func (s *UDPSessionRelay) Stop() error { 704 for i := range s.listeners { 705 lnc := &s.listeners[i] 706 if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 707 lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err)) 708 } 709 } 710 711 // Wait for serverConn receive goroutines to exit, 712 // so there won't be any new sessions added to the table. 713 s.mwg.Wait() 714 715 s.server.Lock() 716 for csid, entry := range s.table { 717 natConn := entry.state.Swap(entry.serverConn) 718 if natConn == nil { 719 continue 720 } 721 722 if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 723 entry.logger.Warn("Failed to set read deadline on natConn", 724 zap.Uint64("clientSessionID", csid), 725 zap.Error(err), 726 ) 727 } 728 } 729 s.server.Unlock() 730 731 // Wait for all relay goroutines to exit before closing serverConn, 732 // so in-flight packets can be written out. 733 s.wg.Wait() 734 735 for i := range s.listeners { 736 lnc := &s.listeners[i] 737 if err := lnc.serverConn.Close(); err != nil { 738 lnc.logger.Warn("Failed to close serverConn", zap.Error(err)) 739 } 740 } 741 742 return nil 743 }