github.com/database64128/shadowsocks-go@v1.7.0/service/udp_session.go (about) 1 package service 2 3 import ( 4 "bytes" 5 "errors" 6 "net" 7 "net/netip" 8 "os" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/database64128/shadowsocks-go/conn" 14 "github.com/database64128/shadowsocks-go/router" 15 "github.com/database64128/shadowsocks-go/stats" 16 "github.com/database64128/shadowsocks-go/zerocopy" 17 "github.com/database64128/tfo-go/v2" 18 "go.uber.org/zap" 19 ) 20 21 // sessionQueuedPacket is the structure used by send channels to queue packets for sending. 22 type sessionQueuedPacket struct { 23 buf []byte 24 start int 25 length int 26 targetAddr conn.Addr 27 clientAddrPort netip.AddrPort 28 } 29 30 // sessionClientAddrInfo stores a session's client address information. 31 type sessionClientAddrInfo struct { 32 addrPort netip.AddrPort 33 pktinfo []byte 34 } 35 36 // session keeps track of a UDP session. 37 type session struct { 38 // state synchronizes session initialization and shutdown. 39 // 40 // - Swap the natConn in to signal initialization completion. 41 // - Swap the serverConn in to signal shutdown. 42 // 43 // Callers must check the swapped-out value to determine the next action. 44 // 45 // - During initialization, if the swapped-out value is non-nil, 46 // initialization must not proceed. 47 // - During shutdown, if the swapped-out value is nil, preceed to the next entry. 48 state atomic.Pointer[net.UDPConn] 49 clientAddrInfo atomic.Pointer[sessionClientAddrInfo] 50 clientAddrPortCache netip.AddrPort 51 clientPktinfoCache []byte 52 natConnSendCh chan<- *sessionQueuedPacket 53 serverConnUnpacker zerocopy.SessionServerUnpacker 54 username string 55 } 56 57 // sessionUplinkGeneric is used for passing information about relay uplink to the relay goroutine. 58 type sessionUplinkGeneric struct { 59 csid uint64 60 natConn *net.UDPConn 61 natConnSendCh <-chan *sessionQueuedPacket 62 natConnPacker zerocopy.ClientPacker 63 username string 64 } 65 66 // sessionDownlinkGeneric is used for passing information about relay downlink to the relay goroutine. 67 type sessionDownlinkGeneric struct { 68 csid uint64 69 clientAddrInfop *sessionClientAddrInfo 70 clientAddrInfo *atomic.Pointer[sessionClientAddrInfo] 71 natConn *net.UDPConn 72 natConnRecvBufSize int 73 natConnUnpacker zerocopy.ClientUnpacker 74 serverConn *net.UDPConn 75 serverConnPacker zerocopy.ServerPacker 76 username string 77 } 78 79 // UDPSessionRelay is a session-based UDP relay service. 80 // 81 // Incoming UDP packets are dispatched to NAT sessions based on the client session ID. 82 type UDPSessionRelay struct { 83 serverName string 84 listenAddress string 85 serverIndex int 86 mtu int 87 packetBufFrontHeadroom int 88 packetBufRecvSize int 89 relayBatchSize int 90 serverRecvBatchSize int 91 sendChannelCapacity int 92 natTimeout time.Duration 93 server zerocopy.UDPSessionServer 94 serverConn *net.UDPConn 95 serverConnListenConfig tfo.ListenConfig 96 collector stats.Collector 97 router *router.Router 98 logger *zap.Logger 99 queuedPacketPool sync.Pool 100 wg sync.WaitGroup 101 mwg sync.WaitGroup 102 table map[uint64]*session 103 startFunc func() error 104 } 105 106 func NewUDPSessionRelay( 107 batchMode, serverName, listenAddress string, 108 relayBatchSize, serverRecvBatchSize, sendChannelCapacity, serverIndex, mtu int, 109 maxClientPackerHeadroom zerocopy.Headroom, 110 natTimeout time.Duration, 111 server zerocopy.UDPSessionServer, 112 serverConnListenConfig tfo.ListenConfig, 113 collector stats.Collector, 114 router *router.Router, 115 logger *zap.Logger, 116 ) *UDPSessionRelay { 117 serverInfo := server.Info() 118 packetBufHeadroom := zerocopy.UDPRelayHeadroom(maxClientPackerHeadroom, serverInfo.UnpackerHeadroom) 119 packetBufRecvSize := mtu - zerocopy.IPv4HeaderLength - zerocopy.UDPHeaderLength 120 packetBufSize := packetBufHeadroom.Front + packetBufRecvSize + packetBufHeadroom.Rear 121 s := UDPSessionRelay{ 122 serverName: serverName, 123 listenAddress: listenAddress, 124 serverIndex: serverIndex, 125 mtu: mtu, 126 packetBufFrontHeadroom: packetBufHeadroom.Front, 127 packetBufRecvSize: packetBufRecvSize, 128 relayBatchSize: relayBatchSize, 129 serverRecvBatchSize: serverRecvBatchSize, 130 sendChannelCapacity: sendChannelCapacity, 131 natTimeout: natTimeout, 132 server: server, 133 serverConnListenConfig: serverConnListenConfig, 134 collector: collector, 135 router: router, 136 logger: logger, 137 queuedPacketPool: sync.Pool{ 138 New: func() any { 139 return &sessionQueuedPacket{ 140 buf: make([]byte, packetBufSize), 141 } 142 }, 143 }, 144 table: make(map[uint64]*session), 145 } 146 s.setStartFunc(batchMode) 147 return &s 148 } 149 150 // String implements the Service String method. 151 func (s *UDPSessionRelay) String() string { 152 return "UDP session relay service for " + s.serverName 153 } 154 155 // Start implements the Service Start method. 156 func (s *UDPSessionRelay) Start() error { 157 return s.startFunc() 158 } 159 160 func (s *UDPSessionRelay) startGeneric() error { 161 serverConn, err := conn.ListenUDP(s.serverConnListenConfig, "udp", s.listenAddress) 162 if err != nil { 163 return err 164 } 165 s.serverConn = serverConn 166 167 s.mwg.Add(1) 168 169 go func() { 170 s.recvFromServerConnGeneric(serverConn) 171 s.mwg.Done() 172 }() 173 174 s.logger.Info("Started UDP session relay service", 175 zap.String("server", s.serverName), 176 zap.String("listenAddress", s.listenAddress), 177 ) 178 179 return nil 180 } 181 182 func (s *UDPSessionRelay) recvFromServerConnGeneric(serverConn *net.UDPConn) { 183 cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize) 184 185 var ( 186 n int 187 cmsgn int 188 flags int 189 err error 190 packetsReceived uint64 191 payloadBytesReceived uint64 192 ) 193 194 for { 195 queuedPacket := s.getQueuedPacket() 196 recvBuf := queuedPacket.buf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+s.packetBufRecvSize] 197 198 n, cmsgn, flags, queuedPacket.clientAddrPort, err = serverConn.ReadMsgUDPAddrPort(recvBuf, cmsgBuf) 199 if err != nil { 200 if errors.Is(err, os.ErrDeadlineExceeded) { 201 s.putQueuedPacket(queuedPacket) 202 break 203 } 204 205 s.logger.Warn("Failed to read packet from serverConn", 206 zap.String("server", s.serverName), 207 zap.String("listenAddress", s.listenAddress), 208 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 209 zap.Int("packetLength", n), 210 zap.Error(err), 211 ) 212 213 s.putQueuedPacket(queuedPacket) 214 continue 215 } 216 err = conn.ParseFlagsForError(flags) 217 if err != nil { 218 s.logger.Warn("Failed to read packet from serverConn", 219 zap.String("server", s.serverName), 220 zap.String("listenAddress", s.listenAddress), 221 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 222 zap.Int("packetLength", n), 223 zap.Error(err), 224 ) 225 226 s.putQueuedPacket(queuedPacket) 227 continue 228 } 229 230 packet := recvBuf[:n] 231 232 csid, err := s.server.SessionInfo(packet) 233 if err != nil { 234 s.logger.Warn("Failed to extract session info from packet", 235 zap.String("server", s.serverName), 236 zap.String("listenAddress", s.listenAddress), 237 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 238 zap.Int("packetLength", n), 239 zap.Error(err), 240 ) 241 242 s.putQueuedPacket(queuedPacket) 243 continue 244 } 245 246 s.server.Lock() 247 248 entry, ok := s.table[csid] 249 if !ok { 250 entry = &session{} 251 252 entry.serverConnUnpacker, entry.username, err = s.server.NewUnpacker(packet, csid) 253 if err != nil { 254 s.logger.Warn("Failed to create unpacker for client session", 255 zap.String("server", s.serverName), 256 zap.String("listenAddress", s.listenAddress), 257 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 258 zap.Uint64("clientSessionID", csid), 259 zap.Int("packetLength", n), 260 zap.Error(err), 261 ) 262 263 s.putQueuedPacket(queuedPacket) 264 s.server.Unlock() 265 continue 266 } 267 } 268 269 queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(queuedPacket.buf, queuedPacket.clientAddrPort, s.packetBufFrontHeadroom, n) 270 if err != nil { 271 s.logger.Warn("Failed to unpack packet", 272 zap.String("server", s.serverName), 273 zap.String("listenAddress", s.listenAddress), 274 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 275 zap.String("username", entry.username), 276 zap.Uint64("clientSessionID", csid), 277 zap.Int("packetLength", n), 278 zap.Error(err), 279 ) 280 281 s.putQueuedPacket(queuedPacket) 282 s.server.Unlock() 283 continue 284 } 285 286 packetsReceived++ 287 payloadBytesReceived += uint64(queuedPacket.length) 288 289 var clientAddrInfop *sessionClientAddrInfo 290 cmsg := cmsgBuf[:cmsgn] 291 292 updateClientAddrPort := entry.clientAddrPortCache != queuedPacket.clientAddrPort 293 updateClientPktinfo := !bytes.Equal(entry.clientPktinfoCache, cmsg) 294 295 if updateClientAddrPort { 296 entry.clientAddrPortCache = queuedPacket.clientAddrPort 297 } 298 299 if updateClientPktinfo { 300 entry.clientPktinfoCache = make([]byte, len(cmsg)) 301 copy(entry.clientPktinfoCache, cmsg) 302 } 303 304 if updateClientAddrPort || updateClientPktinfo { 305 clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) 306 if err != nil { 307 s.logger.Warn("Failed to parse pktinfo control message from serverConn", 308 zap.String("server", s.serverName), 309 zap.String("listenAddress", s.listenAddress), 310 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 311 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 312 zap.String("username", entry.username), 313 zap.Uint64("clientSessionID", csid), 314 zap.Error(err), 315 ) 316 317 s.putQueuedPacket(queuedPacket) 318 s.server.Unlock() 319 continue 320 } 321 322 clientAddrInfop = &sessionClientAddrInfo{entry.clientAddrPortCache, entry.clientPktinfoCache} 323 entry.clientAddrInfo.Store(clientAddrInfop) 324 325 if ce := s.logger.Check(zap.DebugLevel, "Updated client address info"); ce != nil { 326 ce.Write( 327 zap.String("server", s.serverName), 328 zap.String("listenAddress", s.listenAddress), 329 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 330 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 331 zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), 332 zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), 333 zap.String("username", entry.username), 334 zap.Uint64("clientSessionID", csid), 335 ) 336 } 337 } 338 339 if !ok { 340 natConnSendCh := make(chan *sessionQueuedPacket, s.sendChannelCapacity) 341 entry.natConnSendCh = natConnSendCh 342 s.table[csid] = entry 343 344 go func() { 345 var sendChClean bool 346 347 defer func() { 348 s.server.Lock() 349 close(natConnSendCh) 350 delete(s.table, csid) 351 s.server.Unlock() 352 353 if !sendChClean { 354 for queuedPacket := range natConnSendCh { 355 s.putQueuedPacket(queuedPacket) 356 } 357 } 358 }() 359 360 c, err := s.router.GetUDPClient(router.RequestInfo{ 361 ServerIndex: s.serverIndex, 362 Username: entry.username, 363 SourceAddrPort: queuedPacket.clientAddrPort, 364 TargetAddr: queuedPacket.targetAddr, 365 }) 366 if err != nil { 367 s.logger.Warn("Failed to get UDP client for new NAT session", 368 zap.String("server", s.serverName), 369 zap.String("listenAddress", s.listenAddress), 370 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 371 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 372 zap.String("username", entry.username), 373 zap.Uint64("clientSessionID", csid), 374 zap.Error(err), 375 ) 376 return 377 } 378 379 // Only add for the current goroutine here, since we don't want the router to block exiting. 380 s.wg.Add(1) 381 defer s.wg.Done() 382 383 clientInfo, natConnPacker, natConnUnpacker, err := c.NewSession() 384 if err != nil { 385 s.logger.Warn("Failed to create new UDP client session", 386 zap.String("server", s.serverName), 387 zap.String("client", clientInfo.Name), 388 zap.String("listenAddress", s.listenAddress), 389 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 390 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 391 zap.String("username", entry.username), 392 zap.Uint64("clientSessionID", csid), 393 zap.Error(err), 394 ) 395 return 396 } 397 398 serverConnPacker, err := entry.serverConnUnpacker.NewPacker() 399 if err != nil { 400 s.logger.Warn("Failed to create packer for client session", 401 zap.String("server", s.serverName), 402 zap.String("client", clientInfo.Name), 403 zap.String("listenAddress", s.listenAddress), 404 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 405 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 406 zap.String("username", entry.username), 407 zap.Uint64("clientSessionID", csid), 408 zap.Error(err), 409 ) 410 return 411 } 412 413 natConn, err := conn.ListenUDP(clientInfo.ListenConfig, "udp", "") 414 if err != nil { 415 s.logger.Warn("Failed to create UDP socket for new NAT session", 416 zap.String("server", s.serverName), 417 zap.String("client", clientInfo.Name), 418 zap.String("listenAddress", s.listenAddress), 419 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 420 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 421 zap.String("username", entry.username), 422 zap.Uint64("clientSessionID", csid), 423 zap.Error(err), 424 ) 425 return 426 } 427 428 err = natConn.SetReadDeadline(time.Now().Add(s.natTimeout)) 429 if err != nil { 430 s.logger.Warn("Failed to set read deadline on natConn", 431 zap.String("server", s.serverName), 432 zap.String("client", clientInfo.Name), 433 zap.String("listenAddress", s.listenAddress), 434 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 435 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 436 zap.Duration("natTimeout", s.natTimeout), 437 zap.String("username", entry.username), 438 zap.Uint64("clientSessionID", csid), 439 zap.Error(err), 440 ) 441 natConn.Close() 442 return 443 } 444 445 oldState := entry.state.Swap(natConn) 446 if oldState != nil { 447 natConn.Close() 448 return 449 } 450 451 // No more early returns! 452 sendChClean = true 453 454 s.logger.Info("UDP session relay started", 455 zap.String("server", s.serverName), 456 zap.String("client", clientInfo.Name), 457 zap.String("listenAddress", s.listenAddress), 458 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 459 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 460 zap.String("username", entry.username), 461 zap.Uint64("clientSessionID", csid), 462 ) 463 464 s.wg.Add(1) 465 466 go func() { 467 s.relayServerConnToNatConnGeneric(sessionUplinkGeneric{ 468 csid: csid, 469 natConn: natConn, 470 natConnSendCh: natConnSendCh, 471 natConnPacker: natConnPacker, 472 username: entry.username, 473 }) 474 natConn.Close() 475 s.wg.Done() 476 }() 477 478 s.relayNatConnToServerConnGeneric(sessionDownlinkGeneric{ 479 csid: csid, 480 clientAddrInfop: clientAddrInfop, 481 clientAddrInfo: &entry.clientAddrInfo, 482 natConn: natConn, 483 natConnRecvBufSize: clientInfo.MaxPacketSize, 484 natConnUnpacker: natConnUnpacker, 485 serverConn: serverConn, 486 serverConnPacker: serverConnPacker, 487 username: entry.username, 488 }) 489 }() 490 491 if ce := s.logger.Check(zap.DebugLevel, "New UDP session"); ce != nil { 492 ce.Write( 493 zap.String("server", s.serverName), 494 zap.String("listenAddress", s.listenAddress), 495 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 496 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 497 zap.String("username", entry.username), 498 zap.Uint64("clientSessionID", csid), 499 ) 500 } 501 } 502 503 select { 504 case entry.natConnSendCh <- queuedPacket: 505 default: 506 if ce := s.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil { 507 ce.Write( 508 zap.String("server", s.serverName), 509 zap.String("listenAddress", s.listenAddress), 510 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 511 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 512 zap.String("username", entry.username), 513 zap.Uint64("clientSessionID", csid), 514 ) 515 } 516 517 s.putQueuedPacket(queuedPacket) 518 } 519 520 s.server.Unlock() 521 } 522 523 s.logger.Info("Finished receiving from serverConn", 524 zap.String("server", s.serverName), 525 zap.String("listenAddress", s.listenAddress), 526 zap.Uint64("packetsReceived", packetsReceived), 527 zap.Uint64("payloadBytesReceived", payloadBytesReceived), 528 ) 529 } 530 531 func (s *UDPSessionRelay) relayServerConnToNatConnGeneric(uplink sessionUplinkGeneric) { 532 var ( 533 destAddrPort netip.AddrPort 534 packetStart int 535 packetLength int 536 err error 537 packetsSent uint64 538 payloadBytesSent uint64 539 ) 540 541 for queuedPacket := range uplink.natConnSendCh { 542 destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length) 543 if err != nil { 544 s.logger.Warn("Failed to pack packet", 545 zap.String("server", s.serverName), 546 zap.String("listenAddress", s.listenAddress), 547 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 548 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 549 zap.String("username", uplink.username), 550 zap.Uint64("clientSessionID", uplink.csid), 551 zap.Int("payloadLength", queuedPacket.length), 552 zap.Error(err), 553 ) 554 555 s.putQueuedPacket(queuedPacket) 556 continue 557 } 558 559 _, err = uplink.natConn.WriteToUDPAddrPort(queuedPacket.buf[packetStart:packetStart+packetLength], destAddrPort) 560 if err != nil { 561 s.logger.Warn("Failed to write packet to natConn", 562 zap.String("server", s.serverName), 563 zap.String("listenAddress", s.listenAddress), 564 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 565 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 566 zap.Stringer("writeDestAddress", destAddrPort), 567 zap.String("username", uplink.username), 568 zap.Uint64("clientSessionID", uplink.csid), 569 zap.Error(err), 570 ) 571 } 572 573 err = uplink.natConn.SetReadDeadline(time.Now().Add(s.natTimeout)) 574 if err != nil { 575 s.logger.Warn("Failed to set read deadline on natConn", 576 zap.String("server", s.serverName), 577 zap.String("listenAddress", s.listenAddress), 578 zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), 579 zap.Duration("natTimeout", s.natTimeout), 580 zap.String("username", uplink.username), 581 zap.Uint64("clientSessionID", uplink.csid), 582 zap.Error(err), 583 ) 584 } 585 586 s.putQueuedPacket(queuedPacket) 587 packetsSent++ 588 payloadBytesSent += uint64(queuedPacket.length) 589 } 590 591 s.logger.Info("Finished relay serverConn -> natConn", 592 zap.String("server", s.serverName), 593 zap.String("listenAddress", s.listenAddress), 594 zap.Stringer("lastWriteDestAddress", destAddrPort), 595 zap.String("username", uplink.username), 596 zap.Uint64("clientSessionID", uplink.csid), 597 zap.Uint64("packetsSent", packetsSent), 598 zap.Uint64("payloadBytesSent", payloadBytesSent), 599 ) 600 601 s.collector.CollectUDPSessionUplink(uplink.username, packetsSent, payloadBytesSent) 602 } 603 604 func (s *UDPSessionRelay) relayNatConnToServerConnGeneric(downlink sessionDownlinkGeneric) { 605 clientAddrInfop := downlink.clientAddrInfop 606 clientAddrPort := clientAddrInfop.addrPort 607 clientPktinfo := clientAddrInfop.pktinfo 608 maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr()) 609 610 serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo() 611 natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo() 612 headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom) 613 614 var ( 615 packetsSent uint64 616 payloadBytesSent uint64 617 ) 618 619 packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear) 620 recvBuf := packetBuf[headroom.Front : headroom.Front+downlink.natConnRecvBufSize] 621 622 for { 623 n, _, flags, packetSourceAddrPort, err := downlink.natConn.ReadMsgUDPAddrPort(recvBuf, nil) 624 if err != nil { 625 if errors.Is(err, os.ErrDeadlineExceeded) { 626 break 627 } 628 629 s.logger.Warn("Failed to read packet from natConn", 630 zap.String("server", s.serverName), 631 zap.String("listenAddress", s.listenAddress), 632 zap.Stringer("clientAddress", clientAddrPort), 633 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 634 zap.String("username", downlink.username), 635 zap.Uint64("clientSessionID", downlink.csid), 636 zap.Int("packetLength", n), 637 zap.Error(err), 638 ) 639 continue 640 } 641 err = conn.ParseFlagsForError(flags) 642 if err != nil { 643 s.logger.Warn("Failed to read packet from natConn", 644 zap.String("server", s.serverName), 645 zap.String("listenAddress", s.listenAddress), 646 zap.Stringer("clientAddress", clientAddrPort), 647 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 648 zap.String("username", downlink.username), 649 zap.Uint64("clientSessionID", downlink.csid), 650 zap.Int("packetLength", n), 651 zap.Error(err), 652 ) 653 continue 654 } 655 656 payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, n) 657 if err != nil { 658 s.logger.Warn("Failed to unpack packet", 659 zap.String("server", s.serverName), 660 zap.String("listenAddress", s.listenAddress), 661 zap.Stringer("clientAddress", clientAddrPort), 662 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 663 zap.String("username", downlink.username), 664 zap.Uint64("clientSessionID", downlink.csid), 665 zap.Int("packetLength", n), 666 zap.Error(err), 667 ) 668 continue 669 } 670 671 if caip := downlink.clientAddrInfo.Load(); caip != clientAddrInfop { 672 clientAddrInfop = caip 673 clientAddrPort = caip.addrPort 674 clientPktinfo = caip.pktinfo 675 maxClientPacketSize = zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr()) 676 } 677 678 packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize) 679 if err != nil { 680 s.logger.Warn("Failed to pack packet", 681 zap.String("server", s.serverName), 682 zap.String("listenAddress", s.listenAddress), 683 zap.Stringer("clientAddress", clientAddrPort), 684 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 685 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 686 zap.String("username", downlink.username), 687 zap.Uint64("clientSessionID", downlink.csid), 688 zap.Int("payloadLength", payloadLength), 689 zap.Int("maxClientPacketSize", maxClientPacketSize), 690 zap.Error(err), 691 ) 692 continue 693 } 694 695 _, _, err = downlink.serverConn.WriteMsgUDPAddrPort(packetBuf[packetStart:packetStart+packetLength], clientPktinfo, clientAddrPort) 696 if err != nil { 697 s.logger.Warn("Failed to write packet to serverConn", 698 zap.String("server", s.serverName), 699 zap.String("listenAddress", s.listenAddress), 700 zap.Stringer("clientAddress", clientAddrPort), 701 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 702 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 703 zap.String("username", downlink.username), 704 zap.Uint64("clientSessionID", downlink.csid), 705 zap.Error(err), 706 ) 707 } 708 709 packetsSent++ 710 payloadBytesSent += uint64(payloadLength) 711 } 712 713 s.logger.Info("Finished relay serverConn <- natConn", 714 zap.String("server", s.serverName), 715 zap.String("listenAddress", s.listenAddress), 716 zap.Stringer("clientAddress", clientAddrPort), 717 zap.String("username", downlink.username), 718 zap.Uint64("clientSessionID", downlink.csid), 719 zap.Uint64("packetsSent", packetsSent), 720 zap.Uint64("payloadBytesSent", payloadBytesSent), 721 ) 722 723 s.collector.CollectUDPSessionDownlink(downlink.username, packetsSent, payloadBytesSent) 724 } 725 726 // getQueuedPacket retrieves a queued packet from the pool. 727 func (s *UDPSessionRelay) getQueuedPacket() *sessionQueuedPacket { 728 return s.queuedPacketPool.Get().(*sessionQueuedPacket) 729 } 730 731 // putQueuedPacket puts the queued packet back into the pool. 732 func (s *UDPSessionRelay) putQueuedPacket(queuedPacket *sessionQueuedPacket) { 733 s.queuedPacketPool.Put(queuedPacket) 734 } 735 736 // Stop implements the Service Stop method. 737 func (s *UDPSessionRelay) Stop() error { 738 if s.serverConn == nil { 739 return nil 740 } 741 742 now := time.Now() 743 744 if err := s.serverConn.SetReadDeadline(now); err != nil { 745 return err 746 } 747 748 // Wait for serverConn receive goroutines to exit, 749 // so there won't be any new sessions added to the table. 750 s.mwg.Wait() 751 752 s.server.Lock() 753 for csid, entry := range s.table { 754 natConn := entry.state.Swap(s.serverConn) 755 if natConn == nil { 756 continue 757 } 758 759 if err := natConn.SetReadDeadline(now); err != nil { 760 s.logger.Warn("Failed to set read deadline on natConn", 761 zap.String("server", s.serverName), 762 zap.String("listenAddress", s.listenAddress), 763 zap.Uint64("clientSessionID", csid), 764 zap.Error(err), 765 ) 766 } 767 } 768 s.server.Unlock() 769 770 // Wait for all relay goroutines to exit before closing serverConn, 771 // so in-flight packets can be written out. 772 s.wg.Wait() 773 774 return s.serverConn.Close() 775 }