github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_session_mmsg.go (about)

     1  //go:build linux || netbsd
     2  
     3  package service
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"errors"
     9  	"net/netip"
    10  	"os"
    11  	"sync/atomic"
    12  	"time"
    13  	"unsafe"
    14  
    15  	"github.com/database64128/shadowsocks-go/conn"
    16  	"github.com/database64128/shadowsocks-go/router"
    17  	"github.com/database64128/shadowsocks-go/zerocopy"
    18  	"go.uber.org/zap"
    19  	"golang.org/x/sys/unix"
    20  )
    21  
    22  // sessionUplinkMmsg is used for passing information about relay uplink to the relay goroutine.
    23  type sessionUplinkMmsg struct {
    24  	csid           uint64
    25  	clientName     string
    26  	natConn        *conn.MmsgWConn
    27  	natConnSendCh  <-chan *sessionQueuedPacket
    28  	natConnPacker  zerocopy.ClientPacker
    29  	natTimeout     time.Duration
    30  	username       string
    31  	relayBatchSize int
    32  	logger         *zap.Logger
    33  }
    34  
    35  // sessionDownlinkMmsg is used for passing information about relay downlink to the relay goroutine.
    36  type sessionDownlinkMmsg struct {
    37  	csid               uint64
    38  	clientName         string
    39  	clientAddrInfop    *sessionClientAddrInfo
    40  	clientAddrInfo     *atomic.Pointer[sessionClientAddrInfo]
    41  	natConn            *conn.MmsgRConn
    42  	natConnRecvBufSize int
    43  	natConnUnpacker    zerocopy.ClientUnpacker
    44  	serverConn         *conn.MmsgWConn
    45  	serverConnPacker   zerocopy.ServerPacker
    46  	username           string
    47  	relayBatchSize     int
    48  	logger             *zap.Logger
    49  }
    50  
    51  func (s *UDPSessionRelay) start(ctx context.Context, index int, lnc *udpRelayServerConn) error {
    52  	switch lnc.batchMode {
    53  	case "sendmmsg", "":
    54  		return s.startMmsg(ctx, index, lnc)
    55  	default:
    56  		return s.startGeneric(ctx, index, lnc)
    57  	}
    58  }
    59  
    60  func (s *UDPSessionRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error {
    61  	serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	lnc.serverConn = serverConn.UDPConn
    66  	lnc.address = serverConn.LocalAddr().String()
    67  	lnc.logger = s.logger.With(
    68  		zap.String("server", s.serverName),
    69  		zap.Int("listener", index),
    70  		zap.String("listenAddress", lnc.address),
    71  	)
    72  
    73  	s.mwg.Add(1)
    74  
    75  	go func() {
    76  		s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn())
    77  		s.mwg.Done()
    78  	}()
    79  
    80  	lnc.logger.Info("Started UDP session relay service listener")
    81  	return nil
    82  }
    83  
    84  func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRelayServerConn, serverConn *conn.MmsgRConn) {
    85  	n := lnc.serverRecvBatchSize
    86  	qpvec := make([]*sessionQueuedPacket, n)
    87  	namevec := make([]unix.RawSockaddrInet6, n)
    88  	iovec := make([]unix.Iovec, n)
    89  	cmsgvec := make([][]byte, n)
    90  	msgvec := make([]conn.Mmsghdr, n)
    91  
    92  	for i := range msgvec {
    93  		cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize)
    94  		cmsgvec[i] = cmsgBuf
    95  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
    96  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
    97  		msgvec[i].Msghdr.Iov = &iovec[i]
    98  		msgvec[i].Msghdr.SetIovlen(1)
    99  		msgvec[i].Msghdr.Control = unsafe.SliceData(cmsgBuf)
   100  	}
   101  
   102  	var (
   103  		err                  error
   104  		recvmmsgCount        uint64
   105  		packetsReceived      uint64
   106  		payloadBytesReceived uint64
   107  		burstBatchSize       int
   108  	)
   109  
   110  	for {
   111  		for i := range iovec[:n] {
   112  			queuedPacket := s.getQueuedPacket()
   113  			qpvec[i] = queuedPacket
   114  			iovec[i].Base = &queuedPacket.buf[s.packetBufFrontHeadroom]
   115  			iovec[i].SetLen(s.packetBufRecvSize)
   116  			msgvec[i].Msghdr.SetControllen(conn.SocketControlMessageBufferSize)
   117  		}
   118  
   119  		n, err = serverConn.ReadMsgs(msgvec, 0)
   120  		if err != nil {
   121  			if errors.Is(err, os.ErrDeadlineExceeded) {
   122  				break
   123  			}
   124  
   125  			lnc.logger.Warn("Failed to batch read packets from serverConn", zap.Error(err))
   126  
   127  			n = 1
   128  			s.putQueuedPacket(qpvec[0])
   129  			continue
   130  		}
   131  
   132  		recvmmsgCount++
   133  		packetsReceived += uint64(n)
   134  		burstBatchSize = max(burstBatchSize, n)
   135  
   136  		s.server.Lock()
   137  
   138  		msgvecn := msgvec[:n]
   139  
   140  		for i := range msgvecn {
   141  			msg := &msgvecn[i]
   142  			queuedPacket := qpvec[i]
   143  
   144  			if msg.Msghdr.Controllen == 0 {
   145  				lnc.logger.Warn("Skipping packet with no control message from serverConn")
   146  				s.putQueuedPacket(queuedPacket)
   147  				continue
   148  			}
   149  
   150  			queuedPacket.clientAddrPort, err = conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   151  			if err != nil {
   152  				lnc.logger.Warn("Failed to parse sockaddr of packet from serverConn", zap.Error(err))
   153  				s.putQueuedPacket(queuedPacket)
   154  				continue
   155  			}
   156  
   157  			err = conn.ParseFlagsForError(int(msg.Msghdr.Flags))
   158  			if err != nil {
   159  				lnc.logger.Warn("Packet from serverConn discarded",
   160  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   161  					zap.Uint32("packetLength", msg.Msglen),
   162  					zap.Error(err),
   163  				)
   164  
   165  				s.putQueuedPacket(queuedPacket)
   166  				continue
   167  			}
   168  
   169  			packet := queuedPacket.buf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+int(msg.Msglen)]
   170  
   171  			csid, err := s.server.SessionInfo(packet)
   172  			if err != nil {
   173  				lnc.logger.Warn("Failed to extract session info from packet",
   174  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   175  					zap.Uint32("packetLength", msg.Msglen),
   176  					zap.Error(err),
   177  				)
   178  
   179  				s.putQueuedPacket(queuedPacket)
   180  				continue
   181  			}
   182  
   183  			entry, ok := s.table[csid]
   184  			if !ok {
   185  				entry = &session{
   186  					serverConn: lnc.serverConn,
   187  					logger:     lnc.logger,
   188  				}
   189  
   190  				entry.serverConnUnpacker, entry.username, err = s.server.NewUnpacker(packet, csid)
   191  				if err != nil {
   192  					lnc.logger.Warn("Failed to create unpacker for client session",
   193  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   194  						zap.Uint64("clientSessionID", csid),
   195  						zap.Uint32("packetLength", msg.Msglen),
   196  						zap.Error(err),
   197  					)
   198  
   199  					s.putQueuedPacket(queuedPacket)
   200  					continue
   201  				}
   202  			}
   203  
   204  			queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(queuedPacket.buf, queuedPacket.clientAddrPort, s.packetBufFrontHeadroom, int(msg.Msglen))
   205  			if err != nil {
   206  				lnc.logger.Warn("Failed to unpack packet from serverConn",
   207  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   208  					zap.String("username", entry.username),
   209  					zap.Uint64("clientSessionID", csid),
   210  					zap.Uint32("packetLength", msg.Msglen),
   211  					zap.Error(err),
   212  				)
   213  
   214  				s.putQueuedPacket(queuedPacket)
   215  				continue
   216  			}
   217  
   218  			payloadBytesReceived += uint64(queuedPacket.length)
   219  
   220  			var clientAddrInfop *sessionClientAddrInfo
   221  			cmsg := cmsgvec[i][:msg.Msghdr.Controllen]
   222  
   223  			updateClientAddrPort := entry.clientAddrPortCache != queuedPacket.clientAddrPort
   224  			updateClientPktinfo := !bytes.Equal(entry.clientPktinfoCache, cmsg)
   225  
   226  			if updateClientAddrPort {
   227  				entry.clientAddrPortCache = queuedPacket.clientAddrPort
   228  			}
   229  
   230  			if updateClientPktinfo {
   231  				entry.clientPktinfoCache = make([]byte, len(cmsg))
   232  				copy(entry.clientPktinfoCache, cmsg)
   233  			}
   234  
   235  			if updateClientAddrPort || updateClientPktinfo {
   236  				clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg)
   237  				if err != nil {
   238  					lnc.logger.Warn("Failed to parse pktinfo control message from serverConn",
   239  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   240  						zap.String("username", entry.username),
   241  						zap.Uint64("clientSessionID", csid),
   242  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   243  						zap.Error(err),
   244  					)
   245  
   246  					s.putQueuedPacket(queuedPacket)
   247  					continue
   248  				}
   249  
   250  				clientAddrInfop = &sessionClientAddrInfo{entry.clientAddrPortCache, entry.clientPktinfoCache}
   251  				entry.clientAddrInfo.Store(clientAddrInfop)
   252  
   253  				if ce := lnc.logger.Check(zap.DebugLevel, "Updated client address info"); ce != nil {
   254  					ce.Write(
   255  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   256  						zap.String("username", entry.username),
   257  						zap.Uint64("clientSessionID", csid),
   258  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   259  						zap.Stringer("clientPktinfoAddr", clientPktinfoAddr),
   260  						zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex),
   261  					)
   262  				}
   263  			}
   264  
   265  			if !ok {
   266  				natConnSendCh := make(chan *sessionQueuedPacket, lnc.sendChannelCapacity)
   267  				entry.natConnSendCh = natConnSendCh
   268  				s.table[csid] = entry
   269  				s.wg.Add(1)
   270  
   271  				go func() {
   272  					var sendChClean bool
   273  
   274  					defer func() {
   275  						s.server.Lock()
   276  						close(natConnSendCh)
   277  						delete(s.table, csid)
   278  						s.server.Unlock()
   279  
   280  						if !sendChClean {
   281  							for queuedPacket := range natConnSendCh {
   282  								s.putQueuedPacket(queuedPacket)
   283  							}
   284  						}
   285  
   286  						s.wg.Done()
   287  					}()
   288  
   289  					c, err := s.router.GetUDPClient(ctx, router.RequestInfo{
   290  						ServerIndex:    s.serverIndex,
   291  						Username:       entry.username,
   292  						SourceAddrPort: queuedPacket.clientAddrPort,
   293  						TargetAddr:     queuedPacket.targetAddr,
   294  					})
   295  					if err != nil {
   296  						lnc.logger.Warn("Failed to get UDP client for new NAT session",
   297  							zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   298  							zap.String("username", entry.username),
   299  							zap.Uint64("clientSessionID", csid),
   300  							zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   301  							zap.Error(err),
   302  						)
   303  						return
   304  					}
   305  
   306  					clientInfo, clientSession, err := c.NewSession(ctx)
   307  					if err != nil {
   308  						lnc.logger.Warn("Failed to create new UDP client session",
   309  							zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   310  							zap.String("username", entry.username),
   311  							zap.Uint64("clientSessionID", csid),
   312  							zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   313  							zap.String("client", clientInfo.Name),
   314  							zap.Error(err),
   315  						)
   316  						return
   317  					}
   318  
   319  					natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
   320  					if err != nil {
   321  						lnc.logger.Warn("Failed to create UDP socket for new NAT session",
   322  							zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   323  							zap.String("username", entry.username),
   324  							zap.Uint64("clientSessionID", csid),
   325  							zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   326  							zap.String("client", clientInfo.Name),
   327  							zap.Error(err),
   328  						)
   329  						clientSession.Close()
   330  						return
   331  					}
   332  
   333  					err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout))
   334  					if err != nil {
   335  						lnc.logger.Warn("Failed to set read deadline on natConn",
   336  							zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   337  							zap.String("username", entry.username),
   338  							zap.Uint64("clientSessionID", csid),
   339  							zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   340  							zap.String("client", clientInfo.Name),
   341  							zap.Duration("natTimeout", lnc.natTimeout),
   342  							zap.Error(err),
   343  						)
   344  						natConn.Close()
   345  						clientSession.Close()
   346  						return
   347  					}
   348  
   349  					serverConnPacker, err := entry.serverConnUnpacker.NewPacker()
   350  					if err != nil {
   351  						lnc.logger.Warn("Failed to create packer for client session",
   352  							zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   353  							zap.String("username", entry.username),
   354  							zap.Uint64("clientSessionID", csid),
   355  							zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   356  							zap.Error(err),
   357  						)
   358  						natConn.Close()
   359  						clientSession.Close()
   360  						return
   361  					}
   362  
   363  					oldState := entry.state.Swap(natConn.UDPConn)
   364  					if oldState != nil {
   365  						natConn.Close()
   366  						clientSession.Close()
   367  						return
   368  					}
   369  
   370  					// No more early returns!
   371  					sendChClean = true
   372  
   373  					lnc.logger.Info("UDP session relay started",
   374  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   375  						zap.String("username", entry.username),
   376  						zap.Uint64("clientSessionID", csid),
   377  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   378  						zap.String("client", clientInfo.Name),
   379  					)
   380  
   381  					s.wg.Add(1)
   382  
   383  					go func() {
   384  						s.relayServerConnToNatConnSendmmsg(ctx, sessionUplinkMmsg{
   385  							csid:           csid,
   386  							clientName:     clientInfo.Name,
   387  							natConn:        natConn.WConn(),
   388  							natConnSendCh:  natConnSendCh,
   389  							natConnPacker:  clientSession.Packer,
   390  							natTimeout:     lnc.natTimeout,
   391  							username:       entry.username,
   392  							relayBatchSize: lnc.relayBatchSize,
   393  							logger:         lnc.logger,
   394  						})
   395  						natConn.Close()
   396  						clientSession.Close()
   397  						s.wg.Done()
   398  					}()
   399  
   400  					s.relayNatConnToServerConnSendmmsg(sessionDownlinkMmsg{
   401  						csid:               csid,
   402  						clientName:         clientInfo.Name,
   403  						clientAddrInfop:    clientAddrInfop,
   404  						clientAddrInfo:     &entry.clientAddrInfo,
   405  						natConn:            natConn.RConn(),
   406  						natConnRecvBufSize: clientSession.MaxPacketSize,
   407  						natConnUnpacker:    clientSession.Unpacker,
   408  						serverConn:         serverConn.WConn(),
   409  						serverConnPacker:   serverConnPacker,
   410  						username:           entry.username,
   411  						relayBatchSize:     lnc.relayBatchSize,
   412  						logger:             lnc.logger,
   413  					})
   414  				}()
   415  
   416  				if ce := lnc.logger.Check(zap.DebugLevel, "New UDP session"); ce != nil {
   417  					ce.Write(
   418  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   419  						zap.String("username", entry.username),
   420  						zap.Uint64("clientSessionID", csid),
   421  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   422  					)
   423  				}
   424  			}
   425  
   426  			select {
   427  			case entry.natConnSendCh <- queuedPacket:
   428  			default:
   429  				if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil {
   430  					ce.Write(
   431  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   432  						zap.String("username", entry.username),
   433  						zap.Uint64("clientSessionID", csid),
   434  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   435  					)
   436  				}
   437  
   438  				s.putQueuedPacket(queuedPacket)
   439  			}
   440  		}
   441  
   442  		s.server.Unlock()
   443  	}
   444  
   445  	for i := range qpvec {
   446  		s.putQueuedPacket(qpvec[i])
   447  	}
   448  
   449  	lnc.logger.Info("Finished receiving from serverConn",
   450  		zap.Uint64("recvmmsgCount", recvmmsgCount),
   451  		zap.Uint64("packetsReceived", packetsReceived),
   452  		zap.Uint64("payloadBytesReceived", payloadBytesReceived),
   453  		zap.Int("burstBatchSize", burstBatchSize),
   454  	)
   455  }
   456  
   457  func (s *UDPSessionRelay) relayServerConnToNatConnSendmmsg(ctx context.Context, uplink sessionUplinkMmsg) {
   458  	var (
   459  		destAddrPort     netip.AddrPort
   460  		packetStart      int
   461  		packetLength     int
   462  		err              error
   463  		sendmmsgCount    uint64
   464  		packetsSent      uint64
   465  		payloadBytesSent uint64
   466  		burstBatchSize   int
   467  	)
   468  
   469  	qpvec := make([]*sessionQueuedPacket, uplink.relayBatchSize)
   470  	namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize)
   471  	iovec := make([]unix.Iovec, uplink.relayBatchSize)
   472  	msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize)
   473  
   474  	for i := range msgvec {
   475  		msgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&namevec[i]))
   476  		msgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   477  		msgvec[i].Msghdr.Iov = &iovec[i]
   478  		msgvec[i].Msghdr.SetIovlen(1)
   479  	}
   480  
   481  main:
   482  	for {
   483  		var count int
   484  
   485  		// Block on first dequeue op.
   486  		queuedPacket, ok := <-uplink.natConnSendCh
   487  		if !ok {
   488  			break
   489  		}
   490  
   491  	dequeue:
   492  		for {
   493  			destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length)
   494  			if err != nil {
   495  				uplink.logger.Warn("Failed to pack packet for natConn",
   496  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   497  					zap.String("username", uplink.username),
   498  					zap.Uint64("clientSessionID", uplink.csid),
   499  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   500  					zap.String("client", uplink.clientName),
   501  					zap.Int("payloadLength", queuedPacket.length),
   502  					zap.Error(err),
   503  				)
   504  
   505  				s.putQueuedPacket(queuedPacket)
   506  
   507  				if count == 0 {
   508  					continue main
   509  				}
   510  				goto next
   511  			}
   512  
   513  			qpvec[count] = queuedPacket
   514  			namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort)
   515  			iovec[count].Base = &queuedPacket.buf[packetStart]
   516  			iovec[count].SetLen(packetLength)
   517  			count++
   518  			payloadBytesSent += uint64(queuedPacket.length)
   519  
   520  			if count == uplink.relayBatchSize {
   521  				break
   522  			}
   523  
   524  		next:
   525  			select {
   526  			case queuedPacket, ok = <-uplink.natConnSendCh:
   527  				if !ok {
   528  					break dequeue
   529  				}
   530  			default:
   531  				break dequeue
   532  			}
   533  		}
   534  
   535  		if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil {
   536  			uplink.logger.Warn("Failed to batch write packets to natConn",
   537  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   538  				zap.String("username", uplink.username),
   539  				zap.Uint64("clientSessionID", uplink.csid),
   540  				zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr),
   541  				zap.String("client", uplink.clientName),
   542  				zap.Stringer("lastWriteDestAddress", destAddrPort),
   543  				zap.Error(err),
   544  			)
   545  		}
   546  
   547  		if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil {
   548  			uplink.logger.Warn("Failed to set read deadline on natConn",
   549  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   550  				zap.String("username", uplink.username),
   551  				zap.Uint64("clientSessionID", uplink.csid),
   552  				zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr),
   553  				zap.String("client", uplink.clientName),
   554  				zap.Stringer("lastWriteDestAddress", destAddrPort),
   555  				zap.Duration("natTimeout", uplink.natTimeout),
   556  				zap.Error(err),
   557  			)
   558  		}
   559  
   560  		sendmmsgCount++
   561  		packetsSent += uint64(count)
   562  		burstBatchSize = max(burstBatchSize, count)
   563  
   564  		qpvecn := qpvec[:count]
   565  
   566  		for i := range qpvecn {
   567  			s.putQueuedPacket(qpvecn[i])
   568  		}
   569  
   570  		if !ok {
   571  			break
   572  		}
   573  	}
   574  
   575  	uplink.logger.Info("Finished relay serverConn -> natConn",
   576  		zap.String("username", uplink.username),
   577  		zap.Uint64("clientSessionID", uplink.csid),
   578  		zap.String("client", uplink.clientName),
   579  		zap.Stringer("lastWriteDestAddress", destAddrPort),
   580  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   581  		zap.Uint64("packetsSent", packetsSent),
   582  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   583  		zap.Int("burstBatchSize", burstBatchSize),
   584  	)
   585  
   586  	s.collector.CollectUDPSessionUplink(uplink.username, packetsSent, payloadBytesSent)
   587  }
   588  
   589  func (s *UDPSessionRelay) relayNatConnToServerConnSendmmsg(downlink sessionDownlinkMmsg) {
   590  	clientAddrInfop := downlink.clientAddrInfop
   591  	clientAddrPort := downlink.clientAddrInfop.addrPort
   592  	clientPktinfo := downlink.clientAddrInfop.pktinfo
   593  	maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr())
   594  
   595  	serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo()
   596  	natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo()
   597  	headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom)
   598  
   599  	var (
   600  		sendmmsgCount    uint64
   601  		packetsSent      uint64
   602  		payloadBytesSent uint64
   603  		burstBatchSize   int
   604  	)
   605  
   606  	rsa6, namelen := conn.AddrPortToSockaddrValue(clientAddrPort)
   607  	savec := make([]unix.RawSockaddrInet6, downlink.relayBatchSize)
   608  	bufvec := make([][]byte, downlink.relayBatchSize)
   609  	riovec := make([]unix.Iovec, downlink.relayBatchSize)
   610  	siovec := make([]unix.Iovec, downlink.relayBatchSize)
   611  	rmsgvec := make([]conn.Mmsghdr, downlink.relayBatchSize)
   612  	smsgvec := make([]conn.Mmsghdr, downlink.relayBatchSize)
   613  
   614  	for i := range downlink.relayBatchSize {
   615  		packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear)
   616  		bufvec[i] = packetBuf
   617  
   618  		riovec[i].Base = &packetBuf[headroom.Front]
   619  		riovec[i].SetLen(downlink.natConnRecvBufSize)
   620  
   621  		rmsgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&savec[i]))
   622  		rmsgvec[i].Msghdr.Namelen = unix.SizeofSockaddrInet6
   623  		rmsgvec[i].Msghdr.Iov = &riovec[i]
   624  		rmsgvec[i].Msghdr.SetIovlen(1)
   625  
   626  		smsgvec[i].Msghdr.Name = (*byte)(unsafe.Pointer(&rsa6))
   627  		smsgvec[i].Msghdr.Namelen = namelen
   628  		smsgvec[i].Msghdr.Iov = &siovec[i]
   629  		smsgvec[i].Msghdr.SetIovlen(1)
   630  		smsgvec[i].Msghdr.Control = unsafe.SliceData(clientPktinfo)
   631  		smsgvec[i].Msghdr.SetControllen(len(clientPktinfo))
   632  	}
   633  
   634  	for {
   635  		nr, err := downlink.natConn.ReadMsgs(rmsgvec, 0)
   636  		if err != nil {
   637  			if errors.Is(err, os.ErrDeadlineExceeded) {
   638  				break
   639  			}
   640  
   641  			downlink.logger.Warn("Failed to batch read packets from natConn",
   642  				zap.Stringer("clientAddress", clientAddrPort),
   643  				zap.String("username", downlink.username),
   644  				zap.Uint64("clientSessionID", downlink.csid),
   645  				zap.String("client", downlink.clientName),
   646  				zap.Error(err),
   647  			)
   648  			continue
   649  		}
   650  
   651  		if caip := downlink.clientAddrInfo.Load(); caip != clientAddrInfop {
   652  			clientAddrInfop = caip
   653  			clientAddrPort = caip.addrPort
   654  			clientPktinfo = caip.pktinfo
   655  			maxClientPacketSize = zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr())
   656  			rsa6, _ = conn.AddrPortToSockaddrValue(clientAddrPort) // namelen won't change
   657  
   658  			for i := range smsgvec {
   659  				smsgvec[i].Msghdr.Control = unsafe.SliceData(clientPktinfo)
   660  				smsgvec[i].Msghdr.SetControllen(len(clientPktinfo))
   661  			}
   662  		}
   663  
   664  		var ns int
   665  		rmsgvecn := rmsgvec[:nr]
   666  
   667  		for i := range rmsgvecn {
   668  			msg := &rmsgvecn[i]
   669  
   670  			packetSourceAddrPort, err := conn.SockaddrToAddrPort(msg.Msghdr.Name, msg.Msghdr.Namelen)
   671  			if err != nil {
   672  				downlink.logger.Warn("Failed to parse sockaddr of packet from natConn",
   673  					zap.Stringer("clientAddress", clientAddrPort),
   674  					zap.String("username", downlink.username),
   675  					zap.Uint64("clientSessionID", downlink.csid),
   676  					zap.String("client", downlink.clientName),
   677  					zap.Error(err),
   678  				)
   679  				continue
   680  			}
   681  
   682  			err = conn.ParseFlagsForError(int(msg.Msghdr.Flags))
   683  			if err != nil {
   684  				downlink.logger.Warn("Failed to read packet from natConn",
   685  					zap.Stringer("clientAddress", clientAddrPort),
   686  					zap.String("username", downlink.username),
   687  					zap.Uint64("clientSessionID", downlink.csid),
   688  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   689  					zap.String("client", downlink.clientName),
   690  					zap.Uint32("packetLength", msg.Msglen),
   691  					zap.Error(err),
   692  				)
   693  				continue
   694  			}
   695  
   696  			packetBuf := bufvec[i]
   697  
   698  			payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, int(msg.Msglen))
   699  			if err != nil {
   700  				downlink.logger.Warn("Failed to unpack packet from natConn",
   701  					zap.Stringer("clientAddress", clientAddrPort),
   702  					zap.String("username", downlink.username),
   703  					zap.Uint64("clientSessionID", downlink.csid),
   704  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   705  					zap.String("client", downlink.clientName),
   706  					zap.Uint32("packetLength", msg.Msglen),
   707  					zap.Error(err),
   708  				)
   709  				continue
   710  			}
   711  
   712  			packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize)
   713  			if err != nil {
   714  				downlink.logger.Warn("Failed to pack packet for serverConn",
   715  					zap.Stringer("clientAddress", clientAddrPort),
   716  					zap.String("username", downlink.username),
   717  					zap.Uint64("clientSessionID", downlink.csid),
   718  					zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   719  					zap.String("client", downlink.clientName),
   720  					zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   721  					zap.Int("payloadLength", payloadLength),
   722  					zap.Int("maxClientPacketSize", maxClientPacketSize),
   723  					zap.Error(err),
   724  				)
   725  				continue
   726  			}
   727  
   728  			siovec[ns].Base = &packetBuf[packetStart]
   729  			siovec[ns].SetLen(packetLength)
   730  			ns++
   731  			payloadBytesSent += uint64(payloadLength)
   732  		}
   733  
   734  		if ns == 0 {
   735  			continue
   736  		}
   737  
   738  		err = downlink.serverConn.WriteMsgs(smsgvec[:ns], 0)
   739  		if err != nil {
   740  			downlink.logger.Warn("Failed to batch write packets to serverConn",
   741  				zap.Stringer("clientAddress", clientAddrPort),
   742  				zap.String("username", downlink.username),
   743  				zap.Uint64("clientSessionID", downlink.csid),
   744  				zap.String("client", downlink.clientName),
   745  				zap.Error(err),
   746  			)
   747  		}
   748  
   749  		sendmmsgCount++
   750  		packetsSent += uint64(ns)
   751  		burstBatchSize = max(burstBatchSize, ns)
   752  	}
   753  
   754  	downlink.logger.Info("Finished relay serverConn <- natConn",
   755  		zap.Stringer("clientAddress", clientAddrPort),
   756  		zap.String("username", downlink.username),
   757  		zap.Uint64("clientSessionID", downlink.csid),
   758  		zap.String("client", downlink.clientName),
   759  		zap.Uint64("sendmmsgCount", sendmmsgCount),
   760  		zap.Uint64("packetsSent", packetsSent),
   761  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   762  		zap.Int("burstBatchSize", burstBatchSize),
   763  	)
   764  
   765  	s.collector.CollectUDPSessionDownlink(downlink.username, packetsSent, payloadBytesSent)
   766  }