github.com/database64128/shadowsocks-go@v1.7.0/service/udp_transparent_linux.go (about)

     1  package service
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"net/netip"
     7  	"os"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  	"unsafe"
    12  
    13  	"github.com/database64128/shadowsocks-go/conn"
    14  	"github.com/database64128/shadowsocks-go/router"
    15  	"github.com/database64128/shadowsocks-go/stats"
    16  	"github.com/database64128/shadowsocks-go/zerocopy"
    17  	"github.com/database64128/tfo-go/v2"
    18  	"go.uber.org/zap"
    19  	"golang.org/x/sys/unix"
    20  )
    21  
    22  // transparentQueuedPacket is the structure used by send channels to queue packets for sending.
    23  type transparentQueuedPacket struct {
    24  	buf            []byte
    25  	targetAddrPort netip.AddrPort
    26  	msglen         uint32
    27  }
    28  
    29  // transparentNATEntry is an entry in the tproxy NAT table.
    30  type transparentNATEntry struct {
    31  	// state synchronizes session initialization and shutdown.
    32  	//
    33  	//  - Swap the natConn in to signal initialization completion.
    34  	//  - Swap the serverConn in to signal shutdown.
    35  	//
    36  	// Callers must check the swapped-out value to determine the next action.
    37  	//
    38  	//  - During initialization, if the swapped-out value is non-nil,
    39  	//    initialization must not proceed.
    40  	//  - During shutdown, if the swapped-out value is nil, preceed to the next entry.
    41  	state         atomic.Pointer[net.UDPConn]
    42  	natConnSendCh chan<- *transparentQueuedPacket
    43  }
    44  
    45  // transparentUplink is used for passing information about relay uplink to the relay goroutine.
    46  type transparentUplink struct {
    47  	clientAddrPort netip.AddrPort
    48  	natConn        *conn.MmsgWConn
    49  	natConnSendCh  <-chan *transparentQueuedPacket
    50  	natConnPacker  zerocopy.ClientPacker
    51  }
    52  
    53  // transparentDownlink is used for passing information about relay downlink to the relay goroutine.
    54  type transparentDownlink struct {
    55  	clientAddrPort     netip.AddrPort
    56  	natConn            *conn.MmsgRConn
    57  	natConnRecvBufSize int
    58  	natConnUnpacker    zerocopy.ClientUnpacker
    59  }
    60  
    61  // UDPTransparentRelay is like [UDPNATRelay], but for transparent proxy.
    62  type UDPTransparentRelay struct {
    63  	serverName                  string
    64  	listenAddress               string
    65  	serverIndex                 int
    66  	mtu                         int
    67  	packetBufFrontHeadroom      int
    68  	packetBufRecvSize           int
    69  	relayBatchSize              int
    70  	serverRecvBatchSize         int
    71  	sendChannelCapacity         int
    72  	natTimeout                  time.Duration
    73  	serverConn                  *net.UDPConn
    74  	serverConnlistenConfig      tfo.ListenConfig
    75  	transparentConnListenConfig tfo.ListenConfig
    76  	collector                   stats.Collector
    77  	router                      *router.Router
    78  	logger                      *zap.Logger
    79  	queuedPacketPool            sync.Pool
    80  	mu                          sync.Mutex
    81  	wg                          sync.WaitGroup
    82  	mwg                         sync.WaitGroup
    83  	table                       map[netip.AddrPort]*transparentNATEntry
    84  }
    85  
    86  func NewUDPTransparentRelay(
    87  	serverName, listenAddress string,
    88  	relayBatchSize, serverRecvBatchSize, sendChannelCapacity, serverIndex, mtu int,
    89  	maxClientPackerHeadroom zerocopy.Headroom,
    90  	natTimeout time.Duration,
    91  	serverConnlistenConfig, transparentConnListenConfig tfo.ListenConfig,
    92  	collector stats.Collector,
    93  	router *router.Router,
    94  	logger *zap.Logger,
    95  ) (Relay, error) {
    96  	packetBufRecvSize := mtu - zerocopy.IPv4HeaderLength - zerocopy.UDPHeaderLength
    97  	packetBufSize := maxClientPackerHeadroom.Front + packetBufRecvSize + maxClientPackerHeadroom.Rear
    98  	return &UDPTransparentRelay{
    99  		serverName:                  serverName,
   100  		listenAddress:               listenAddress,
   101  		serverIndex:                 serverIndex,
   102  		mtu:                         mtu,
   103  		packetBufFrontHeadroom:      maxClientPackerHeadroom.Front,
   104  		packetBufRecvSize:           packetBufRecvSize,
   105  		relayBatchSize:              relayBatchSize,
   106  		serverRecvBatchSize:         serverRecvBatchSize,
   107  		sendChannelCapacity:         sendChannelCapacity,
   108  		natTimeout:                  natTimeout,
   109  		serverConnlistenConfig:      serverConnlistenConfig,
   110  		transparentConnListenConfig: transparentConnListenConfig,
   111  		collector:                   collector,
   112  		router:                      router,
   113  		logger:                      logger,
   114  		queuedPacketPool: sync.Pool{
   115  			New: func() any {
   116  				return &transparentQueuedPacket{
   117  					buf: make([]byte, packetBufSize),
   118  				}
   119  			},
   120  		},
   121  		table: make(map[netip.AddrPort]*transparentNATEntry),
   122  	}, nil
   123  }
   124  
   125  // String implements the Relay String method.
   126  func (s *UDPTransparentRelay) String() string {
   127  	return "UDP transparent relay service for " + s.serverName
   128  }
   129  
   130  // Start implements the Relay Start method.
   131  func (s *UDPTransparentRelay) Start() error {
   132  	serverConn, err := conn.ListenUDPRawConn(s.serverConnlistenConfig, "udp", s.listenAddress)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	s.serverConn = serverConn.UDPConn
   137  
   138  	s.mwg.Add(1)
   139  
   140  	go func() {
   141  		s.recvFromServerConnRecvmmsg(serverConn.RConn())
   142  		s.mwg.Done()
   143  	}()
   144  
   145  	s.logger.Info("Started UDP transparent relay service",
   146  		zap.String("server", s.serverName),
   147  		zap.String("listenAddress", s.listenAddress),
   148  	)
   149  
   150  	return nil
   151  }
   152  
   153  func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(serverConn *conn.MmsgRConn) {
   154  	n := s.serverRecvBatchSize
   155  	qpvec := make([]*transparentQueuedPacket, n)
   156  	namevec := make([]unix.RawSockaddrInet6, n)
   157  	iovec := make([]unix.Iovec, n)
   158  	cmsgvec := make([][]byte, n)
   159  	msgvec := make([]conn.Mmsghdr, n)
   160  
   161  	for i := range msgvec {
   162  		cmsgBuf := make([]byte, conn.TransparentSocketControlMessageBufferSize)
   163  		cmsgvec[i] = cmsgBuf
   164  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
   165  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   166  		msgvec[i].Msghdr.Iov = &iovec[i]
   167  		msgvec[i].Msghdr.SetIovlen(1)
   168  		msgvec[i].Msghdr.Control = &cmsgBuf[0]
   169  	}
   170  
   171  	var (
   172  		err                  error
   173  		recvmmsgCount        uint64
   174  		packetsReceived      uint64
   175  		payloadBytesReceived uint64
   176  		burstBatchSize       int
   177  	)
   178  
   179  	for {
   180  		for i := range iovec[:n] {
   181  			queuedPacket := s.getQueuedPacket()
   182  			qpvec[i] = queuedPacket
   183  			iovec[i].Base = &queuedPacket.buf[s.packetBufFrontHeadroom]
   184  			iovec[i].SetLen(s.packetBufRecvSize)
   185  			msgvec[i].Msghdr.SetControllen(conn.TransparentSocketControlMessageBufferSize)
   186  		}
   187  
   188  		n, err = serverConn.ReadMsgs(msgvec, 0)
   189  		if err != nil {
   190  			if errors.Is(err, os.ErrDeadlineExceeded) {
   191  				break
   192  			}
   193  
   194  			s.logger.Warn("Failed to batch read packets from serverConn",
   195  				zap.String("server", s.serverName),
   196  				zap.String("listenAddress", s.listenAddress),
   197  				zap.Error(err),
   198  			)
   199  
   200  			n = 1
   201  			s.putQueuedPacket(qpvec[0])
   202  			continue
   203  		}
   204  
   205  		recvmmsgCount++
   206  		packetsReceived += uint64(n)
   207  		if burstBatchSize < n {
   208  			burstBatchSize = n
   209  		}
   210  
   211  		s.mu.Lock()
   212  
   213  		msgvecn := msgvec[:n]
   214  
   215  		for i := range msgvecn {
   216  			msg := &msgvecn[i]
   217  			queuedPacket := qpvec[i]
   218  
   219  			clientAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   220  			if err != nil {
   221  				s.logger.Warn("Failed to parse sockaddr of packet from serverConn",
   222  					zap.String("server", s.serverName),
   223  					zap.String("listenAddress", s.listenAddress),
   224  					zap.Error(err),
   225  				)
   226  
   227  				s.putQueuedPacket(queuedPacket)
   228  				continue
   229  			}
   230  
   231  			if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil {
   232  				s.logger.Warn("Packet from serverConn discarded",
   233  					zap.String("server", s.serverName),
   234  					zap.String("listenAddress", s.listenAddress),
   235  					zap.Stringer("clientAddress", clientAddrPort),
   236  					zap.Uint32("packetLength", msg.Msglen),
   237  					zap.Error(err),
   238  				)
   239  
   240  				s.putQueuedPacket(queuedPacket)
   241  				continue
   242  			}
   243  
   244  			queuedPacket.targetAddrPort, err = conn.ParseOrigDstAddrCmsg(cmsgvec[i][:msg.Msghdr.Controllen])
   245  			if err != nil {
   246  				s.logger.Warn("Failed to parse original destination address control message from serverConn",
   247  					zap.String("server", s.serverName),
   248  					zap.String("listenAddress", s.listenAddress),
   249  					zap.Stringer("clientAddress", clientAddrPort),
   250  					zap.Error(err),
   251  				)
   252  
   253  				s.putQueuedPacket(queuedPacket)
   254  				continue
   255  			}
   256  
   257  			queuedPacket.msglen = msg.Msglen
   258  			payloadBytesReceived += uint64(msg.Msglen)
   259  
   260  			entry := s.table[clientAddrPort]
   261  			if entry == nil {
   262  				natConnSendCh := make(chan *transparentQueuedPacket, s.sendChannelCapacity)
   263  				entry = &transparentNATEntry{natConnSendCh: natConnSendCh}
   264  				s.table[clientAddrPort] = entry
   265  
   266  				go func() {
   267  					var sendChClean bool
   268  
   269  					defer func() {
   270  						s.mu.Lock()
   271  						close(natConnSendCh)
   272  						delete(s.table, clientAddrPort)
   273  						s.mu.Unlock()
   274  
   275  						if !sendChClean {
   276  							for queuedPacket := range natConnSendCh {
   277  								s.putQueuedPacket(queuedPacket)
   278  							}
   279  						}
   280  					}()
   281  
   282  					c, err := s.router.GetUDPClient(router.RequestInfo{
   283  						ServerIndex:    s.serverIndex,
   284  						SourceAddrPort: clientAddrPort,
   285  						TargetAddr:     conn.AddrFromIPPort(queuedPacket.targetAddrPort),
   286  					})
   287  					if err != nil {
   288  						s.logger.Warn("Failed to get UDP client for new NAT session",
   289  							zap.String("server", s.serverName),
   290  							zap.String("listenAddress", s.listenAddress),
   291  							zap.Stringer("clientAddress", clientAddrPort),
   292  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   293  							zap.Error(err),
   294  						)
   295  						return
   296  					}
   297  
   298  					// Only add for the current goroutine here, since we don't want the router to block exiting.
   299  					s.wg.Add(1)
   300  					defer s.wg.Done()
   301  
   302  					clientInfo, natConnPacker, natConnUnpacker, err := c.NewSession()
   303  					if err != nil {
   304  						s.logger.Warn("Failed to create new UDP client session",
   305  							zap.String("server", s.serverName),
   306  							zap.String("client", clientInfo.Name),
   307  							zap.String("listenAddress", s.listenAddress),
   308  							zap.Stringer("clientAddress", clientAddrPort),
   309  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   310  							zap.Error(err),
   311  						)
   312  						return
   313  					}
   314  
   315  					natConn, err := conn.ListenUDPRawConn(clientInfo.ListenConfig, "udp", "")
   316  					if err != nil {
   317  						s.logger.Warn("Failed to create UDP socket for new NAT session",
   318  							zap.String("server", s.serverName),
   319  							zap.String("client", clientInfo.Name),
   320  							zap.String("listenAddress", s.listenAddress),
   321  							zap.Stringer("clientAddress", clientAddrPort),
   322  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   323  							zap.Error(err),
   324  						)
   325  						return
   326  					}
   327  
   328  					if err = natConn.SetReadDeadline(time.Now().Add(s.natTimeout)); err != nil {
   329  						s.logger.Warn("Failed to set read deadline on natConn",
   330  							zap.String("server", s.serverName),
   331  							zap.String("client", clientInfo.Name),
   332  							zap.String("listenAddress", s.listenAddress),
   333  							zap.Stringer("clientAddress", clientAddrPort),
   334  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   335  							zap.Duration("natTimeout", s.natTimeout),
   336  							zap.Error(err),
   337  						)
   338  						natConn.Close()
   339  						return
   340  					}
   341  
   342  					oldState := entry.state.Swap(natConn.UDPConn)
   343  					if oldState != nil {
   344  						natConn.Close()
   345  						return
   346  					}
   347  
   348  					// No more early returns!
   349  					sendChClean = true
   350  
   351  					s.logger.Info("UDP transparent relay started",
   352  						zap.String("server", s.serverName),
   353  						zap.String("client", clientInfo.Name),
   354  						zap.String("listenAddress", s.listenAddress),
   355  						zap.Stringer("clientAddress", clientAddrPort),
   356  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   357  					)
   358  
   359  					s.wg.Add(1)
   360  
   361  					go func() {
   362  						s.relayServerConnToNatConnSendmmsg(transparentUplink{
   363  							clientAddrPort: clientAddrPort,
   364  							natConn:        natConn.WConn(),
   365  							natConnSendCh:  natConnSendCh,
   366  							natConnPacker:  natConnPacker,
   367  						})
   368  						natConn.Close()
   369  						s.wg.Done()
   370  					}()
   371  
   372  					s.relayNatConnToTransparentConnSendmmsg(transparentDownlink{
   373  						clientAddrPort:     clientAddrPort,
   374  						natConn:            natConn.RConn(),
   375  						natConnRecvBufSize: clientInfo.MaxPacketSize,
   376  						natConnUnpacker:    natConnUnpacker,
   377  					})
   378  				}()
   379  
   380  				if ce := s.logger.Check(zap.DebugLevel, "New UDP transparent session"); ce != nil {
   381  					ce.Write(
   382  						zap.String("server", s.serverName),
   383  						zap.String("listenAddress", s.listenAddress),
   384  						zap.Stringer("clientAddress", clientAddrPort),
   385  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   386  					)
   387  				}
   388  			}
   389  
   390  			select {
   391  			case entry.natConnSendCh <- queuedPacket:
   392  			default:
   393  				if ce := s.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil {
   394  					ce.Write(
   395  						zap.String("server", s.serverName),
   396  						zap.String("listenAddress", s.listenAddress),
   397  						zap.Stringer("clientAddress", clientAddrPort),
   398  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   399  					)
   400  				}
   401  
   402  				s.putQueuedPacket(queuedPacket)
   403  			}
   404  		}
   405  
   406  		s.mu.Unlock()
   407  	}
   408  
   409  	for i := range qpvec {
   410  		s.putQueuedPacket(qpvec[i])
   411  	}
   412  
   413  	s.logger.Info("Finished receiving from serverConn",
   414  		zap.String("server", s.serverName),
   415  		zap.String("listenAddress", s.listenAddress),
   416  		zap.Uint64("recvmmsgCount", recvmmsgCount),
   417  		zap.Uint64("packetsReceived", packetsReceived),
   418  		zap.Uint64("payloadBytesReceived", payloadBytesReceived),
   419  		zap.Int("burstBatchSize", burstBatchSize),
   420  	)
   421  }
   422  
   423  func (s *UDPTransparentRelay) relayServerConnToNatConnSendmmsg(uplink transparentUplink) {
   424  	var (
   425  		destAddrPort     netip.AddrPort
   426  		packetStart      int
   427  		packetLength     int
   428  		err              error
   429  		sendmmsgCount    uint64
   430  		packetsSent      uint64
   431  		payloadBytesSent uint64
   432  		burstBatchSize   int
   433  	)
   434  
   435  	qpvec := make([]*transparentQueuedPacket, s.relayBatchSize)
   436  	namevec := make([]unix.RawSockaddrInet6, s.relayBatchSize)
   437  	iovec := make([]unix.Iovec, s.relayBatchSize)
   438  	msgvec := make([]conn.Mmsghdr, s.relayBatchSize)
   439  
   440  	for i := range msgvec {
   441  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
   442  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   443  		msgvec[i].Msghdr.Iov = &iovec[i]
   444  		msgvec[i].Msghdr.SetIovlen(1)
   445  	}
   446  
   447  main:
   448  	for {
   449  		var count int
   450  
   451  		// Block on first dequeue op.
   452  		queuedPacket, ok := <-uplink.natConnSendCh
   453  		if !ok {
   454  			break
   455  		}
   456  
   457  	dequeue:
   458  		for {
   459  			destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(queuedPacket.buf, conn.AddrFromIPPort(queuedPacket.targetAddrPort), s.packetBufFrontHeadroom, int(queuedPacket.msglen))
   460  			if err != nil {
   461  				s.logger.Warn("Failed to pack packet for natConn",
   462  					zap.String("server", s.serverName),
   463  					zap.String("listenAddress", s.listenAddress),
   464  					zap.Stringer("clientAddress", uplink.clientAddrPort),
   465  					zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   466  					zap.Uint32("payloadLength", queuedPacket.msglen),
   467  					zap.Error(err),
   468  				)
   469  
   470  				s.putQueuedPacket(queuedPacket)
   471  
   472  				if count == 0 {
   473  					continue main
   474  				}
   475  				goto next
   476  			}
   477  
   478  			qpvec[count] = queuedPacket
   479  			namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort)
   480  			iovec[count].Base = &queuedPacket.buf[packetStart]
   481  			iovec[count].SetLen(packetLength)
   482  			count++
   483  			payloadBytesSent += uint64(queuedPacket.msglen)
   484  
   485  			if count == s.relayBatchSize {
   486  				break
   487  			}
   488  
   489  		next:
   490  			select {
   491  			case queuedPacket, ok = <-uplink.natConnSendCh:
   492  				if !ok {
   493  					break dequeue
   494  				}
   495  			default:
   496  				break dequeue
   497  			}
   498  		}
   499  
   500  		if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil {
   501  			s.logger.Warn("Failed to batch write packets to natConn",
   502  				zap.String("server", s.serverName),
   503  				zap.String("listenAddress", s.listenAddress),
   504  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   505  				zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort),
   506  				zap.Stringer("lastWriteDestAddress", destAddrPort),
   507  				zap.Error(err),
   508  			)
   509  		}
   510  
   511  		if err := uplink.natConn.SetReadDeadline(time.Now().Add(s.natTimeout)); err != nil {
   512  			s.logger.Warn("Failed to set read deadline on natConn",
   513  				zap.String("server", s.serverName),
   514  				zap.String("listenAddress", s.listenAddress),
   515  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   516  				zap.Duration("natTimeout", s.natTimeout),
   517  				zap.Error(err),
   518  			)
   519  		}
   520  
   521  		sendmmsgCount++
   522  		packetsSent += uint64(count)
   523  		if burstBatchSize < count {
   524  			burstBatchSize = count
   525  		}
   526  
   527  		qpvecn := qpvec[:count]
   528  
   529  		for i := range qpvecn {
   530  			s.putQueuedPacket(qpvecn[i])
   531  		}
   532  
   533  		if !ok {
   534  			break
   535  		}
   536  	}
   537  
   538  	s.logger.Info("Finished relay serverConn -> natConn",
   539  		zap.String("server", s.serverName),
   540  		zap.String("listenAddress", s.listenAddress),
   541  		zap.Stringer("clientAddress", uplink.clientAddrPort),
   542  		zap.Stringer("lastWriteDestAddress", destAddrPort),
   543  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   544  		zap.Uint64("packetsSent", packetsSent),
   545  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   546  		zap.Int("burstBatchSize", burstBatchSize),
   547  	)
   548  
   549  	s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent)
   550  }
   551  
   552  // getQueuedPacket retrieves a queued packet from the pool.
   553  func (s *UDPTransparentRelay) getQueuedPacket() *transparentQueuedPacket {
   554  	return s.queuedPacketPool.Get().(*transparentQueuedPacket)
   555  }
   556  
   557  // putQueuedPacket puts the queued packet back into the pool.
   558  func (s *UDPTransparentRelay) putQueuedPacket(queuedPacket *transparentQueuedPacket) {
   559  	s.queuedPacketPool.Put(queuedPacket)
   560  }
   561  
   562  type transparentConn struct {
   563  	mwc    *conn.MmsgWConn
   564  	iovec  []unix.Iovec
   565  	msgvec []conn.Mmsghdr
   566  	n      int
   567  }
   568  
   569  func (s *UDPTransparentRelay) newTransparentConn(address string, name *byte, namelen uint32) (*transparentConn, error) {
   570  	c, err := conn.ListenUDPRawConn(s.transparentConnListenConfig, "udp", address)
   571  	if err != nil {
   572  		return nil, err
   573  	}
   574  
   575  	iovec := make([]unix.Iovec, s.relayBatchSize)
   576  	msgvec := make([]conn.Mmsghdr, s.relayBatchSize)
   577  
   578  	for i := range msgvec {
   579  		msgvec[i].Msghdr.Name = name
   580  		msgvec[i].Msghdr.Namelen = namelen
   581  		msgvec[i].Msghdr.Iov = &iovec[i]
   582  		msgvec[i].Msghdr.SetIovlen(1)
   583  	}
   584  
   585  	return &transparentConn{
   586  		mwc:    c.WConn(),
   587  		iovec:  iovec,
   588  		msgvec: msgvec,
   589  	}, nil
   590  }
   591  
   592  func (tc *transparentConn) putMsg(base *byte, length int) {
   593  	tc.iovec[tc.n].Base = base
   594  	tc.iovec[tc.n].SetLen(length)
   595  	tc.n++
   596  }
   597  
   598  func (tc *transparentConn) writeMsgvec() (sendmmsgCount, packetsSent int, err error) {
   599  	if tc.n == 0 {
   600  		return
   601  	}
   602  	packetsSent = tc.n
   603  	tc.n = 0
   604  	return 1, packetsSent, tc.mwc.WriteMsgs(tc.msgvec[:packetsSent], 0)
   605  }
   606  
   607  func (tc *transparentConn) close() error {
   608  	return tc.mwc.Close()
   609  }
   610  
   611  func (s *UDPTransparentRelay) relayNatConnToTransparentConnSendmmsg(downlink transparentDownlink) {
   612  	var (
   613  		sendmmsgCount    uint64
   614  		packetsSent      uint64
   615  		payloadBytesSent uint64
   616  		burstBatchSize   int
   617  	)
   618  
   619  	maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr())
   620  	name, namelen := conn.AddrPortUnmappedToSockaddr(downlink.clientAddrPort)
   621  	tcMap := make(map[netip.AddrPort]*transparentConn)
   622  
   623  	savec := make([]unix.RawSockaddrInet6, s.relayBatchSize)
   624  	bufvec := make([][]byte, s.relayBatchSize)
   625  	iovec := make([]unix.Iovec, s.relayBatchSize)
   626  	msgvec := make([]conn.Mmsghdr, s.relayBatchSize)
   627  
   628  	for i := 0; i < s.relayBatchSize; i++ {
   629  		packetBuf := make([]byte, downlink.natConnRecvBufSize)
   630  		bufvec[i] = packetBuf
   631  
   632  		iovec[i].Base = &packetBuf[0]
   633  		iovec[i].SetLen(downlink.natConnRecvBufSize)
   634  
   635  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&savec[i]))
   636  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   637  		msgvec[i].Msghdr.Iov = &iovec[i]
   638  		msgvec[i].Msghdr.SetIovlen(1)
   639  	}
   640  
   641  	for {
   642  		nr, err := downlink.natConn.ReadMsgs(msgvec, 0)
   643  		if err != nil {
   644  			if errors.Is(err, os.ErrDeadlineExceeded) {
   645  				break
   646  			}
   647  
   648  			s.logger.Warn("Failed to batch read packets from natConn",
   649  				zap.String("server", s.serverName),
   650  				zap.String("listenAddress", s.listenAddress),
   651  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   652  				zap.Error(err),
   653  			)
   654  			continue
   655  		}
   656  
   657  		var ns int
   658  		msgvecn := msgvec[:nr]
   659  
   660  		for i := range msgvecn {
   661  			msg := &msgvecn[i]
   662  
   663  			packetSourceAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   664  			if err != nil {
   665  				s.logger.Warn("Failed to parse sockaddr of packet from natConn",
   666  					zap.String("server", s.serverName),
   667  					zap.String("listenAddress", s.listenAddress),
   668  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   669  					zap.Error(err),
   670  				)
   671  				continue
   672  			}
   673  
   674  			if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil {
   675  				s.logger.Warn("Packet from natConn discarded",
   676  					zap.String("server", s.serverName),
   677  					zap.String("listenAddress", s.listenAddress),
   678  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   679  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   680  					zap.Uint32("packetLength", msg.Msglen),
   681  					zap.Error(err),
   682  				)
   683  				continue
   684  			}
   685  
   686  			packetBuf := bufvec[i]
   687  
   688  			payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, 0, int(msg.Msglen))
   689  			if err != nil {
   690  				s.logger.Warn("Failed to unpack packet from natConn",
   691  					zap.String("server", s.serverName),
   692  					zap.String("listenAddress", s.listenAddress),
   693  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   694  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   695  					zap.Uint32("packetLength", msg.Msglen),
   696  					zap.Error(err),
   697  				)
   698  				continue
   699  			}
   700  
   701  			if payloadLength > maxClientPacketSize {
   702  				s.logger.Warn("Payload too large to send to client",
   703  					zap.String("server", s.serverName),
   704  					zap.String("listenAddress", s.listenAddress),
   705  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   706  					zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   707  					zap.Int("payloadLength", payloadLength),
   708  					zap.Int("maxClientPacketSize", maxClientPacketSize),
   709  				)
   710  				continue
   711  			}
   712  
   713  			tc := tcMap[payloadSourceAddrPort]
   714  			if tc == nil {
   715  				tc, err = s.newTransparentConn(payloadSourceAddrPort.String(), name, namelen)
   716  				if err != nil {
   717  					s.logger.Warn("Failed to create transparentConn",
   718  						zap.String("server", s.serverName),
   719  						zap.String("listenAddress", s.listenAddress),
   720  						zap.Stringer("clientAddress", downlink.clientAddrPort),
   721  						zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   722  						zap.Error(err),
   723  					)
   724  					continue
   725  				}
   726  				tcMap[payloadSourceAddrPort] = tc
   727  			}
   728  			tc.putMsg(&packetBuf[payloadStart], payloadLength)
   729  			ns++
   730  			payloadBytesSent += uint64(payloadLength)
   731  		}
   732  
   733  		if ns == 0 {
   734  			continue
   735  		}
   736  
   737  		for payloadSourceAddrPort, tc := range tcMap {
   738  			sc, ps, err := tc.writeMsgvec()
   739  			if err != nil {
   740  				s.logger.Warn("Failed to batch write packets to transparentConn",
   741  					zap.String("server", s.serverName),
   742  					zap.String("listenAddress", s.listenAddress),
   743  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   744  					zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   745  					zap.Error(err),
   746  				)
   747  			}
   748  
   749  			sendmmsgCount += uint64(sc)
   750  			packetsSent += uint64(ps)
   751  			if burstBatchSize < ps {
   752  				burstBatchSize = ps
   753  			}
   754  		}
   755  	}
   756  
   757  	for payloadSourceAddrPort, tc := range tcMap {
   758  		if err := tc.close(); err != nil {
   759  			s.logger.Warn("Failed to close transparentConn",
   760  				zap.String("server", s.serverName),
   761  				zap.String("listenAddress", s.listenAddress),
   762  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   763  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   764  				zap.Error(err),
   765  			)
   766  		}
   767  	}
   768  
   769  	s.logger.Info("Finished relay transparentConn <- natConn",
   770  		zap.String("server", s.serverName),
   771  		zap.String("listenAddress", s.listenAddress),
   772  		zap.Stringer("clientAddress", downlink.clientAddrPort),
   773  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   774  		zap.Uint64("packetsSent", packetsSent),
   775  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   776  		zap.Int("burstBatchSize", burstBatchSize),
   777  	)
   778  
   779  	s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent)
   780  }
   781  
   782  // Stop implements the Relay Stop method.
   783  func (s *UDPTransparentRelay) Stop() error {
   784  	if s.serverConn == nil {
   785  		return nil
   786  	}
   787  
   788  	now := time.Now()
   789  
   790  	if err := s.serverConn.SetReadDeadline(now); err != nil {
   791  		return err
   792  	}
   793  
   794  	// Wait for serverConn receive goroutines to exit,
   795  	// so there won't be any new sessions added to the table.
   796  	s.mwg.Wait()
   797  
   798  	s.mu.Lock()
   799  	for clientAddrPort, entry := range s.table {
   800  		natConn := entry.state.Swap(s.serverConn)
   801  		if natConn == nil {
   802  			continue
   803  		}
   804  
   805  		if err := natConn.SetReadDeadline(now); err != nil {
   806  			s.logger.Warn("Failed to set read deadline on natConn",
   807  				zap.String("server", s.serverName),
   808  				zap.String("listenAddress", s.listenAddress),
   809  				zap.Stringer("clientAddress", clientAddrPort),
   810  				zap.Error(err),
   811  			)
   812  		}
   813  	}
   814  	s.mu.Unlock()
   815  
   816  	// Wait for all relay goroutines to exit before closing serverConn,
   817  	// so in-flight packets can be written out.
   818  	s.wg.Wait()
   819  
   820  	return s.serverConn.Close()
   821  }