github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_session.go (about)

     1  package service
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/database64128/shadowsocks-go/conn"
    15  	"github.com/database64128/shadowsocks-go/router"
    16  	"github.com/database64128/shadowsocks-go/stats"
    17  	"github.com/database64128/shadowsocks-go/zerocopy"
    18  	"go.uber.org/zap"
    19  )
    20  
    21  // sessionQueuedPacket is the structure used by send channels to queue packets for sending.
    22  type sessionQueuedPacket struct {
    23  	buf            []byte
    24  	start          int
    25  	length         int
    26  	targetAddr     conn.Addr
    27  	clientAddrPort netip.AddrPort
    28  }
    29  
    30  // sessionClientAddrInfo stores a session's client address information.
    31  type sessionClientAddrInfo struct {
    32  	addrPort netip.AddrPort
    33  	pktinfo  []byte
    34  }
    35  
    36  // session keeps track of a UDP session.
    37  type session struct {
    38  	// state synchronizes session initialization and shutdown.
    39  	//
    40  	//  - Swap the natConn in to signal initialization completion.
    41  	//  - Swap the serverConn in to signal shutdown.
    42  	//
    43  	// Callers must check the swapped-out value to determine the next action.
    44  	//
    45  	//  - During initialization, if the swapped-out value is non-nil,
    46  	//    initialization must not proceed.
    47  	//  - During shutdown, if the swapped-out value is nil, preceed to the next entry.
    48  	state               atomic.Pointer[net.UDPConn]
    49  	clientAddrInfo      atomic.Pointer[sessionClientAddrInfo]
    50  	clientAddrPortCache netip.AddrPort
    51  	clientPktinfoCache  []byte
    52  	natConnSendCh       chan<- *sessionQueuedPacket
    53  	serverConn          *net.UDPConn
    54  	serverConnUnpacker  zerocopy.ServerUnpacker
    55  	username            string
    56  	logger              *zap.Logger
    57  }
    58  
    59  // sessionUplinkGeneric is used for passing information about relay uplink to the relay goroutine.
    60  type sessionUplinkGeneric struct {
    61  	csid          uint64
    62  	clientName    string
    63  	natConn       *net.UDPConn
    64  	natConnSendCh <-chan *sessionQueuedPacket
    65  	natConnPacker zerocopy.ClientPacker
    66  	natTimeout    time.Duration
    67  	username      string
    68  	logger        *zap.Logger
    69  }
    70  
    71  // sessionDownlinkGeneric is used for passing information about relay downlink to the relay goroutine.
    72  type sessionDownlinkGeneric struct {
    73  	csid               uint64
    74  	clientName         string
    75  	clientAddrInfop    *sessionClientAddrInfo
    76  	clientAddrInfo     *atomic.Pointer[sessionClientAddrInfo]
    77  	natConn            *net.UDPConn
    78  	natConnRecvBufSize int
    79  	natConnUnpacker    zerocopy.ClientUnpacker
    80  	serverConn         *net.UDPConn
    81  	serverConnPacker   zerocopy.ServerPacker
    82  	username           string
    83  	logger             *zap.Logger
    84  }
    85  
    86  // UDPSessionRelay is a session-based UDP relay service.
    87  //
    88  // Incoming UDP packets are dispatched to NAT sessions based on the client session ID.
    89  type UDPSessionRelay struct {
    90  	serverName             string
    91  	serverIndex            int
    92  	mtu                    int
    93  	packetBufFrontHeadroom int
    94  	packetBufRecvSize      int
    95  	listeners              []udpRelayServerConn
    96  	server                 zerocopy.UDPSessionServer
    97  	collector              stats.Collector
    98  	router                 *router.Router
    99  	logger                 *zap.Logger
   100  	queuedPacketPool       sync.Pool
   101  	wg                     sync.WaitGroup
   102  	mwg                    sync.WaitGroup
   103  	table                  map[uint64]*session
   104  }
   105  
   106  func NewUDPSessionRelay(
   107  	serverName string,
   108  	serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int,
   109  	listeners []udpRelayServerConn,
   110  	server zerocopy.UDPSessionServer,
   111  	collector stats.Collector,
   112  	router *router.Router,
   113  	logger *zap.Logger,
   114  ) *UDPSessionRelay {
   115  	return &UDPSessionRelay{
   116  		serverName:             serverName,
   117  		serverIndex:            serverIndex,
   118  		mtu:                    mtu,
   119  		packetBufFrontHeadroom: packetBufFrontHeadroom,
   120  		packetBufRecvSize:      packetBufRecvSize,
   121  		listeners:              listeners,
   122  		server:                 server,
   123  		collector:              collector,
   124  		router:                 router,
   125  		logger:                 logger,
   126  		queuedPacketPool: sync.Pool{
   127  			New: func() any {
   128  				return &sessionQueuedPacket{
   129  					buf: make([]byte, packetBufSize),
   130  				}
   131  			},
   132  		},
   133  		table: make(map[uint64]*session),
   134  	}
   135  }
   136  
   137  // String implements the Service String method.
   138  func (s *UDPSessionRelay) String() string {
   139  	return "UDP session relay service for " + s.serverName
   140  }
   141  
   142  // Start implements the Service Start method.
   143  func (s *UDPSessionRelay) Start(ctx context.Context) error {
   144  	for i := range s.listeners {
   145  		if err := s.start(ctx, i, &s.listeners[i]); err != nil {
   146  			return err
   147  		}
   148  	}
   149  	return nil
   150  }
   151  
   152  func (s *UDPSessionRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) {
   153  	lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
   154  	if err != nil {
   155  		return
   156  	}
   157  	lnc.address = lnc.serverConn.LocalAddr().String()
   158  	lnc.logger = s.logger.With(
   159  		zap.String("server", s.serverName),
   160  		zap.Int("listener", index),
   161  		zap.String("listenAddress", lnc.address),
   162  	)
   163  
   164  	s.mwg.Add(1)
   165  
   166  	go func() {
   167  		s.recvFromServerConnGeneric(ctx, lnc)
   168  		s.mwg.Done()
   169  	}()
   170  
   171  	lnc.logger.Info("Started UDP session relay service listener")
   172  	return
   173  }
   174  
   175  func (s *UDPSessionRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRelayServerConn) {
   176  	cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize)
   177  
   178  	var (
   179  		n                    int
   180  		cmsgn                int
   181  		flags                int
   182  		err                  error
   183  		packetsReceived      uint64
   184  		payloadBytesReceived uint64
   185  	)
   186  
   187  	for {
   188  		queuedPacket := s.getQueuedPacket()
   189  		recvBuf := queuedPacket.buf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+s.packetBufRecvSize]
   190  
   191  		n, cmsgn, flags, queuedPacket.clientAddrPort, err = lnc.serverConn.ReadMsgUDPAddrPort(recvBuf, cmsgBuf)
   192  		if err != nil {
   193  			if errors.Is(err, os.ErrDeadlineExceeded) {
   194  				s.putQueuedPacket(queuedPacket)
   195  				break
   196  			}
   197  
   198  			lnc.logger.Warn("Failed to read packet from serverConn",
   199  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   200  				zap.Int("packetLength", n),
   201  				zap.Error(err),
   202  			)
   203  
   204  			s.putQueuedPacket(queuedPacket)
   205  			continue
   206  		}
   207  		err = conn.ParseFlagsForError(flags)
   208  		if err != nil {
   209  			lnc.logger.Warn("Failed to read packet from serverConn",
   210  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   211  				zap.Int("packetLength", n),
   212  				zap.Error(err),
   213  			)
   214  
   215  			s.putQueuedPacket(queuedPacket)
   216  			continue
   217  		}
   218  
   219  		packet := recvBuf[:n]
   220  
   221  		csid, err := s.server.SessionInfo(packet)
   222  		if err != nil {
   223  			lnc.logger.Warn("Failed to extract session info from packet",
   224  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   225  				zap.Int("packetLength", n),
   226  				zap.Error(err),
   227  			)
   228  
   229  			s.putQueuedPacket(queuedPacket)
   230  			continue
   231  		}
   232  
   233  		s.server.Lock()
   234  
   235  		entry, ok := s.table[csid]
   236  		if !ok {
   237  			entry = &session{
   238  				serverConn: lnc.serverConn,
   239  				logger:     lnc.logger,
   240  			}
   241  
   242  			entry.serverConnUnpacker, entry.username, err = s.server.NewUnpacker(packet, csid)
   243  			if err != nil {
   244  				lnc.logger.Warn("Failed to create unpacker for client session",
   245  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   246  					zap.Uint64("clientSessionID", csid),
   247  					zap.Int("packetLength", n),
   248  					zap.Error(err),
   249  				)
   250  
   251  				s.putQueuedPacket(queuedPacket)
   252  				s.server.Unlock()
   253  				continue
   254  			}
   255  		}
   256  
   257  		queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(queuedPacket.buf, queuedPacket.clientAddrPort, s.packetBufFrontHeadroom, n)
   258  		if err != nil {
   259  			lnc.logger.Warn("Failed to unpack packet",
   260  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   261  				zap.String("username", entry.username),
   262  				zap.Uint64("clientSessionID", csid),
   263  				zap.Int("packetLength", n),
   264  				zap.Error(err),
   265  			)
   266  
   267  			s.putQueuedPacket(queuedPacket)
   268  			s.server.Unlock()
   269  			continue
   270  		}
   271  
   272  		packetsReceived++
   273  		payloadBytesReceived += uint64(queuedPacket.length)
   274  
   275  		var clientAddrInfop *sessionClientAddrInfo
   276  		cmsg := cmsgBuf[:cmsgn]
   277  
   278  		updateClientAddrPort := entry.clientAddrPortCache != queuedPacket.clientAddrPort
   279  		updateClientPktinfo := !bytes.Equal(entry.clientPktinfoCache, cmsg)
   280  
   281  		if updateClientAddrPort {
   282  			entry.clientAddrPortCache = queuedPacket.clientAddrPort
   283  		}
   284  
   285  		if updateClientPktinfo {
   286  			entry.clientPktinfoCache = make([]byte, len(cmsg))
   287  			copy(entry.clientPktinfoCache, cmsg)
   288  		}
   289  
   290  		if updateClientAddrPort || updateClientPktinfo {
   291  			clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg)
   292  			if err != nil {
   293  				lnc.logger.Warn("Failed to parse pktinfo control message from serverConn",
   294  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   295  					zap.String("username", entry.username),
   296  					zap.Uint64("clientSessionID", csid),
   297  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   298  					zap.Error(err),
   299  				)
   300  
   301  				s.putQueuedPacket(queuedPacket)
   302  				s.server.Unlock()
   303  				continue
   304  			}
   305  
   306  			clientAddrInfop = &sessionClientAddrInfo{entry.clientAddrPortCache, entry.clientPktinfoCache}
   307  			entry.clientAddrInfo.Store(clientAddrInfop)
   308  
   309  			if ce := lnc.logger.Check(zap.DebugLevel, "Updated client address info"); ce != nil {
   310  				ce.Write(
   311  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   312  					zap.String("username", entry.username),
   313  					zap.Uint64("clientSessionID", csid),
   314  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   315  					zap.Stringer("clientPktinfoAddr", clientPktinfoAddr),
   316  					zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex),
   317  				)
   318  			}
   319  		}
   320  
   321  		if !ok {
   322  			natConnSendCh := make(chan *sessionQueuedPacket, lnc.sendChannelCapacity)
   323  			entry.natConnSendCh = natConnSendCh
   324  			s.table[csid] = entry
   325  			s.wg.Add(1)
   326  
   327  			go func() {
   328  				var sendChClean bool
   329  
   330  				defer func() {
   331  					s.server.Lock()
   332  					close(natConnSendCh)
   333  					delete(s.table, csid)
   334  					s.server.Unlock()
   335  
   336  					if !sendChClean {
   337  						for queuedPacket := range natConnSendCh {
   338  							s.putQueuedPacket(queuedPacket)
   339  						}
   340  					}
   341  
   342  					s.wg.Done()
   343  				}()
   344  
   345  				c, err := s.router.GetUDPClient(ctx, router.RequestInfo{
   346  					ServerIndex:    s.serverIndex,
   347  					Username:       entry.username,
   348  					SourceAddrPort: queuedPacket.clientAddrPort,
   349  					TargetAddr:     queuedPacket.targetAddr,
   350  				})
   351  				if err != nil {
   352  					lnc.logger.Warn("Failed to get UDP client for new NAT session",
   353  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   354  						zap.String("username", entry.username),
   355  						zap.Uint64("clientSessionID", csid),
   356  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   357  						zap.Error(err),
   358  					)
   359  					return
   360  				}
   361  
   362  				clientInfo, clientSession, err := c.NewSession(ctx)
   363  				if err != nil {
   364  					lnc.logger.Warn("Failed to create new UDP client session",
   365  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   366  						zap.String("username", entry.username),
   367  						zap.Uint64("clientSessionID", csid),
   368  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   369  						zap.String("client", clientInfo.Name),
   370  						zap.Error(err),
   371  					)
   372  					return
   373  				}
   374  
   375  				natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
   376  				if err != nil {
   377  					lnc.logger.Warn("Failed to create UDP socket for new NAT session",
   378  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   379  						zap.String("username", entry.username),
   380  						zap.Uint64("clientSessionID", csid),
   381  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   382  						zap.String("client", clientInfo.Name),
   383  						zap.Error(err),
   384  					)
   385  					clientSession.Close()
   386  					return
   387  				}
   388  
   389  				err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout))
   390  				if err != nil {
   391  					lnc.logger.Warn("Failed to set read deadline on natConn",
   392  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   393  						zap.String("username", entry.username),
   394  						zap.Uint64("clientSessionID", csid),
   395  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   396  						zap.String("client", clientInfo.Name),
   397  						zap.Duration("natTimeout", lnc.natTimeout),
   398  						zap.Error(err),
   399  					)
   400  					natConn.Close()
   401  					clientSession.Close()
   402  					return
   403  				}
   404  
   405  				serverConnPacker, err := entry.serverConnUnpacker.NewPacker()
   406  				if err != nil {
   407  					lnc.logger.Warn("Failed to create packer for client session",
   408  						zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   409  						zap.String("username", entry.username),
   410  						zap.Uint64("clientSessionID", csid),
   411  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   412  						zap.Error(err),
   413  					)
   414  					natConn.Close()
   415  					clientSession.Close()
   416  					return
   417  				}
   418  
   419  				oldState := entry.state.Swap(natConn)
   420  				if oldState != nil {
   421  					natConn.Close()
   422  					clientSession.Close()
   423  					return
   424  				}
   425  
   426  				// No more early returns!
   427  				sendChClean = true
   428  
   429  				lnc.logger.Info("UDP session relay started",
   430  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   431  					zap.String("username", entry.username),
   432  					zap.Uint64("clientSessionID", csid),
   433  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   434  					zap.String("client", clientInfo.Name),
   435  				)
   436  
   437  				s.wg.Add(1)
   438  
   439  				go func() {
   440  					s.relayServerConnToNatConnGeneric(ctx, sessionUplinkGeneric{
   441  						csid:          csid,
   442  						clientName:    clientInfo.Name,
   443  						natConn:       natConn,
   444  						natConnSendCh: natConnSendCh,
   445  						natConnPacker: clientSession.Packer,
   446  						natTimeout:    lnc.natTimeout,
   447  						username:      entry.username,
   448  						logger:        lnc.logger,
   449  					})
   450  					natConn.Close()
   451  					clientSession.Close()
   452  					s.wg.Done()
   453  				}()
   454  
   455  				s.relayNatConnToServerConnGeneric(sessionDownlinkGeneric{
   456  					csid:               csid,
   457  					clientName:         clientInfo.Name,
   458  					clientAddrInfop:    clientAddrInfop,
   459  					clientAddrInfo:     &entry.clientAddrInfo,
   460  					natConn:            natConn,
   461  					natConnRecvBufSize: clientSession.MaxPacketSize,
   462  					natConnUnpacker:    clientSession.Unpacker,
   463  					serverConn:         lnc.serverConn,
   464  					serverConnPacker:   serverConnPacker,
   465  					username:           entry.username,
   466  					logger:             lnc.logger,
   467  				})
   468  			}()
   469  
   470  			if ce := lnc.logger.Check(zap.DebugLevel, "New UDP session"); ce != nil {
   471  				ce.Write(
   472  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   473  					zap.String("username", entry.username),
   474  					zap.Uint64("clientSessionID", csid),
   475  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   476  				)
   477  			}
   478  		}
   479  
   480  		select {
   481  		case entry.natConnSendCh <- queuedPacket:
   482  		default:
   483  			if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil {
   484  				ce.Write(
   485  					zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   486  					zap.String("username", entry.username),
   487  					zap.Uint64("clientSessionID", csid),
   488  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   489  				)
   490  			}
   491  
   492  			s.putQueuedPacket(queuedPacket)
   493  		}
   494  
   495  		s.server.Unlock()
   496  	}
   497  
   498  	lnc.logger.Info("Finished receiving from serverConn",
   499  		zap.Uint64("packetsReceived", packetsReceived),
   500  		zap.Uint64("payloadBytesReceived", payloadBytesReceived),
   501  	)
   502  }
   503  
   504  func (s *UDPSessionRelay) relayServerConnToNatConnGeneric(ctx context.Context, uplink sessionUplinkGeneric) {
   505  	var (
   506  		destAddrPort     netip.AddrPort
   507  		packetStart      int
   508  		packetLength     int
   509  		err              error
   510  		packetsSent      uint64
   511  		payloadBytesSent uint64
   512  	)
   513  
   514  	for queuedPacket := range uplink.natConnSendCh {
   515  		destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length)
   516  		if err != nil {
   517  			uplink.logger.Warn("Failed to pack packet",
   518  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   519  				zap.String("username", uplink.username),
   520  				zap.Uint64("clientSessionID", uplink.csid),
   521  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   522  				zap.String("client", uplink.clientName),
   523  				zap.Int("payloadLength", queuedPacket.length),
   524  				zap.Error(err),
   525  			)
   526  
   527  			s.putQueuedPacket(queuedPacket)
   528  			continue
   529  		}
   530  
   531  		_, err = uplink.natConn.WriteToUDPAddrPort(queuedPacket.buf[packetStart:packetStart+packetLength], destAddrPort)
   532  		if err != nil {
   533  			uplink.logger.Warn("Failed to write packet to natConn",
   534  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   535  				zap.String("username", uplink.username),
   536  				zap.Uint64("clientSessionID", uplink.csid),
   537  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   538  				zap.String("client", uplink.clientName),
   539  				zap.Stringer("writeDestAddress", destAddrPort),
   540  				zap.Int("packetLength", packetLength),
   541  				zap.Error(err),
   542  			)
   543  		}
   544  
   545  		err = uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout))
   546  		if err != nil {
   547  			uplink.logger.Warn("Failed to set read deadline on natConn",
   548  				zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
   549  				zap.String("username", uplink.username),
   550  				zap.Uint64("clientSessionID", uplink.csid),
   551  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   552  				zap.String("client", uplink.clientName),
   553  				zap.Stringer("writeDestAddress", destAddrPort),
   554  				zap.Duration("natTimeout", uplink.natTimeout),
   555  				zap.Error(err),
   556  			)
   557  		}
   558  
   559  		s.putQueuedPacket(queuedPacket)
   560  		packetsSent++
   561  		payloadBytesSent += uint64(queuedPacket.length)
   562  	}
   563  
   564  	uplink.logger.Info("Finished relay serverConn -> natConn",
   565  		zap.String("username", uplink.username),
   566  		zap.Uint64("clientSessionID", uplink.csid),
   567  		zap.String("client", uplink.clientName),
   568  		zap.Stringer("lastWriteDestAddress", destAddrPort),
   569  		zap.Uint64("packetsSent", packetsSent),
   570  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   571  	)
   572  
   573  	s.collector.CollectUDPSessionUplink(uplink.username, packetsSent, payloadBytesSent)
   574  }
   575  
   576  func (s *UDPSessionRelay) relayNatConnToServerConnGeneric(downlink sessionDownlinkGeneric) {
   577  	clientAddrInfop := downlink.clientAddrInfop
   578  	clientAddrPort := clientAddrInfop.addrPort
   579  	clientPktinfo := clientAddrInfop.pktinfo
   580  	maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr())
   581  
   582  	serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo()
   583  	natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo()
   584  	headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom)
   585  
   586  	var (
   587  		packetsSent      uint64
   588  		payloadBytesSent uint64
   589  	)
   590  
   591  	packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear)
   592  	recvBuf := packetBuf[headroom.Front : headroom.Front+downlink.natConnRecvBufSize]
   593  
   594  	for {
   595  		n, _, flags, packetSourceAddrPort, err := downlink.natConn.ReadMsgUDPAddrPort(recvBuf, nil)
   596  		if err != nil {
   597  			if errors.Is(err, os.ErrDeadlineExceeded) {
   598  				break
   599  			}
   600  
   601  			downlink.logger.Warn("Failed to read packet from natConn",
   602  				zap.Stringer("clientAddress", clientAddrPort),
   603  				zap.String("username", downlink.username),
   604  				zap.Uint64("clientSessionID", downlink.csid),
   605  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   606  				zap.String("client", downlink.clientName),
   607  				zap.Int("packetLength", n),
   608  				zap.Error(err),
   609  			)
   610  			continue
   611  		}
   612  		err = conn.ParseFlagsForError(flags)
   613  		if err != nil {
   614  			downlink.logger.Warn("Failed to read packet from natConn",
   615  				zap.Stringer("clientAddress", clientAddrPort),
   616  				zap.String("username", downlink.username),
   617  				zap.Uint64("clientSessionID", downlink.csid),
   618  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   619  				zap.String("client", downlink.clientName),
   620  				zap.Int("packetLength", n),
   621  				zap.Error(err),
   622  			)
   623  			continue
   624  		}
   625  
   626  		payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, n)
   627  		if err != nil {
   628  			downlink.logger.Warn("Failed to unpack packet",
   629  				zap.Stringer("clientAddress", clientAddrPort),
   630  				zap.String("username", downlink.username),
   631  				zap.Uint64("clientSessionID", downlink.csid),
   632  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   633  				zap.String("client", downlink.clientName),
   634  				zap.Int("packetLength", n),
   635  				zap.Error(err),
   636  			)
   637  			continue
   638  		}
   639  
   640  		if caip := downlink.clientAddrInfo.Load(); caip != clientAddrInfop {
   641  			clientAddrInfop = caip
   642  			clientAddrPort = caip.addrPort
   643  			clientPktinfo = caip.pktinfo
   644  			maxClientPacketSize = zerocopy.MaxPacketSizeForAddr(s.mtu, clientAddrPort.Addr())
   645  		}
   646  
   647  		packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize)
   648  		if err != nil {
   649  			downlink.logger.Warn("Failed to pack packet",
   650  				zap.Stringer("clientAddress", clientAddrPort),
   651  				zap.String("username", downlink.username),
   652  				zap.Uint64("clientSessionID", downlink.csid),
   653  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   654  				zap.String("client", downlink.clientName),
   655  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   656  				zap.Int("payloadLength", payloadLength),
   657  				zap.Int("maxClientPacketSize", maxClientPacketSize),
   658  				zap.Error(err),
   659  			)
   660  			continue
   661  		}
   662  
   663  		_, _, err = downlink.serverConn.WriteMsgUDPAddrPort(packetBuf[packetStart:packetStart+packetLength], clientPktinfo, clientAddrPort)
   664  		if err != nil {
   665  			downlink.logger.Warn("Failed to write packet to serverConn",
   666  				zap.Stringer("clientAddress", clientAddrPort),
   667  				zap.String("username", downlink.username),
   668  				zap.Uint64("clientSessionID", downlink.csid),
   669  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   670  				zap.String("client", downlink.clientName),
   671  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   672  				zap.Int("packetLength", packetLength),
   673  				zap.Error(err),
   674  			)
   675  		}
   676  
   677  		packetsSent++
   678  		payloadBytesSent += uint64(payloadLength)
   679  	}
   680  
   681  	downlink.logger.Info("Finished relay serverConn <- natConn",
   682  		zap.Stringer("clientAddress", clientAddrPort),
   683  		zap.String("username", downlink.username),
   684  		zap.Uint64("clientSessionID", downlink.csid),
   685  		zap.Uint64("packetsSent", packetsSent),
   686  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   687  	)
   688  
   689  	s.collector.CollectUDPSessionDownlink(downlink.username, packetsSent, payloadBytesSent)
   690  }
   691  
   692  // getQueuedPacket retrieves a queued packet from the pool.
   693  func (s *UDPSessionRelay) getQueuedPacket() *sessionQueuedPacket {
   694  	return s.queuedPacketPool.Get().(*sessionQueuedPacket)
   695  }
   696  
   697  // putQueuedPacket puts the queued packet back into the pool.
   698  func (s *UDPSessionRelay) putQueuedPacket(queuedPacket *sessionQueuedPacket) {
   699  	s.queuedPacketPool.Put(queuedPacket)
   700  }
   701  
   702  // Stop implements the Service Stop method.
   703  func (s *UDPSessionRelay) Stop() error {
   704  	for i := range s.listeners {
   705  		lnc := &s.listeners[i]
   706  		if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   707  			lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err))
   708  		}
   709  	}
   710  
   711  	// Wait for serverConn receive goroutines to exit,
   712  	// so there won't be any new sessions added to the table.
   713  	s.mwg.Wait()
   714  
   715  	s.server.Lock()
   716  	for csid, entry := range s.table {
   717  		natConn := entry.state.Swap(entry.serverConn)
   718  		if natConn == nil {
   719  			continue
   720  		}
   721  
   722  		if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   723  			entry.logger.Warn("Failed to set read deadline on natConn",
   724  				zap.Uint64("clientSessionID", csid),
   725  				zap.Error(err),
   726  			)
   727  		}
   728  	}
   729  	s.server.Unlock()
   730  
   731  	// Wait for all relay goroutines to exit before closing serverConn,
   732  	// so in-flight packets can be written out.
   733  	s.wg.Wait()
   734  
   735  	for i := range s.listeners {
   736  		lnc := &s.listeners[i]
   737  		if err := lnc.serverConn.Close(); err != nil {
   738  			lnc.logger.Warn("Failed to close serverConn", zap.Error(err))
   739  		}
   740  	}
   741  
   742  	return nil
   743  }