github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_transparent_linux.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"net/netip"
     8  	"os"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  	"unsafe"
    13  
    14  	"github.com/database64128/shadowsocks-go/conn"
    15  	"github.com/database64128/shadowsocks-go/router"
    16  	"github.com/database64128/shadowsocks-go/stats"
    17  	"github.com/database64128/shadowsocks-go/zerocopy"
    18  	"go.uber.org/zap"
    19  	"golang.org/x/sys/unix"
    20  )
    21  
    22  // transparentQueuedPacket is the structure used by send channels to queue packets for sending.
    23  type transparentQueuedPacket struct {
    24  	buf            []byte
    25  	targetAddrPort netip.AddrPort
    26  	msglen         uint32
    27  }
    28  
    29  // transparentNATEntry is an entry in the tproxy NAT table.
    30  type transparentNATEntry struct {
    31  	// state synchronizes session initialization and shutdown.
    32  	//
    33  	//  - Swap the natConn in to signal initialization completion.
    34  	//  - Swap the serverConn in to signal shutdown.
    35  	//
    36  	// Callers must check the swapped-out value to determine the next action.
    37  	//
    38  	//  - During initialization, if the swapped-out value is non-nil,
    39  	//    initialization must not proceed.
    40  	//  - During shutdown, if the swapped-out value is nil, preceed to the next entry.
    41  	state         atomic.Pointer[net.UDPConn]
    42  	natConnSendCh chan<- *transparentQueuedPacket
    43  	serverConn    *net.UDPConn
    44  	logger        *zap.Logger
    45  }
    46  
    47  // transparentUplink is used for passing information about relay uplink to the relay goroutine.
    48  type transparentUplink struct {
    49  	clientName     string
    50  	clientAddrPort netip.AddrPort
    51  	natConn        *conn.MmsgWConn
    52  	natConnSendCh  <-chan *transparentQueuedPacket
    53  	natConnPacker  zerocopy.ClientPacker
    54  	natTimeout     time.Duration
    55  	relayBatchSize int
    56  	logger         *zap.Logger
    57  }
    58  
    59  // transparentDownlink is used for passing information about relay downlink to the relay goroutine.
    60  type transparentDownlink struct {
    61  	clientName         string
    62  	clientAddrPort     netip.AddrPort
    63  	natConn            *conn.MmsgRConn
    64  	natConnRecvBufSize int
    65  	natConnUnpacker    zerocopy.ClientUnpacker
    66  	relayBatchSize     int
    67  	logger             *zap.Logger
    68  }
    69  
    70  // UDPTransparentRelay is like [UDPNATRelay], but for transparent proxy.
    71  type UDPTransparentRelay struct {
    72  	serverName                  string
    73  	serverIndex                 int
    74  	mtu                         int
    75  	packetBufFrontHeadroom      int
    76  	packetBufRecvSize           int
    77  	listeners                   []udpRelayServerConn
    78  	transparentConnListenConfig conn.ListenConfig
    79  	collector                   stats.Collector
    80  	router                      *router.Router
    81  	logger                      *zap.Logger
    82  	queuedPacketPool            sync.Pool
    83  	mu                          sync.Mutex
    84  	wg                          sync.WaitGroup
    85  	mwg                         sync.WaitGroup
    86  	table                       map[netip.AddrPort]*transparentNATEntry
    87  }
    88  
    89  func NewUDPTransparentRelay(
    90  	serverName string,
    91  	serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int,
    92  	listeners []udpRelayServerConn,
    93  	transparentConnListenConfig conn.ListenConfig,
    94  	collector stats.Collector,
    95  	router *router.Router,
    96  	logger *zap.Logger,
    97  ) (Relay, error) {
    98  	return &UDPTransparentRelay{
    99  		serverName:                  serverName,
   100  		serverIndex:                 serverIndex,
   101  		mtu:                         mtu,
   102  		packetBufFrontHeadroom:      packetBufFrontHeadroom,
   103  		packetBufRecvSize:           packetBufRecvSize,
   104  		listeners:                   listeners,
   105  		transparentConnListenConfig: transparentConnListenConfig,
   106  		collector:                   collector,
   107  		router:                      router,
   108  		logger:                      logger,
   109  		queuedPacketPool: sync.Pool{
   110  			New: func() any {
   111  				return &transparentQueuedPacket{
   112  					buf: make([]byte, packetBufSize),
   113  				}
   114  			},
   115  		},
   116  		table: make(map[netip.AddrPort]*transparentNATEntry),
   117  	}, nil
   118  }
   119  
   120  // String implements the Relay String method.
   121  func (s *UDPTransparentRelay) String() string {
   122  	return "UDP transparent relay service for " + s.serverName
   123  }
   124  
   125  // Start implements the Relay Start method.
   126  func (s *UDPTransparentRelay) Start(ctx context.Context) error {
   127  	for i := range s.listeners {
   128  		index := i
   129  		lnc := &s.listeners[index]
   130  
   131  		serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
   132  		if err != nil {
   133  			return err
   134  		}
   135  		lnc.serverConn = serverConn.UDPConn
   136  		lnc.address = serverConn.LocalAddr().String()
   137  		lnc.logger = s.logger.With(
   138  			zap.String("server", s.serverName),
   139  			zap.Int("listener", index),
   140  			zap.String("listenAddress", lnc.address),
   141  		)
   142  
   143  		s.mwg.Add(1)
   144  
   145  		go func() {
   146  			s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn())
   147  			s.mwg.Done()
   148  		}()
   149  
   150  		lnc.logger.Info("Started UDP transparent relay service listener")
   151  	}
   152  	return nil
   153  }
   154  
   155  func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRelayServerConn, serverConn *conn.MmsgRConn) {
   156  	n := lnc.serverRecvBatchSize
   157  	qpvec := make([]*transparentQueuedPacket, n)
   158  	namevec := make([]unix.RawSockaddrInet6, n)
   159  	iovec := make([]unix.Iovec, n)
   160  	cmsgvec := make([][]byte, n)
   161  	msgvec := make([]conn.Mmsghdr, n)
   162  
   163  	for i := range msgvec {
   164  		cmsgBuf := make([]byte, conn.TransparentSocketControlMessageBufferSize)
   165  		cmsgvec[i] = cmsgBuf
   166  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
   167  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   168  		msgvec[i].Msghdr.Iov = &iovec[i]
   169  		msgvec[i].Msghdr.SetIovlen(1)
   170  		msgvec[i].Msghdr.Control = unsafe.SliceData(cmsgBuf)
   171  	}
   172  
   173  	var (
   174  		err                  error
   175  		recvmmsgCount        uint64
   176  		packetsReceived      uint64
   177  		payloadBytesReceived uint64
   178  		burstBatchSize       int
   179  	)
   180  
   181  	for {
   182  		for i := range iovec[:n] {
   183  			queuedPacket := s.getQueuedPacket()
   184  			qpvec[i] = queuedPacket
   185  			iovec[i].Base = &queuedPacket.buf[s.packetBufFrontHeadroom]
   186  			iovec[i].SetLen(s.packetBufRecvSize)
   187  			msgvec[i].Msghdr.SetControllen(conn.TransparentSocketControlMessageBufferSize)
   188  		}
   189  
   190  		n, err = serverConn.ReadMsgs(msgvec, 0)
   191  		if err != nil {
   192  			if errors.Is(err, os.ErrDeadlineExceeded) {
   193  				break
   194  			}
   195  
   196  			lnc.logger.Warn("Failed to batch read packets from serverConn", zap.Error(err))
   197  
   198  			n = 1
   199  			s.putQueuedPacket(qpvec[0])
   200  			continue
   201  		}
   202  
   203  		recvmmsgCount++
   204  		packetsReceived += uint64(n)
   205  		burstBatchSize = max(burstBatchSize, n)
   206  
   207  		s.mu.Lock()
   208  
   209  		msgvecn := msgvec[:n]
   210  
   211  		for i := range msgvecn {
   212  			msg := &msgvecn[i]
   213  			queuedPacket := qpvec[i]
   214  
   215  			clientAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   216  			if err != nil {
   217  				lnc.logger.Warn("Failed to parse sockaddr of packet from serverConn", zap.Error(err))
   218  				s.putQueuedPacket(queuedPacket)
   219  				continue
   220  			}
   221  
   222  			if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil {
   223  				lnc.logger.Warn("Packet from serverConn discarded",
   224  					zap.Stringer("clientAddress", clientAddrPort),
   225  					zap.Uint32("packetLength", msg.Msglen),
   226  					zap.Error(err),
   227  				)
   228  
   229  				s.putQueuedPacket(queuedPacket)
   230  				continue
   231  			}
   232  
   233  			queuedPacket.targetAddrPort, err = conn.ParseOrigDstAddrCmsg(cmsgvec[i][:msg.Msghdr.Controllen])
   234  			if err != nil {
   235  				lnc.logger.Warn("Failed to parse original destination address control message from serverConn",
   236  					zap.Stringer("clientAddress", clientAddrPort),
   237  					zap.Error(err),
   238  				)
   239  				s.putQueuedPacket(queuedPacket)
   240  				continue
   241  			}
   242  
   243  			queuedPacket.msglen = msg.Msglen
   244  			payloadBytesReceived += uint64(msg.Msglen)
   245  
   246  			entry := s.table[clientAddrPort]
   247  			if entry == nil {
   248  				natConnSendCh := make(chan *transparentQueuedPacket, lnc.sendChannelCapacity)
   249  				entry = &transparentNATEntry{
   250  					natConnSendCh: natConnSendCh,
   251  					serverConn:    lnc.serverConn,
   252  					logger:        lnc.logger,
   253  				}
   254  				s.table[clientAddrPort] = entry
   255  				s.wg.Add(1)
   256  
   257  				go func() {
   258  					var sendChClean bool
   259  
   260  					defer func() {
   261  						s.mu.Lock()
   262  						close(natConnSendCh)
   263  						delete(s.table, clientAddrPort)
   264  						s.mu.Unlock()
   265  
   266  						if !sendChClean {
   267  							for queuedPacket := range natConnSendCh {
   268  								s.putQueuedPacket(queuedPacket)
   269  							}
   270  						}
   271  
   272  						s.wg.Done()
   273  					}()
   274  
   275  					c, err := s.router.GetUDPClient(ctx, router.RequestInfo{
   276  						ServerIndex:    s.serverIndex,
   277  						SourceAddrPort: clientAddrPort,
   278  						TargetAddr:     conn.AddrFromIPPort(queuedPacket.targetAddrPort),
   279  					})
   280  					if err != nil {
   281  						lnc.logger.Warn("Failed to get UDP client for new NAT session",
   282  							zap.Stringer("clientAddress", clientAddrPort),
   283  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   284  							zap.Error(err),
   285  						)
   286  						return
   287  					}
   288  
   289  					clientInfo, clientSession, err := c.NewSession(ctx)
   290  					if err != nil {
   291  						lnc.logger.Warn("Failed to create new UDP client session",
   292  							zap.Stringer("clientAddress", clientAddrPort),
   293  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   294  							zap.String("client", clientInfo.Name),
   295  							zap.Error(err),
   296  						)
   297  						return
   298  					}
   299  
   300  					natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
   301  					if err != nil {
   302  						lnc.logger.Warn("Failed to create UDP socket for new NAT session",
   303  							zap.Stringer("clientAddress", clientAddrPort),
   304  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   305  							zap.String("client", clientInfo.Name),
   306  							zap.Error(err),
   307  						)
   308  						clientSession.Close()
   309  						return
   310  					}
   311  
   312  					if err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout)); err != nil {
   313  						lnc.logger.Warn("Failed to set read deadline on natConn",
   314  							zap.Stringer("clientAddress", clientAddrPort),
   315  							zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   316  							zap.String("client", clientInfo.Name),
   317  							zap.Duration("natTimeout", lnc.natTimeout),
   318  							zap.Error(err),
   319  						)
   320  						natConn.Close()
   321  						clientSession.Close()
   322  						return
   323  					}
   324  
   325  					oldState := entry.state.Swap(natConn.UDPConn)
   326  					if oldState != nil {
   327  						natConn.Close()
   328  						clientSession.Close()
   329  						return
   330  					}
   331  
   332  					// No more early returns!
   333  					sendChClean = true
   334  
   335  					lnc.logger.Info("UDP transparent relay started",
   336  						zap.Stringer("clientAddress", clientAddrPort),
   337  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   338  						zap.String("client", clientInfo.Name),
   339  					)
   340  
   341  					s.wg.Add(1)
   342  
   343  					go func() {
   344  						s.relayServerConnToNatConnSendmmsg(ctx, transparentUplink{
   345  							clientName:     clientInfo.Name,
   346  							clientAddrPort: clientAddrPort,
   347  							natConn:        natConn.WConn(),
   348  							natConnSendCh:  natConnSendCh,
   349  							natConnPacker:  clientSession.Packer,
   350  							natTimeout:     lnc.natTimeout,
   351  							relayBatchSize: lnc.relayBatchSize,
   352  							logger:         lnc.logger,
   353  						})
   354  						natConn.Close()
   355  						clientSession.Close()
   356  						s.wg.Done()
   357  					}()
   358  
   359  					s.relayNatConnToTransparentConnSendmmsg(ctx, transparentDownlink{
   360  						clientName:         clientInfo.Name,
   361  						clientAddrPort:     clientAddrPort,
   362  						natConn:            natConn.RConn(),
   363  						natConnRecvBufSize: clientSession.MaxPacketSize,
   364  						natConnUnpacker:    clientSession.Unpacker,
   365  						relayBatchSize:     lnc.relayBatchSize,
   366  					})
   367  				}()
   368  
   369  				if ce := lnc.logger.Check(zap.DebugLevel, "New UDP transparent session"); ce != nil {
   370  					ce.Write(
   371  						zap.Stringer("clientAddress", clientAddrPort),
   372  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   373  					)
   374  				}
   375  			}
   376  
   377  			select {
   378  			case entry.natConnSendCh <- queuedPacket:
   379  			default:
   380  				if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil {
   381  					ce.Write(
   382  						zap.Stringer("clientAddress", clientAddrPort),
   383  						zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   384  					)
   385  				}
   386  
   387  				s.putQueuedPacket(queuedPacket)
   388  			}
   389  		}
   390  
   391  		s.mu.Unlock()
   392  	}
   393  
   394  	for i := range qpvec {
   395  		s.putQueuedPacket(qpvec[i])
   396  	}
   397  
   398  	lnc.logger.Info("Finished receiving from serverConn",
   399  		zap.Uint64("recvmmsgCount", recvmmsgCount),
   400  		zap.Uint64("packetsReceived", packetsReceived),
   401  		zap.Uint64("payloadBytesReceived", payloadBytesReceived),
   402  		zap.Int("burstBatchSize", burstBatchSize),
   403  	)
   404  }
   405  
   406  func (s *UDPTransparentRelay) relayServerConnToNatConnSendmmsg(ctx context.Context, uplink transparentUplink) {
   407  	var (
   408  		destAddrPort     netip.AddrPort
   409  		packetStart      int
   410  		packetLength     int
   411  		err              error
   412  		sendmmsgCount    uint64
   413  		packetsSent      uint64
   414  		payloadBytesSent uint64
   415  		burstBatchSize   int
   416  	)
   417  
   418  	qpvec := make([]*transparentQueuedPacket, uplink.relayBatchSize)
   419  	namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize)
   420  	iovec := make([]unix.Iovec, uplink.relayBatchSize)
   421  	msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize)
   422  
   423  	for i := range msgvec {
   424  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
   425  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   426  		msgvec[i].Msghdr.Iov = &iovec[i]
   427  		msgvec[i].Msghdr.SetIovlen(1)
   428  	}
   429  
   430  main:
   431  	for {
   432  		var count int
   433  
   434  		// Block on first dequeue op.
   435  		queuedPacket, ok := <-uplink.natConnSendCh
   436  		if !ok {
   437  			break
   438  		}
   439  
   440  	dequeue:
   441  		for {
   442  			destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, conn.AddrFromIPPort(queuedPacket.targetAddrPort), s.packetBufFrontHeadroom, int(queuedPacket.msglen))
   443  			if err != nil {
   444  				uplink.logger.Warn("Failed to pack packet for natConn",
   445  					zap.Stringer("clientAddress", uplink.clientAddrPort),
   446  					zap.Stringer("targetAddress", &queuedPacket.targetAddrPort),
   447  					zap.String("client", uplink.clientName),
   448  					zap.Uint32("payloadLength", queuedPacket.msglen),
   449  					zap.Error(err),
   450  				)
   451  
   452  				s.putQueuedPacket(queuedPacket)
   453  
   454  				if count == 0 {
   455  					continue main
   456  				}
   457  				goto next
   458  			}
   459  
   460  			qpvec[count] = queuedPacket
   461  			namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort)
   462  			iovec[count].Base = &queuedPacket.buf[packetStart]
   463  			iovec[count].SetLen(packetLength)
   464  			count++
   465  			payloadBytesSent += uint64(queuedPacket.msglen)
   466  
   467  			if count == uplink.relayBatchSize {
   468  				break
   469  			}
   470  
   471  		next:
   472  			select {
   473  			case queuedPacket, ok = <-uplink.natConnSendCh:
   474  				if !ok {
   475  					break dequeue
   476  				}
   477  			default:
   478  				break dequeue
   479  			}
   480  		}
   481  
   482  		if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil {
   483  			uplink.logger.Warn("Failed to batch write packets to natConn",
   484  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   485  				zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort),
   486  				zap.String("client", uplink.clientName),
   487  				zap.Stringer("lastWriteDestAddress", destAddrPort),
   488  				zap.Error(err),
   489  			)
   490  		}
   491  
   492  		if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil {
   493  			uplink.logger.Warn("Failed to set read deadline on natConn",
   494  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   495  				zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort),
   496  				zap.String("client", uplink.clientName),
   497  				zap.Stringer("lastWriteDestAddress", destAddrPort),
   498  				zap.Duration("natTimeout", uplink.natTimeout),
   499  				zap.Error(err),
   500  			)
   501  		}
   502  
   503  		sendmmsgCount++
   504  		packetsSent += uint64(count)
   505  		burstBatchSize = max(burstBatchSize, count)
   506  
   507  		qpvecn := qpvec[:count]
   508  
   509  		for i := range qpvecn {
   510  			s.putQueuedPacket(qpvecn[i])
   511  		}
   512  
   513  		if !ok {
   514  			break
   515  		}
   516  	}
   517  
   518  	uplink.logger.Info("Finished relay serverConn -> natConn",
   519  		zap.Stringer("clientAddress", uplink.clientAddrPort),
   520  		zap.String("client", uplink.clientName),
   521  		zap.Stringer("lastWriteDestAddress", destAddrPort),
   522  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   523  		zap.Uint64("packetsSent", packetsSent),
   524  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   525  		zap.Int("burstBatchSize", burstBatchSize),
   526  	)
   527  
   528  	s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent)
   529  }
   530  
   531  // getQueuedPacket retrieves a queued packet from the pool.
   532  func (s *UDPTransparentRelay) getQueuedPacket() *transparentQueuedPacket {
   533  	return s.queuedPacketPool.Get().(*transparentQueuedPacket)
   534  }
   535  
   536  // putQueuedPacket puts the queued packet back into the pool.
   537  func (s *UDPTransparentRelay) putQueuedPacket(queuedPacket *transparentQueuedPacket) {
   538  	s.queuedPacketPool.Put(queuedPacket)
   539  }
   540  
   541  type transparentConn struct {
   542  	mwc    *conn.MmsgWConn
   543  	iovec  []unix.Iovec
   544  	msgvec []conn.Mmsghdr
   545  	n      int
   546  }
   547  
   548  func (s *UDPTransparentRelay) newTransparentConn(ctx context.Context, address string, relayBatchSize int, name *byte, namelen uint32) (*transparentConn, error) {
   549  	c, err := s.transparentConnListenConfig.ListenUDPRawConn(ctx, "udp", address)
   550  	if err != nil {
   551  		return nil, err
   552  	}
   553  
   554  	iovec := make([]unix.Iovec, relayBatchSize)
   555  	msgvec := make([]conn.Mmsghdr, relayBatchSize)
   556  
   557  	for i := range msgvec {
   558  		msgvec[i].Msghdr.Name = name
   559  		msgvec[i].Msghdr.Namelen = namelen
   560  		msgvec[i].Msghdr.Iov = &iovec[i]
   561  		msgvec[i].Msghdr.SetIovlen(1)
   562  	}
   563  
   564  	return &transparentConn{
   565  		mwc:    c.WConn(),
   566  		iovec:  iovec,
   567  		msgvec: msgvec,
   568  	}, nil
   569  }
   570  
   571  func (tc *transparentConn) putMsg(base *byte, length int) {
   572  	tc.iovec[tc.n].Base = base
   573  	tc.iovec[tc.n].SetLen(length)
   574  	tc.n++
   575  }
   576  
   577  func (tc *transparentConn) writeMsgvec() (sendmmsgCount, packetsSent int, err error) {
   578  	if tc.n == 0 {
   579  		return
   580  	}
   581  	packetsSent = tc.n
   582  	tc.n = 0
   583  	return 1, packetsSent, tc.mwc.WriteMsgs(tc.msgvec[:packetsSent], 0)
   584  }
   585  
   586  func (tc *transparentConn) close() error {
   587  	return tc.mwc.Close()
   588  }
   589  
   590  func (s *UDPTransparentRelay) relayNatConnToTransparentConnSendmmsg(ctx context.Context, downlink transparentDownlink) {
   591  	var (
   592  		sendmmsgCount    uint64
   593  		packetsSent      uint64
   594  		payloadBytesSent uint64
   595  		burstBatchSize   int
   596  	)
   597  
   598  	maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr())
   599  	name, namelen := conn.AddrPortUnmappedToSockaddr(downlink.clientAddrPort)
   600  	tcMap := make(map[netip.AddrPort]*transparentConn)
   601  
   602  	savec := make([]unix.RawSockaddrInet6, downlink.relayBatchSize)
   603  	bufvec := make([][]byte, downlink.relayBatchSize)
   604  	iovec := make([]unix.Iovec, downlink.relayBatchSize)
   605  	msgvec := make([]conn.Mmsghdr, downlink.relayBatchSize)
   606  
   607  	for i := range downlink.relayBatchSize {
   608  		packetBuf := make([]byte, downlink.natConnRecvBufSize)
   609  		bufvec[i] = packetBuf
   610  
   611  		iovec[i].Base = &packetBuf[0]
   612  		iovec[i].SetLen(downlink.natConnRecvBufSize)
   613  
   614  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&savec[i]))
   615  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   616  		msgvec[i].Msghdr.Iov = &iovec[i]
   617  		msgvec[i].Msghdr.SetIovlen(1)
   618  	}
   619  
   620  	for {
   621  		nr, err := downlink.natConn.ReadMsgs(msgvec, 0)
   622  		if err != nil {
   623  			if errors.Is(err, os.ErrDeadlineExceeded) {
   624  				break
   625  			}
   626  
   627  			downlink.logger.Warn("Failed to batch read packets from natConn",
   628  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   629  				zap.String("client", downlink.clientName),
   630  				zap.Error(err),
   631  			)
   632  			continue
   633  		}
   634  
   635  		var ns int
   636  		msgvecn := msgvec[:nr]
   637  
   638  		for i := range msgvecn {
   639  			msg := &msgvecn[i]
   640  
   641  			packetSourceAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   642  			if err != nil {
   643  				downlink.logger.Warn("Failed to parse sockaddr of packet from natConn",
   644  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   645  					zap.String("client", downlink.clientName),
   646  					zap.Error(err),
   647  				)
   648  				continue
   649  			}
   650  
   651  			if err = conn.ParseFlagsForError(int(msg.Msghdr.Flags)); err != nil {
   652  				downlink.logger.Warn("Packet from natConn discarded",
   653  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   654  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   655  					zap.String("client", downlink.clientName),
   656  					zap.Uint32("packetLength", msg.Msglen),
   657  					zap.Error(err),
   658  				)
   659  				continue
   660  			}
   661  
   662  			packetBuf := bufvec[i]
   663  
   664  			payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, 0, int(msg.Msglen))
   665  			if err != nil {
   666  				downlink.logger.Warn("Failed to unpack packet from natConn",
   667  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   668  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   669  					zap.String("client", downlink.clientName),
   670  					zap.Uint32("packetLength", msg.Msglen),
   671  					zap.Error(err),
   672  				)
   673  				continue
   674  			}
   675  
   676  			if payloadLength > maxClientPacketSize {
   677  				downlink.logger.Warn("Payload too large to send to client",
   678  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   679  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   680  					zap.String("client", downlink.clientName),
   681  					zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   682  					zap.Int("payloadLength", payloadLength),
   683  					zap.Int("maxClientPacketSize", maxClientPacketSize),
   684  				)
   685  				continue
   686  			}
   687  
   688  			tc := tcMap[payloadSourceAddrPort]
   689  			if tc == nil {
   690  				tc, err = s.newTransparentConn(ctx, payloadSourceAddrPort.String(), downlink.relayBatchSize, name, namelen)
   691  				if err != nil {
   692  					downlink.logger.Warn("Failed to create transparentConn",
   693  						zap.Stringer("clientAddress", downlink.clientAddrPort),
   694  						zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   695  						zap.String("client", downlink.clientName),
   696  						zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   697  						zap.Error(err),
   698  					)
   699  					continue
   700  				}
   701  				tcMap[payloadSourceAddrPort] = tc
   702  			}
   703  			tc.putMsg(&packetBuf[payloadStart], payloadLength)
   704  			ns++
   705  			payloadBytesSent += uint64(payloadLength)
   706  		}
   707  
   708  		if ns == 0 {
   709  			continue
   710  		}
   711  
   712  		for payloadSourceAddrPort, tc := range tcMap {
   713  			sc, ps, err := tc.writeMsgvec()
   714  			if err != nil {
   715  				downlink.logger.Warn("Failed to batch write packets to transparentConn",
   716  					zap.Stringer("clientAddress", downlink.clientAddrPort),
   717  					zap.String("client", downlink.clientName),
   718  					zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   719  					zap.Error(err),
   720  				)
   721  			}
   722  
   723  			sendmmsgCount += uint64(sc)
   724  			packetsSent += uint64(ps)
   725  			burstBatchSize = max(burstBatchSize, ps)
   726  		}
   727  	}
   728  
   729  	for payloadSourceAddrPort, tc := range tcMap {
   730  		if err := tc.close(); err != nil {
   731  			downlink.logger.Warn("Failed to close transparentConn",
   732  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   733  				zap.String("client", downlink.clientName),
   734  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   735  				zap.Error(err),
   736  			)
   737  		}
   738  	}
   739  
   740  	downlink.logger.Info("Finished relay transparentConn <- natConn",
   741  		zap.Stringer("clientAddress", downlink.clientAddrPort),
   742  		zap.String("client", downlink.clientName),
   743  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   744  		zap.Uint64("packetsSent", packetsSent),
   745  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   746  		zap.Int("burstBatchSize", burstBatchSize),
   747  	)
   748  
   749  	s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent)
   750  }
   751  
   752  // Stop implements the Relay Stop method.
   753  func (s *UDPTransparentRelay) Stop() error {
   754  	for i := range s.listeners {
   755  		lnc := &s.listeners[i]
   756  		if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   757  			lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err))
   758  		}
   759  	}
   760  
   761  	// Wait for serverConn receive goroutines to exit,
   762  	// so there won't be any new sessions added to the table.
   763  	s.mwg.Wait()
   764  
   765  	s.mu.Lock()
   766  	for clientAddrPort, entry := range s.table {
   767  		natConn := entry.state.Swap(entry.serverConn)
   768  		if natConn == nil {
   769  			continue
   770  		}
   771  
   772  		if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   773  			entry.logger.Warn("Failed to set read deadline on natConn",
   774  				zap.Stringer("clientAddress", clientAddrPort),
   775  				zap.Error(err),
   776  			)
   777  		}
   778  	}
   779  	s.mu.Unlock()
   780  
   781  	// Wait for all relay goroutines to exit before closing serverConn,
   782  	// so in-flight packets can be written out.
   783  	s.wg.Wait()
   784  
   785  	for i := range s.listeners {
   786  		lnc := &s.listeners[i]
   787  		if err := lnc.serverConn.Close(); err != nil {
   788  			lnc.logger.Warn("Failed to close serverConn", zap.Error(err))
   789  		}
   790  	}
   791  
   792  	return nil
   793  }