github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_transparent_linux.go (about) 1 package service 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "net/netip" 8 "os" 9 "sync" 10 "sync/atomic" 11 "time" 12 "unsafe" 13 14 "github.com/database64128/shadowsocks-go/conn" 15 "github.com/database64128/shadowsocks-go/router" 16 "github.com/database64128/shadowsocks-go/stats" 17 "github.com/database64128/shadowsocks-go/zerocopy" 18 "go.uber.org/zap" 19 "golang.org/x/sys/unix" 20 ) 21 22 // transparentQueuedPacket is the structure used by send channels to queue packets for sending. 23 type transparentQueuedPacket struct { 24 buf []byte 25 targetAddrPort netip.AddrPort 26 msglen uint32 27 } 28 29 // transparentNATEntry is an entry in the tproxy NAT table. 30 type transparentNATEntry struct { 31 // state synchronizes session initialization and shutdown. 32 // 33 // - Swap the natConn in to signal initialization completion. 34 // - Swap the serverConn in to signal shutdown. 35 // 36 // Callers must check the swapped-out value to determine the next action. 37 // 38 // - During initialization, if the swapped-out value is non-nil, 39 // initialization must not proceed. 40 // - During shutdown, if the swapped-out value is nil, preceed to the next entry. 41 state atomic.Pointer[net.UDPConn] 42 natConnSendCh chan<- *transparentQueuedPacket 43 serverConn *net.UDPConn 44 logger *zap.Logger 45 } 46 47 // transparentUplink is used for passing information about relay uplink to the relay goroutine. 48 type transparentUplink struct { 49 clientName string 50 clientAddrPort netip.AddrPort 51 natConn *conn.MmsgWConn 52 natConnSendCh <-chan *transparentQueuedPacket 53 natConnPacker zerocopy.ClientPacker 54 natTimeout time.Duration 55 relayBatchSize int 56 logger *zap.Logger 57 } 58 59 // transparentDownlink is used for passing information about relay downlink to the relay goroutine. 60 type transparentDownlink struct { 61 clientName string 62 clientAddrPort netip.AddrPort 63 natConn *conn.MmsgRConn 64 natConnRecvBufSize int 65 natConnUnpacker zerocopy.ClientUnpacker 66 relayBatchSize int 67 logger *zap.Logger 68 } 69 70 // UDPTransparentRelay is like [UDPNATRelay], but for transparent proxy. 71 type UDPTransparentRelay struct { 72 serverName string 73 serverIndex int 74 mtu int 75 packetBufFrontHeadroom int 76 packetBufRecvSize int 77 listeners []udpRelayServerConn 78 transparentConnListenConfig conn.ListenConfig 79 collector stats.Collector 80 router *router.Router 81 logger *zap.Logger 82 queuedPacketPool sync.Pool 83 mu sync.Mutex 84 wg sync.WaitGroup 85 mwg sync.WaitGroup 86 table map[netip.AddrPort]*transparentNATEntry 87 } 88 89 func NewUDPTransparentRelay( 90 serverName string, 91 serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int, 92 listeners []udpRelayServerConn, 93 transparentConnListenConfig conn.ListenConfig, 94 collector stats.Collector, 95 router *router.Router, 96 logger *zap.Logger, 97 ) (Relay, error) { 98 return &UDPTransparentRelay{ 99 serverName: serverName, 100 serverIndex: serverIndex, 101 mtu: mtu, 102 packetBufFrontHeadroom: packetBufFrontHeadroom, 103 packetBufRecvSize: packetBufRecvSize, 104 listeners: listeners, 105 transparentConnListenConfig: transparentConnListenConfig, 106 collector: collector, 107 router: router, 108 logger: logger, 109 queuedPacketPool: sync.Pool{ 110 New: func() any { 111 return &transparentQueuedPacket{ 112 buf: make([]byte, packetBufSize), 113 } 114 }, 115 }, 116 table: make(map[netip.AddrPort]*transparentNATEntry), 117 }, nil 118 } 119 120 // String implements the Relay String method. 121 func (s *UDPTransparentRelay) String() string { 122 return "UDP transparent relay service for " + s.serverName 123 } 124 125 // Start implements the Relay Start method. 126 func (s *UDPTransparentRelay) Start(ctx context.Context) error { 127 for i := range s.listeners { 128 index := i 129 lnc := &s.listeners[index] 130 131 serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) 132 if err != nil { 133 return err 134 } 135 lnc.serverConn = serverConn.UDPConn 136 lnc.address = serverConn.LocalAddr().String() 137 lnc.logger = s.logger.With( 138 zap.String("server", s.serverName), 139 zap.Int("listener", index), 140 zap.String("listenAddress", lnc.address), 141 ) 142 143 s.mwg.Add(1) 144 145 go func() { 146 s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn()) 147 s.mwg.Done() 148 }() 149 150 lnc.logger.Info("Started UDP transparent relay service listener") 151 } 152 return nil 153 } 154 155 func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRelayServerConn, serverConn *conn.MmsgRConn) { 156 n := lnc.serverRecvBatchSize 157 qpvec := make([]*transparentQueuedPacket, n) 158 namevec := make([]unix.RawSockaddrInet6, n) 159 iovec := make([]unix.Iovec, n) 160 cmsgvec := make([][]byte, n) 161 msgvec := make([]conn.Mmsghdr, n) 162 163 for i := range msgvec { 164 cmsgBuf := make([]byte, conn.TransparentSocketControlMessageBufferSize) 165 cmsgvec[i] = cmsgBuf 166 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i])) 167 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 168 msgvec[i].Msghdr.Iov = &iovec[i] 169 msgvec[i].Msghdr.SetIovlen(1) 170 msgvec[i].Msghdr.Control = unsafe.SliceData(cmsgBuf) 171 } 172 173 var ( 174 err error 175 recvmmsgCount uint64 176 packetsReceived uint64 177 payloadBytesReceived uint64 178 burstBatchSize int 179 ) 180 181 for { 182 for i := range iovec[:n] { 183 queuedPacket := s.getQueuedPacket() 184 qpvec[i] = queuedPacket 185 iovec[i].Base = &queuedPacket.buf[s.packetBufFrontHeadroom] 186 iovec[i].SetLen(s.packetBufRecvSize) 187 msgvec[i].Msghdr.SetControllen(conn.TransparentSocketControlMessageBufferSize) 188 } 189 190 n, err = serverConn.ReadMsgs(msgvec, 0) 191 if err != nil { 192 if errors.Is(err, os.ErrDeadlineExceeded) { 193 break 194 } 195 196 lnc.logger.Warn("Failed to batch read packets from serverConn", zap.Error(err)) 197 198 n = 1 199 s.putQueuedPacket(qpvec[0]) 200 continue 201 } 202 203 recvmmsgCount++ 204 packetsReceived += uint64(n) 205 burstBatchSize = max(burstBatchSize, n) 206 207 s.mu.Lock() 208 209 msgvecn := msgvec[:n] 210 211 for i := range msgvecn { 212 msg := &msgvecn[i] 213 queuedPacket := qpvec[i] 214 215 clientAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen) 216 if err != nil { 217 lnc.logger.Warn("Failed to parse sockaddr of packet from serverConn", zap.Error(err)) 218 s.putQueuedPacket(queuedPacket) 219 continue 220 } 221 222 if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil { 223 lnc.logger.Warn("Packet from serverConn discarded", 224 zap.Stringer("clientAddress", clientAddrPort), 225 zap.Uint32("packetLength", msg.Msglen), 226 zap.Error(err), 227 ) 228 229 s.putQueuedPacket(queuedPacket) 230 continue 231 } 232 233 queuedPacket.targetAddrPort, err = conn.ParseOrigDstAddrCmsg(cmsgvec[i][:msg.Msghdr.Controllen]) 234 if err != nil { 235 lnc.logger.Warn("Failed to parse original destination address control message from serverConn", 236 zap.Stringer("clientAddress", clientAddrPort), 237 zap.Error(err), 238 ) 239 s.putQueuedPacket(queuedPacket) 240 continue 241 } 242 243 queuedPacket.msglen = msg.Msglen 244 payloadBytesReceived += uint64(msg.Msglen) 245 246 entry := s.table[clientAddrPort] 247 if entry == nil { 248 natConnSendCh := make(chan *transparentQueuedPacket, lnc.sendChannelCapacity) 249 entry = &transparentNATEntry{ 250 natConnSendCh: natConnSendCh, 251 serverConn: lnc.serverConn, 252 logger: lnc.logger, 253 } 254 s.table[clientAddrPort] = entry 255 s.wg.Add(1) 256 257 go func() { 258 var sendChClean bool 259 260 defer func() { 261 s.mu.Lock() 262 close(natConnSendCh) 263 delete(s.table, clientAddrPort) 264 s.mu.Unlock() 265 266 if !sendChClean { 267 for queuedPacket := range natConnSendCh { 268 s.putQueuedPacket(queuedPacket) 269 } 270 } 271 272 s.wg.Done() 273 }() 274 275 c, err := s.router.GetUDPClient(ctx, router.RequestInfo{ 276 ServerIndex: s.serverIndex, 277 SourceAddrPort: clientAddrPort, 278 TargetAddr: conn.AddrFromIPPort(queuedPacket.targetAddrPort), 279 }) 280 if err != nil { 281 lnc.logger.Warn("Failed to get UDP client for new NAT session", 282 zap.Stringer("clientAddress", clientAddrPort), 283 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 284 zap.Error(err), 285 ) 286 return 287 } 288 289 clientInfo, clientSession, err := c.NewSession(ctx) 290 if err != nil { 291 lnc.logger.Warn("Failed to create new UDP client session", 292 zap.Stringer("clientAddress", clientAddrPort), 293 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 294 zap.String("client", clientInfo.Name), 295 zap.Error(err), 296 ) 297 return 298 } 299 300 natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") 301 if err != nil { 302 lnc.logger.Warn("Failed to create UDP socket for new NAT session", 303 zap.Stringer("clientAddress", clientAddrPort), 304 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 305 zap.String("client", clientInfo.Name), 306 zap.Error(err), 307 ) 308 clientSession.Close() 309 return 310 } 311 312 if err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout)); err != nil { 313 lnc.logger.Warn("Failed to set read deadline on natConn", 314 zap.Stringer("clientAddress", clientAddrPort), 315 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 316 zap.String("client", clientInfo.Name), 317 zap.Duration("natTimeout", lnc.natTimeout), 318 zap.Error(err), 319 ) 320 natConn.Close() 321 clientSession.Close() 322 return 323 } 324 325 oldState := entry.state.Swap(natConn.UDPConn) 326 if oldState != nil { 327 natConn.Close() 328 clientSession.Close() 329 return 330 } 331 332 // No more early returns! 333 sendChClean = true 334 335 lnc.logger.Info("UDP transparent relay started", 336 zap.Stringer("clientAddress", clientAddrPort), 337 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 338 zap.String("client", clientInfo.Name), 339 ) 340 341 s.wg.Add(1) 342 343 go func() { 344 s.relayServerConnToNatConnSendmmsg(ctx, transparentUplink{ 345 clientName: clientInfo.Name, 346 clientAddrPort: clientAddrPort, 347 natConn: natConn.WConn(), 348 natConnSendCh: natConnSendCh, 349 natConnPacker: clientSession.Packer, 350 natTimeout: lnc.natTimeout, 351 relayBatchSize: lnc.relayBatchSize, 352 logger: lnc.logger, 353 }) 354 natConn.Close() 355 clientSession.Close() 356 s.wg.Done() 357 }() 358 359 s.relayNatConnToTransparentConnSendmmsg(ctx, transparentDownlink{ 360 clientName: clientInfo.Name, 361 clientAddrPort: clientAddrPort, 362 natConn: natConn.RConn(), 363 natConnRecvBufSize: clientSession.MaxPacketSize, 364 natConnUnpacker: clientSession.Unpacker, 365 relayBatchSize: lnc.relayBatchSize, 366 }) 367 }() 368 369 if ce := lnc.logger.Check(zap.DebugLevel, "New UDP transparent session"); ce != nil { 370 ce.Write( 371 zap.Stringer("clientAddress", clientAddrPort), 372 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 373 ) 374 } 375 } 376 377 select { 378 case entry.natConnSendCh <- queuedPacket: 379 default: 380 if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil { 381 ce.Write( 382 zap.Stringer("clientAddress", clientAddrPort), 383 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 384 ) 385 } 386 387 s.putQueuedPacket(queuedPacket) 388 } 389 } 390 391 s.mu.Unlock() 392 } 393 394 for i := range qpvec { 395 s.putQueuedPacket(qpvec[i]) 396 } 397 398 lnc.logger.Info("Finished receiving from serverConn", 399 zap.Uint64("recvmmsgCount", recvmmsgCount), 400 zap.Uint64("packetsReceived", packetsReceived), 401 zap.Uint64("payloadBytesReceived", payloadBytesReceived), 402 zap.Int("burstBatchSize", burstBatchSize), 403 ) 404 } 405 406 func (s *UDPTransparentRelay) relayServerConnToNatConnSendmmsg(ctx context.Context, uplink transparentUplink) { 407 var ( 408 destAddrPort netip.AddrPort 409 packetStart int 410 packetLength int 411 err error 412 sendmmsgCount uint64 413 packetsSent uint64 414 payloadBytesSent uint64 415 burstBatchSize int 416 ) 417 418 qpvec := make([]*transparentQueuedPacket, uplink.relayBatchSize) 419 namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize) 420 iovec := make([]unix.Iovec, uplink.relayBatchSize) 421 msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize) 422 423 for i := range msgvec { 424 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i])) 425 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 426 msgvec[i].Msghdr.Iov = &iovec[i] 427 msgvec[i].Msghdr.SetIovlen(1) 428 } 429 430 main: 431 for { 432 var count int 433 434 // Block on first dequeue op. 435 queuedPacket, ok := <-uplink.natConnSendCh 436 if !ok { 437 break 438 } 439 440 dequeue: 441 for { 442 destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, conn.AddrFromIPPort(queuedPacket.targetAddrPort), s.packetBufFrontHeadroom, int(queuedPacket.msglen)) 443 if err != nil { 444 uplink.logger.Warn("Failed to pack packet for natConn", 445 zap.Stringer("clientAddress", uplink.clientAddrPort), 446 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 447 zap.String("client", uplink.clientName), 448 zap.Uint32("payloadLength", queuedPacket.msglen), 449 zap.Error(err), 450 ) 451 452 s.putQueuedPacket(queuedPacket) 453 454 if count == 0 { 455 continue main 456 } 457 goto next 458 } 459 460 qpvec[count] = queuedPacket 461 namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort) 462 iovec[count].Base = &queuedPacket.buf[packetStart] 463 iovec[count].SetLen(packetLength) 464 count++ 465 payloadBytesSent += uint64(queuedPacket.msglen) 466 467 if count == uplink.relayBatchSize { 468 break 469 } 470 471 next: 472 select { 473 case queuedPacket, ok = <-uplink.natConnSendCh: 474 if !ok { 475 break dequeue 476 } 477 default: 478 break dequeue 479 } 480 } 481 482 if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil { 483 uplink.logger.Warn("Failed to batch write packets to natConn", 484 zap.Stringer("clientAddress", uplink.clientAddrPort), 485 zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort), 486 zap.String("client", uplink.clientName), 487 zap.Stringer("lastWriteDestAddress", destAddrPort), 488 zap.Error(err), 489 ) 490 } 491 492 if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil { 493 uplink.logger.Warn("Failed to set read deadline on natConn", 494 zap.Stringer("clientAddress", uplink.clientAddrPort), 495 zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort), 496 zap.String("client", uplink.clientName), 497 zap.Stringer("lastWriteDestAddress", destAddrPort), 498 zap.Duration("natTimeout", uplink.natTimeout), 499 zap.Error(err), 500 ) 501 } 502 503 sendmmsgCount++ 504 packetsSent += uint64(count) 505 burstBatchSize = max(burstBatchSize, count) 506 507 qpvecn := qpvec[:count] 508 509 for i := range qpvecn { 510 s.putQueuedPacket(qpvecn[i]) 511 } 512 513 if !ok { 514 break 515 } 516 } 517 518 uplink.logger.Info("Finished relay serverConn -> natConn", 519 zap.Stringer("clientAddress", uplink.clientAddrPort), 520 zap.String("client", uplink.clientName), 521 zap.Stringer("lastWriteDestAddress", destAddrPort), 522 zap.Uint64("sendmmsgCount", sendmmsgCount), 523 zap.Uint64("packetsSent", packetsSent), 524 zap.Uint64("payloadBytesSent", payloadBytesSent), 525 zap.Int("burstBatchSize", burstBatchSize), 526 ) 527 528 s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent) 529 } 530 531 // getQueuedPacket retrieves a queued packet from the pool. 532 func (s *UDPTransparentRelay) getQueuedPacket() *transparentQueuedPacket { 533 return s.queuedPacketPool.Get().(*transparentQueuedPacket) 534 } 535 536 // putQueuedPacket puts the queued packet back into the pool. 537 func (s *UDPTransparentRelay) putQueuedPacket(queuedPacket *transparentQueuedPacket) { 538 s.queuedPacketPool.Put(queuedPacket) 539 } 540 541 type transparentConn struct { 542 mwc *conn.MmsgWConn 543 iovec []unix.Iovec 544 msgvec []conn.Mmsghdr 545 n int 546 } 547 548 func (s *UDPTransparentRelay) newTransparentConn(ctx context.Context, address string, relayBatchSize int, name *byte, namelen uint32) (*transparentConn, error) { 549 c, err := s.transparentConnListenConfig.ListenUDPRawConn(ctx, "udp", address) 550 if err != nil { 551 return nil, err 552 } 553 554 iovec := make([]unix.Iovec, relayBatchSize) 555 msgvec := make([]conn.Mmsghdr, relayBatchSize) 556 557 for i := range msgvec { 558 msgvec[i].Msghdr.Name = name 559 msgvec[i].Msghdr.Namelen = namelen 560 msgvec[i].Msghdr.Iov = &iovec[i] 561 msgvec[i].Msghdr.SetIovlen(1) 562 } 563 564 return &transparentConn{ 565 mwc: c.WConn(), 566 iovec: iovec, 567 msgvec: msgvec, 568 }, nil 569 } 570 571 func (tc *transparentConn) putMsg(base *byte, length int) { 572 tc.iovec[tc.n].Base = base 573 tc.iovec[tc.n].SetLen(length) 574 tc.n++ 575 } 576 577 func (tc *transparentConn) writeMsgvec() (sendmmsgCount, packetsSent int, err error) { 578 if tc.n == 0 { 579 return 580 } 581 packetsSent = tc.n 582 tc.n = 0 583 return 1, packetsSent, tc.mwc.WriteMsgs(tc.msgvec[:packetsSent], 0) 584 } 585 586 func (tc *transparentConn) close() error { 587 return tc.mwc.Close() 588 } 589 590 func (s *UDPTransparentRelay) relayNatConnToTransparentConnSendmmsg(ctx context.Context, downlink transparentDownlink) { 591 var ( 592 sendmmsgCount uint64 593 packetsSent uint64 594 payloadBytesSent uint64 595 burstBatchSize int 596 ) 597 598 maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr()) 599 name, namelen := conn.AddrPortUnmappedToSockaddr(downlink.clientAddrPort) 600 tcMap := make(map[netip.AddrPort]*transparentConn) 601 602 savec := make([]unix.RawSockaddrInet6, downlink.relayBatchSize) 603 bufvec := make([][]byte, downlink.relayBatchSize) 604 iovec := make([]unix.Iovec, downlink.relayBatchSize) 605 msgvec := make([]conn.Mmsghdr, downlink.relayBatchSize) 606 607 for i := range downlink.relayBatchSize { 608 packetBuf := make([]byte, downlink.natConnRecvBufSize) 609 bufvec[i] = packetBuf 610 611 iovec[i].Base = &packetBuf[0] 612 iovec[i].SetLen(downlink.natConnRecvBufSize) 613 614 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&savec[i])) 615 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 616 msgvec[i].Msghdr.Iov = &iovec[i] 617 msgvec[i].Msghdr.SetIovlen(1) 618 } 619 620 for { 621 nr, err := downlink.natConn.ReadMsgs(msgvec, 0) 622 if err != nil { 623 if errors.Is(err, os.ErrDeadlineExceeded) { 624 break 625 } 626 627 downlink.logger.Warn("Failed to batch read packets from natConn", 628 zap.Stringer("clientAddress", downlink.clientAddrPort), 629 zap.String("client", downlink.clientName), 630 zap.Error(err), 631 ) 632 continue 633 } 634 635 var ns int 636 msgvecn := msgvec[:nr] 637 638 for i := range msgvecn { 639 msg := &msgvecn[i] 640 641 packetSourceAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen) 642 if err != nil { 643 downlink.logger.Warn("Failed to parse sockaddr of packet from natConn", 644 zap.Stringer("clientAddress", downlink.clientAddrPort), 645 zap.String("client", downlink.clientName), 646 zap.Error(err), 647 ) 648 continue 649 } 650 651 if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil { 652 downlink.logger.Warn("Packet from natConn discarded", 653 zap.Stringer("clientAddress", downlink.clientAddrPort), 654 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 655 zap.String("client", downlink.clientName), 656 zap.Uint32("packetLength", msg.Msglen), 657 zap.Error(err), 658 ) 659 continue 660 } 661 662 packetBuf := bufvec[i] 663 664 payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, 0, int(msg.Msglen)) 665 if err != nil { 666 downlink.logger.Warn("Failed to unpack packet from natConn", 667 zap.Stringer("clientAddress", downlink.clientAddrPort), 668 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 669 zap.String("client", downlink.clientName), 670 zap.Uint32("packetLength", msg.Msglen), 671 zap.Error(err), 672 ) 673 continue 674 } 675 676 if payloadLength > maxClientPacketSize { 677 downlink.logger.Warn("Payload too large to send to client", 678 zap.Stringer("clientAddress", downlink.clientAddrPort), 679 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 680 zap.String("client", downlink.clientName), 681 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 682 zap.Int("payloadLength", payloadLength), 683 zap.Int("maxClientPacketSize", maxClientPacketSize), 684 ) 685 continue 686 } 687 688 tc := tcMap[payloadSourceAddrPort] 689 if tc == nil { 690 tc, err = s.newTransparentConn(ctx, payloadSourceAddrPort.String(), downlink.relayBatchSize, name, namelen) 691 if err != nil { 692 downlink.logger.Warn("Failed to create transparentConn", 693 zap.Stringer("clientAddress", downlink.clientAddrPort), 694 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 695 zap.String("client", downlink.clientName), 696 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 697 zap.Error(err), 698 ) 699 continue 700 } 701 tcMap[payloadSourceAddrPort] = tc 702 } 703 tc.putMsg(&packetBuf[payloadStart], payloadLength) 704 ns++ 705 payloadBytesSent += uint64(payloadLength) 706 } 707 708 if ns == 0 { 709 continue 710 } 711 712 for payloadSourceAddrPort, tc := range tcMap { 713 sc, ps, err := tc.writeMsgvec() 714 if err != nil { 715 downlink.logger.Warn("Failed to batch write packets to transparentConn", 716 zap.Stringer("clientAddress", downlink.clientAddrPort), 717 zap.String("client", downlink.clientName), 718 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 719 zap.Error(err), 720 ) 721 } 722 723 sendmmsgCount += uint64(sc) 724 packetsSent += uint64(ps) 725 burstBatchSize = max(burstBatchSize, ps) 726 } 727 } 728 729 for payloadSourceAddrPort, tc := range tcMap { 730 if err := tc.close(); err != nil { 731 downlink.logger.Warn("Failed to close transparentConn", 732 zap.Stringer("clientAddress", downlink.clientAddrPort), 733 zap.String("client", downlink.clientName), 734 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 735 zap.Error(err), 736 ) 737 } 738 } 739 740 downlink.logger.Info("Finished relay transparentConn <- natConn", 741 zap.Stringer("clientAddress", downlink.clientAddrPort), 742 zap.String("client", downlink.clientName), 743 zap.Uint64("sendmmsgCount", sendmmsgCount), 744 zap.Uint64("packetsSent", packetsSent), 745 zap.Uint64("payloadBytesSent", payloadBytesSent), 746 zap.Int("burstBatchSize", burstBatchSize), 747 ) 748 749 s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent) 750 } 751 752 // Stop implements the Relay Stop method. 753 func (s *UDPTransparentRelay) Stop() error { 754 for i := range s.listeners { 755 lnc := &s.listeners[i] 756 if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 757 lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err)) 758 } 759 } 760 761 // Wait for serverConn receive goroutines to exit, 762 // so there won't be any new sessions added to the table. 763 s.mwg.Wait() 764 765 s.mu.Lock() 766 for clientAddrPort, entry := range s.table { 767 natConn := entry.state.Swap(entry.serverConn) 768 if natConn == nil { 769 continue 770 } 771 772 if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 773 entry.logger.Warn("Failed to set read deadline on natConn", 774 zap.Stringer("clientAddress", clientAddrPort), 775 zap.Error(err), 776 ) 777 } 778 } 779 s.mu.Unlock() 780 781 // Wait for all relay goroutines to exit before closing serverConn, 782 // so in-flight packets can be written out. 783 s.wg.Wait() 784 785 for i := range s.listeners { 786 lnc := &s.listeners[i] 787 if err := lnc.serverConn.Close(); err != nil { 788 lnc.logger.Warn("Failed to close serverConn", zap.Error(err)) 789 } 790 } 791 792 return nil 793 }