github.com/database64128/shadowsocks-go@v1.7.0/service/udp_transparent_linux.go (about) 1 package service 2 3 import ( 4 "errors" 5 "net" 6 "net/netip" 7 "os" 8 "sync" 9 "sync/atomic" 10 "time" 11 "unsafe" 12 13 "github.com/database64128/shadowsocks-go/conn" 14 "github.com/database64128/shadowsocks-go/router" 15 "github.com/database64128/shadowsocks-go/stats" 16 "github.com/database64128/shadowsocks-go/zerocopy" 17 "github.com/database64128/tfo-go/v2" 18 "go.uber.org/zap" 19 "golang.org/x/sys/unix" 20 ) 21 22 // transparentQueuedPacket is the structure used by send channels to queue packets for sending. 23 type transparentQueuedPacket struct { 24 buf []byte 25 targetAddrPort netip.AddrPort 26 msglen uint32 27 } 28 29 // transparentNATEntry is an entry in the tproxy NAT table. 30 type transparentNATEntry struct { 31 // state synchronizes session initialization and shutdown. 32 // 33 // - Swap the natConn in to signal initialization completion. 34 // - Swap the serverConn in to signal shutdown. 35 // 36 // Callers must check the swapped-out value to determine the next action. 37 // 38 // - During initialization, if the swapped-out value is non-nil, 39 // initialization must not proceed. 40 // - During shutdown, if the swapped-out value is nil, preceed to the next entry. 41 state atomic.Pointer[net.UDPConn] 42 natConnSendCh chan<- *transparentQueuedPacket 43 } 44 45 // transparentUplink is used for passing information about relay uplink to the relay goroutine. 46 type transparentUplink struct { 47 clientAddrPort netip.AddrPort 48 natConn *conn.MmsgWConn 49 natConnSendCh <-chan *transparentQueuedPacket 50 natConnPacker zerocopy.ClientPacker 51 } 52 53 // transparentDownlink is used for passing information about relay downlink to the relay goroutine. 54 type transparentDownlink struct { 55 clientAddrPort netip.AddrPort 56 natConn *conn.MmsgRConn 57 natConnRecvBufSize int 58 natConnUnpacker zerocopy.ClientUnpacker 59 } 60 61 // UDPTransparentRelay is like [UDPNATRelay], but for transparent proxy. 62 type UDPTransparentRelay struct { 63 serverName string 64 listenAddress string 65 serverIndex int 66 mtu int 67 packetBufFrontHeadroom int 68 packetBufRecvSize int 69 relayBatchSize int 70 serverRecvBatchSize int 71 sendChannelCapacity int 72 natTimeout time.Duration 73 serverConn *net.UDPConn 74 serverConnlistenConfig tfo.ListenConfig 75 transparentConnListenConfig tfo.ListenConfig 76 collector stats.Collector 77 router *router.Router 78 logger *zap.Logger 79 queuedPacketPool sync.Pool 80 mu sync.Mutex 81 wg sync.WaitGroup 82 mwg sync.WaitGroup 83 table map[netip.AddrPort]*transparentNATEntry 84 } 85 86 func NewUDPTransparentRelay( 87 serverName, listenAddress string, 88 relayBatchSize, serverRecvBatchSize, sendChannelCapacity, serverIndex, mtu int, 89 maxClientPackerHeadroom zerocopy.Headroom, 90 natTimeout time.Duration, 91 serverConnlistenConfig, transparentConnListenConfig tfo.ListenConfig, 92 collector stats.Collector, 93 router *router.Router, 94 logger *zap.Logger, 95 ) (Relay, error) { 96 packetBufRecvSize := mtu - zerocopy.IPv4HeaderLength - zerocopy.UDPHeaderLength 97 packetBufSize := maxClientPackerHeadroom.Front + packetBufRecvSize + maxClientPackerHeadroom.Rear 98 return &UDPTransparentRelay{ 99 serverName: serverName, 100 listenAddress: listenAddress, 101 serverIndex: serverIndex, 102 mtu: mtu, 103 packetBufFrontHeadroom: maxClientPackerHeadroom.Front, 104 packetBufRecvSize: packetBufRecvSize, 105 relayBatchSize: relayBatchSize, 106 serverRecvBatchSize: serverRecvBatchSize, 107 sendChannelCapacity: sendChannelCapacity, 108 natTimeout: natTimeout, 109 serverConnlistenConfig: serverConnlistenConfig, 110 transparentConnListenConfig: transparentConnListenConfig, 111 collector: collector, 112 router: router, 113 logger: logger, 114 queuedPacketPool: sync.Pool{ 115 New: func() any { 116 return &transparentQueuedPacket{ 117 buf: make([]byte, packetBufSize), 118 } 119 }, 120 }, 121 table: make(map[netip.AddrPort]*transparentNATEntry), 122 }, nil 123 } 124 125 // String implements the Relay String method. 126 func (s *UDPTransparentRelay) String() string { 127 return "UDP transparent relay service for " + s.serverName 128 } 129 130 // Start implements the Relay Start method. 131 func (s *UDPTransparentRelay) Start() error { 132 serverConn, err := conn.ListenUDPRawConn(s.serverConnlistenConfig, "udp", s.listenAddress) 133 if err != nil { 134 return err 135 } 136 s.serverConn = serverConn.UDPConn 137 138 s.mwg.Add(1) 139 140 go func() { 141 s.recvFromServerConnRecvmmsg(serverConn.RConn()) 142 s.mwg.Done() 143 }() 144 145 s.logger.Info("Started UDP transparent relay service", 146 zap.String("server", s.serverName), 147 zap.String("listenAddress", s.listenAddress), 148 ) 149 150 return nil 151 } 152 153 func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(serverConn *conn.MmsgRConn) { 154 n := s.serverRecvBatchSize 155 qpvec := make([]*transparentQueuedPacket, n) 156 namevec := make([]unix.RawSockaddrInet6, n) 157 iovec := make([]unix.Iovec, n) 158 cmsgvec := make([][]byte, n) 159 msgvec := make([]conn.Mmsghdr, n) 160 161 for i := range msgvec { 162 cmsgBuf := make([]byte, conn.TransparentSocketControlMessageBufferSize) 163 cmsgvec[i] = cmsgBuf 164 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i])) 165 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 166 msgvec[i].Msghdr.Iov = &iovec[i] 167 msgvec[i].Msghdr.SetIovlen(1) 168 msgvec[i].Msghdr.Control = &cmsgBuf[0] 169 } 170 171 var ( 172 err error 173 recvmmsgCount uint64 174 packetsReceived uint64 175 payloadBytesReceived uint64 176 burstBatchSize int 177 ) 178 179 for { 180 for i := range iovec[:n] { 181 queuedPacket := s.getQueuedPacket() 182 qpvec[i] = queuedPacket 183 iovec[i].Base = &queuedPacket.buf[s.packetBufFrontHeadroom] 184 iovec[i].SetLen(s.packetBufRecvSize) 185 msgvec[i].Msghdr.SetControllen(conn.TransparentSocketControlMessageBufferSize) 186 } 187 188 n, err = serverConn.ReadMsgs(msgvec, 0) 189 if err != nil { 190 if errors.Is(err, os.ErrDeadlineExceeded) { 191 break 192 } 193 194 s.logger.Warn("Failed to batch read packets from serverConn", 195 zap.String("server", s.serverName), 196 zap.String("listenAddress", s.listenAddress), 197 zap.Error(err), 198 ) 199 200 n = 1 201 s.putQueuedPacket(qpvec[0]) 202 continue 203 } 204 205 recvmmsgCount++ 206 packetsReceived += uint64(n) 207 if burstBatchSize < n { 208 burstBatchSize = n 209 } 210 211 s.mu.Lock() 212 213 msgvecn := msgvec[:n] 214 215 for i := range msgvecn { 216 msg := &msgvecn[i] 217 queuedPacket := qpvec[i] 218 219 clientAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen) 220 if err != nil { 221 s.logger.Warn("Failed to parse sockaddr of packet from serverConn", 222 zap.String("server", s.serverName), 223 zap.String("listenAddress", s.listenAddress), 224 zap.Error(err), 225 ) 226 227 s.putQueuedPacket(queuedPacket) 228 continue 229 } 230 231 if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil { 232 s.logger.Warn("Packet from serverConn discarded", 233 zap.String("server", s.serverName), 234 zap.String("listenAddress", s.listenAddress), 235 zap.Stringer("clientAddress", clientAddrPort), 236 zap.Uint32("packetLength", msg.Msglen), 237 zap.Error(err), 238 ) 239 240 s.putQueuedPacket(queuedPacket) 241 continue 242 } 243 244 queuedPacket.targetAddrPort, err = conn.ParseOrigDstAddrCmsg(cmsgvec[i][:msg.Msghdr.Controllen]) 245 if err != nil { 246 s.logger.Warn("Failed to parse original destination address control message from serverConn", 247 zap.String("server", s.serverName), 248 zap.String("listenAddress", s.listenAddress), 249 zap.Stringer("clientAddress", clientAddrPort), 250 zap.Error(err), 251 ) 252 253 s.putQueuedPacket(queuedPacket) 254 continue 255 } 256 257 queuedPacket.msglen = msg.Msglen 258 payloadBytesReceived += uint64(msg.Msglen) 259 260 entry := s.table[clientAddrPort] 261 if entry == nil { 262 natConnSendCh := make(chan *transparentQueuedPacket, s.sendChannelCapacity) 263 entry = &transparentNATEntry{natConnSendCh: natConnSendCh} 264 s.table[clientAddrPort] = entry 265 266 go func() { 267 var sendChClean bool 268 269 defer func() { 270 s.mu.Lock() 271 close(natConnSendCh) 272 delete(s.table, clientAddrPort) 273 s.mu.Unlock() 274 275 if !sendChClean { 276 for queuedPacket := range natConnSendCh { 277 s.putQueuedPacket(queuedPacket) 278 } 279 } 280 }() 281 282 c, err := s.router.GetUDPClient(router.RequestInfo{ 283 ServerIndex: s.serverIndex, 284 SourceAddrPort: clientAddrPort, 285 TargetAddr: conn.AddrFromIPPort(queuedPacket.targetAddrPort), 286 }) 287 if err != nil { 288 s.logger.Warn("Failed to get UDP client for new NAT session", 289 zap.String("server", s.serverName), 290 zap.String("listenAddress", s.listenAddress), 291 zap.Stringer("clientAddress", clientAddrPort), 292 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 293 zap.Error(err), 294 ) 295 return 296 } 297 298 // Only add for the current goroutine here, since we don't want the router to block exiting. 299 s.wg.Add(1) 300 defer s.wg.Done() 301 302 clientInfo, natConnPacker, natConnUnpacker, err := c.NewSession() 303 if err != nil { 304 s.logger.Warn("Failed to create new UDP client session", 305 zap.String("server", s.serverName), 306 zap.String("client", clientInfo.Name), 307 zap.String("listenAddress", s.listenAddress), 308 zap.Stringer("clientAddress", clientAddrPort), 309 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 310 zap.Error(err), 311 ) 312 return 313 } 314 315 natConn, err := conn.ListenUDPRawConn(clientInfo.ListenConfig, "udp", "") 316 if err != nil { 317 s.logger.Warn("Failed to create UDP socket for new NAT session", 318 zap.String("server", s.serverName), 319 zap.String("client", clientInfo.Name), 320 zap.String("listenAddress", s.listenAddress), 321 zap.Stringer("clientAddress", clientAddrPort), 322 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 323 zap.Error(err), 324 ) 325 return 326 } 327 328 if err = natConn.SetReadDeadline(time.Now().Add(s.natTimeout)); err != nil { 329 s.logger.Warn("Failed to set read deadline on natConn", 330 zap.String("server", s.serverName), 331 zap.String("client", clientInfo.Name), 332 zap.String("listenAddress", s.listenAddress), 333 zap.Stringer("clientAddress", clientAddrPort), 334 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 335 zap.Duration("natTimeout", s.natTimeout), 336 zap.Error(err), 337 ) 338 natConn.Close() 339 return 340 } 341 342 oldState := entry.state.Swap(natConn.UDPConn) 343 if oldState != nil { 344 natConn.Close() 345 return 346 } 347 348 // No more early returns! 349 sendChClean = true 350 351 s.logger.Info("UDP transparent relay started", 352 zap.String("server", s.serverName), 353 zap.String("client", clientInfo.Name), 354 zap.String("listenAddress", s.listenAddress), 355 zap.Stringer("clientAddress", clientAddrPort), 356 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 357 ) 358 359 s.wg.Add(1) 360 361 go func() { 362 s.relayServerConnToNatConnSendmmsg(transparentUplink{ 363 clientAddrPort: clientAddrPort, 364 natConn: natConn.WConn(), 365 natConnSendCh: natConnSendCh, 366 natConnPacker: natConnPacker, 367 }) 368 natConn.Close() 369 s.wg.Done() 370 }() 371 372 s.relayNatConnToTransparentConnSendmmsg(transparentDownlink{ 373 clientAddrPort: clientAddrPort, 374 natConn: natConn.RConn(), 375 natConnRecvBufSize: clientInfo.MaxPacketSize, 376 natConnUnpacker: natConnUnpacker, 377 }) 378 }() 379 380 if ce := s.logger.Check(zap.DebugLevel, "New UDP transparent session"); ce != nil { 381 ce.Write( 382 zap.String("server", s.serverName), 383 zap.String("listenAddress", s.listenAddress), 384 zap.Stringer("clientAddress", clientAddrPort), 385 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 386 ) 387 } 388 } 389 390 select { 391 case entry.natConnSendCh <- queuedPacket: 392 default: 393 if ce := s.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil { 394 ce.Write( 395 zap.String("server", s.serverName), 396 zap.String("listenAddress", s.listenAddress), 397 zap.Stringer("clientAddress", clientAddrPort), 398 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 399 ) 400 } 401 402 s.putQueuedPacket(queuedPacket) 403 } 404 } 405 406 s.mu.Unlock() 407 } 408 409 for i := range qpvec { 410 s.putQueuedPacket(qpvec[i]) 411 } 412 413 s.logger.Info("Finished receiving from serverConn", 414 zap.String("server", s.serverName), 415 zap.String("listenAddress", s.listenAddress), 416 zap.Uint64("recvmmsgCount", recvmmsgCount), 417 zap.Uint64("packetsReceived", packetsReceived), 418 zap.Uint64("payloadBytesReceived", payloadBytesReceived), 419 zap.Int("burstBatchSize", burstBatchSize), 420 ) 421 } 422 423 func (s *UDPTransparentRelay) relayServerConnToNatConnSendmmsg(uplink transparentUplink) { 424 var ( 425 destAddrPort netip.AddrPort 426 packetStart int 427 packetLength int 428 err error 429 sendmmsgCount uint64 430 packetsSent uint64 431 payloadBytesSent uint64 432 burstBatchSize int 433 ) 434 435 qpvec := make([]*transparentQueuedPacket, s.relayBatchSize) 436 namevec := make([]unix.RawSockaddrInet6, s.relayBatchSize) 437 iovec := make([]unix.Iovec, s.relayBatchSize) 438 msgvec := make([]conn.Mmsghdr, s.relayBatchSize) 439 440 for i := range msgvec { 441 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i])) 442 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 443 msgvec[i].Msghdr.Iov = &iovec[i] 444 msgvec[i].Msghdr.SetIovlen(1) 445 } 446 447 main: 448 for { 449 var count int 450 451 // Block on first dequeue op. 452 queuedPacket, ok := <-uplink.natConnSendCh 453 if !ok { 454 break 455 } 456 457 dequeue: 458 for { 459 destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(queuedPacket.buf, conn.AddrFromIPPort(queuedPacket.targetAddrPort), s.packetBufFrontHeadroom, int(queuedPacket.msglen)) 460 if err != nil { 461 s.logger.Warn("Failed to pack packet for natConn", 462 zap.String("server", s.serverName), 463 zap.String("listenAddress", s.listenAddress), 464 zap.Stringer("clientAddress", uplink.clientAddrPort), 465 zap.Stringer("targetAddress", &queuedPacket.targetAddrPort), 466 zap.Uint32("payloadLength", queuedPacket.msglen), 467 zap.Error(err), 468 ) 469 470 s.putQueuedPacket(queuedPacket) 471 472 if count == 0 { 473 continue main 474 } 475 goto next 476 } 477 478 qpvec[count] = queuedPacket 479 namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort) 480 iovec[count].Base = &queuedPacket.buf[packetStart] 481 iovec[count].SetLen(packetLength) 482 count++ 483 payloadBytesSent += uint64(queuedPacket.msglen) 484 485 if count == s.relayBatchSize { 486 break 487 } 488 489 next: 490 select { 491 case queuedPacket, ok = <-uplink.natConnSendCh: 492 if !ok { 493 break dequeue 494 } 495 default: 496 break dequeue 497 } 498 } 499 500 if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil { 501 s.logger.Warn("Failed to batch write packets to natConn", 502 zap.String("server", s.serverName), 503 zap.String("listenAddress", s.listenAddress), 504 zap.Stringer("clientAddress", uplink.clientAddrPort), 505 zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort), 506 zap.Stringer("lastWriteDestAddress", destAddrPort), 507 zap.Error(err), 508 ) 509 } 510 511 if err := uplink.natConn.SetReadDeadline(time.Now().Add(s.natTimeout)); err != nil { 512 s.logger.Warn("Failed to set read deadline on natConn", 513 zap.String("server", s.serverName), 514 zap.String("listenAddress", s.listenAddress), 515 zap.Stringer("clientAddress", uplink.clientAddrPort), 516 zap.Duration("natTimeout", s.natTimeout), 517 zap.Error(err), 518 ) 519 } 520 521 sendmmsgCount++ 522 packetsSent += uint64(count) 523 if burstBatchSize < count { 524 burstBatchSize = count 525 } 526 527 qpvecn := qpvec[:count] 528 529 for i := range qpvecn { 530 s.putQueuedPacket(qpvecn[i]) 531 } 532 533 if !ok { 534 break 535 } 536 } 537 538 s.logger.Info("Finished relay serverConn -> natConn", 539 zap.String("server", s.serverName), 540 zap.String("listenAddress", s.listenAddress), 541 zap.Stringer("clientAddress", uplink.clientAddrPort), 542 zap.Stringer("lastWriteDestAddress", destAddrPort), 543 zap.Uint64("sendmmsgCount", sendmmsgCount), 544 zap.Uint64("packetsSent", packetsSent), 545 zap.Uint64("payloadBytesSent", payloadBytesSent), 546 zap.Int("burstBatchSize", burstBatchSize), 547 ) 548 549 s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent) 550 } 551 552 // getQueuedPacket retrieves a queued packet from the pool. 553 func (s *UDPTransparentRelay) getQueuedPacket() *transparentQueuedPacket { 554 return s.queuedPacketPool.Get().(*transparentQueuedPacket) 555 } 556 557 // putQueuedPacket puts the queued packet back into the pool. 558 func (s *UDPTransparentRelay) putQueuedPacket(queuedPacket *transparentQueuedPacket) { 559 s.queuedPacketPool.Put(queuedPacket) 560 } 561 562 type transparentConn struct { 563 mwc *conn.MmsgWConn 564 iovec []unix.Iovec 565 msgvec []conn.Mmsghdr 566 n int 567 } 568 569 func (s *UDPTransparentRelay) newTransparentConn(address string, name *byte, namelen uint32) (*transparentConn, error) { 570 c, err := conn.ListenUDPRawConn(s.transparentConnListenConfig, "udp", address) 571 if err != nil { 572 return nil, err 573 } 574 575 iovec := make([]unix.Iovec, s.relayBatchSize) 576 msgvec := make([]conn.Mmsghdr, s.relayBatchSize) 577 578 for i := range msgvec { 579 msgvec[i].Msghdr.Name = name 580 msgvec[i].Msghdr.Namelen = namelen 581 msgvec[i].Msghdr.Iov = &iovec[i] 582 msgvec[i].Msghdr.SetIovlen(1) 583 } 584 585 return &transparentConn{ 586 mwc: c.WConn(), 587 iovec: iovec, 588 msgvec: msgvec, 589 }, nil 590 } 591 592 func (tc *transparentConn) putMsg(base *byte, length int) { 593 tc.iovec[tc.n].Base = base 594 tc.iovec[tc.n].SetLen(length) 595 tc.n++ 596 } 597 598 func (tc *transparentConn) writeMsgvec() (sendmmsgCount, packetsSent int, err error) { 599 if tc.n == 0 { 600 return 601 } 602 packetsSent = tc.n 603 tc.n = 0 604 return 1, packetsSent, tc.mwc.WriteMsgs(tc.msgvec[:packetsSent], 0) 605 } 606 607 func (tc *transparentConn) close() error { 608 return tc.mwc.Close() 609 } 610 611 func (s *UDPTransparentRelay) relayNatConnToTransparentConnSendmmsg(downlink transparentDownlink) { 612 var ( 613 sendmmsgCount uint64 614 packetsSent uint64 615 payloadBytesSent uint64 616 burstBatchSize int 617 ) 618 619 maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr()) 620 name, namelen := conn.AddrPortUnmappedToSockaddr(downlink.clientAddrPort) 621 tcMap := make(map[netip.AddrPort]*transparentConn) 622 623 savec := make([]unix.RawSockaddrInet6, s.relayBatchSize) 624 bufvec := make([][]byte, s.relayBatchSize) 625 iovec := make([]unix.Iovec, s.relayBatchSize) 626 msgvec := make([]conn.Mmsghdr, s.relayBatchSize) 627 628 for i := 0; i < s.relayBatchSize; i++ { 629 packetBuf := make([]byte, downlink.natConnRecvBufSize) 630 bufvec[i] = packetBuf 631 632 iovec[i].Base = &packetBuf[0] 633 iovec[i].SetLen(downlink.natConnRecvBufSize) 634 635 msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&savec[i])) 636 msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6 637 msgvec[i].Msghdr.Iov = &iovec[i] 638 msgvec[i].Msghdr.SetIovlen(1) 639 } 640 641 for { 642 nr, err := downlink.natConn.ReadMsgs(msgvec, 0) 643 if err != nil { 644 if errors.Is(err, os.ErrDeadlineExceeded) { 645 break 646 } 647 648 s.logger.Warn("Failed to batch read packets from natConn", 649 zap.String("server", s.serverName), 650 zap.String("listenAddress", s.listenAddress), 651 zap.Stringer("clientAddress", downlink.clientAddrPort), 652 zap.Error(err), 653 ) 654 continue 655 } 656 657 var ns int 658 msgvecn := msgvec[:nr] 659 660 for i := range msgvecn { 661 msg := &msgvecn[i] 662 663 packetSourceAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen) 664 if err != nil { 665 s.logger.Warn("Failed to parse sockaddr of packet from natConn", 666 zap.String("server", s.serverName), 667 zap.String("listenAddress", s.listenAddress), 668 zap.Stringer("clientAddress", downlink.clientAddrPort), 669 zap.Error(err), 670 ) 671 continue 672 } 673 674 if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil { 675 s.logger.Warn("Packet from natConn discarded", 676 zap.String("server", s.serverName), 677 zap.String("listenAddress", s.listenAddress), 678 zap.Stringer("clientAddress", downlink.clientAddrPort), 679 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 680 zap.Uint32("packetLength", msg.Msglen), 681 zap.Error(err), 682 ) 683 continue 684 } 685 686 packetBuf := bufvec[i] 687 688 payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, 0, int(msg.Msglen)) 689 if err != nil { 690 s.logger.Warn("Failed to unpack packet from natConn", 691 zap.String("server", s.serverName), 692 zap.String("listenAddress", s.listenAddress), 693 zap.Stringer("clientAddress", downlink.clientAddrPort), 694 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 695 zap.Uint32("packetLength", msg.Msglen), 696 zap.Error(err), 697 ) 698 continue 699 } 700 701 if payloadLength > maxClientPacketSize { 702 s.logger.Warn("Payload too large to send to client", 703 zap.String("server", s.serverName), 704 zap.String("listenAddress", s.listenAddress), 705 zap.Stringer("clientAddress", downlink.clientAddrPort), 706 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 707 zap.Int("payloadLength", payloadLength), 708 zap.Int("maxClientPacketSize", maxClientPacketSize), 709 ) 710 continue 711 } 712 713 tc := tcMap[payloadSourceAddrPort] 714 if tc == nil { 715 tc, err = s.newTransparentConn(payloadSourceAddrPort.String(), name, namelen) 716 if err != nil { 717 s.logger.Warn("Failed to create transparentConn", 718 zap.String("server", s.serverName), 719 zap.String("listenAddress", s.listenAddress), 720 zap.Stringer("clientAddress", downlink.clientAddrPort), 721 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 722 zap.Error(err), 723 ) 724 continue 725 } 726 tcMap[payloadSourceAddrPort] = tc 727 } 728 tc.putMsg(&packetBuf[payloadStart], payloadLength) 729 ns++ 730 payloadBytesSent += uint64(payloadLength) 731 } 732 733 if ns == 0 { 734 continue 735 } 736 737 for payloadSourceAddrPort, tc := range tcMap { 738 sc, ps, err := tc.writeMsgvec() 739 if err != nil { 740 s.logger.Warn("Failed to batch write packets to transparentConn", 741 zap.String("server", s.serverName), 742 zap.String("listenAddress", s.listenAddress), 743 zap.Stringer("clientAddress", downlink.clientAddrPort), 744 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 745 zap.Error(err), 746 ) 747 } 748 749 sendmmsgCount += uint64(sc) 750 packetsSent += uint64(ps) 751 if burstBatchSize < ps { 752 burstBatchSize = ps 753 } 754 } 755 } 756 757 for payloadSourceAddrPort, tc := range tcMap { 758 if err := tc.close(); err != nil { 759 s.logger.Warn("Failed to close transparentConn", 760 zap.String("server", s.serverName), 761 zap.String("listenAddress", s.listenAddress), 762 zap.Stringer("clientAddress", downlink.clientAddrPort), 763 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 764 zap.Error(err), 765 ) 766 } 767 } 768 769 s.logger.Info("Finished relay transparentConn <- natConn", 770 zap.String("server", s.serverName), 771 zap.String("listenAddress", s.listenAddress), 772 zap.Stringer("clientAddress", downlink.clientAddrPort), 773 zap.Uint64("sendmmsgCount", sendmmsgCount), 774 zap.Uint64("packetsSent", packetsSent), 775 zap.Uint64("payloadBytesSent", payloadBytesSent), 776 zap.Int("burstBatchSize", burstBatchSize), 777 ) 778 779 s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent) 780 } 781 782 // Stop implements the Relay Stop method. 783 func (s *UDPTransparentRelay) Stop() error { 784 if s.serverConn == nil { 785 return nil 786 } 787 788 now := time.Now() 789 790 if err := s.serverConn.SetReadDeadline(now); err != nil { 791 return err 792 } 793 794 // Wait for serverConn receive goroutines to exit, 795 // so there won't be any new sessions added to the table. 796 s.mwg.Wait() 797 798 s.mu.Lock() 799 for clientAddrPort, entry := range s.table { 800 natConn := entry.state.Swap(s.serverConn) 801 if natConn == nil { 802 continue 803 } 804 805 if err := natConn.SetReadDeadline(now); err != nil { 806 s.logger.Warn("Failed to set read deadline on natConn", 807 zap.String("server", s.serverName), 808 zap.String("listenAddress", s.listenAddress), 809 zap.Stringer("clientAddress", clientAddrPort), 810 zap.Error(err), 811 ) 812 } 813 } 814 s.mu.Unlock() 815 816 // Wait for all relay goroutines to exit before closing serverConn, 817 // so in-flight packets can be written out. 818 s.wg.Wait() 819 820 return s.serverConn.Close() 821 }