github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_nat.go (about)

     1  package service
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/database64128/shadowsocks-go/conn"
    15  	"github.com/database64128/shadowsocks-go/router"
    16  	"github.com/database64128/shadowsocks-go/stats"
    17  	"github.com/database64128/shadowsocks-go/zerocopy"
    18  	"go.uber.org/zap"
    19  )
    20  
    21  // natQueuedPacket is the structure used by send channels to queue packets for sending.
    22  type natQueuedPacket struct {
    23  	buf        []byte
    24  	start      int
    25  	length     int
    26  	targetAddr conn.Addr
    27  }
    28  
    29  // natEntry is an entry in the NAT table.
    30  type natEntry struct {
    31  	// state synchronizes session initialization and shutdown.
    32  	//
    33  	//  - Swap the natConn in to signal initialization completion.
    34  	//  - Swap the serverConn in to signal shutdown.
    35  	//
    36  	// Callers must check the swapped-out value to determine the next action.
    37  	//
    38  	//  - During initialization, if the swapped-out value is non-nil,
    39  	//    initialization must not proceed.
    40  	//  - During shutdown, if the swapped-out value is nil, preceed to the next entry.
    41  	state              atomic.Pointer[net.UDPConn]
    42  	clientPktinfo      atomic.Pointer[[]byte]
    43  	clientPktinfoCache []byte
    44  	natConnSendCh      chan<- *natQueuedPacket
    45  	serverConn         *net.UDPConn
    46  	serverConnUnpacker zerocopy.ServerUnpacker
    47  	logger             *zap.Logger
    48  }
    49  
    50  // natUplinkGeneric is used for passing information about relay uplink to the relay goroutine.
    51  type natUplinkGeneric struct {
    52  	clientName     string
    53  	clientAddrPort netip.AddrPort
    54  	natConn        *net.UDPConn
    55  	natConnSendCh  <-chan *natQueuedPacket
    56  	natConnPacker  zerocopy.ClientPacker
    57  	natTimeout     time.Duration
    58  	logger         *zap.Logger
    59  }
    60  
    61  // natDownlinkGeneric is used for passing information about relay downlink to the relay goroutine.
    62  type natDownlinkGeneric struct {
    63  	clientName         string
    64  	clientAddrPort     netip.AddrPort
    65  	clientPktinfo      *atomic.Pointer[[]byte]
    66  	natConn            *net.UDPConn
    67  	natConnRecvBufSize int
    68  	natConnUnpacker    zerocopy.ClientUnpacker
    69  	serverConn         *net.UDPConn
    70  	serverConnPacker   zerocopy.ServerPacker
    71  	logger             *zap.Logger
    72  }
    73  
    74  // UDPNATRelay is an address-based UDP relay service.
    75  //
    76  // Incoming UDP packets are dispatched to NAT sessions based on the source address and port.
    77  type UDPNATRelay struct {
    78  	serverName             string
    79  	serverIndex            int
    80  	mtu                    int
    81  	packetBufFrontHeadroom int
    82  	packetBufRecvSize      int
    83  	listeners              []udpRelayServerConn
    84  	server                 zerocopy.UDPNATServer
    85  	collector              stats.Collector
    86  	router                 *router.Router
    87  	logger                 *zap.Logger
    88  	queuedPacketPool       sync.Pool
    89  	mu                     sync.Mutex
    90  	wg                     sync.WaitGroup
    91  	mwg                    sync.WaitGroup
    92  	table                  map[netip.AddrPort]*natEntry
    93  }
    94  
    95  func NewUDPNATRelay(
    96  	serverName string,
    97  	serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int,
    98  	listeners []udpRelayServerConn,
    99  	server zerocopy.UDPNATServer,
   100  	collector stats.Collector,
   101  	router *router.Router,
   102  	logger *zap.Logger,
   103  ) *UDPNATRelay {
   104  	return &UDPNATRelay{
   105  		serverName:             serverName,
   106  		serverIndex:            serverIndex,
   107  		mtu:                    mtu,
   108  		packetBufFrontHeadroom: packetBufFrontHeadroom,
   109  		packetBufRecvSize:      packetBufRecvSize,
   110  		listeners:              listeners,
   111  		server:                 server,
   112  		collector:              collector,
   113  		router:                 router,
   114  		logger:                 logger,
   115  		queuedPacketPool: sync.Pool{
   116  			New: func() any {
   117  				return &natQueuedPacket{
   118  					buf: make([]byte, packetBufSize),
   119  				}
   120  			},
   121  		},
   122  		table: make(map[netip.AddrPort]*natEntry),
   123  	}
   124  }
   125  
   126  // String implements the Service String method.
   127  func (s *UDPNATRelay) String() string {
   128  	return "UDP NAT relay service for " + s.serverName
   129  }
   130  
   131  // Start implements the Service Start method.
   132  func (s *UDPNATRelay) Start(ctx context.Context) error {
   133  	for i := range s.listeners {
   134  		if err := s.start(ctx, i, &s.listeners[i]); err != nil {
   135  			return err
   136  		}
   137  	}
   138  	return nil
   139  }
   140  
   141  func (s *UDPNATRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) {
   142  	lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
   143  	if err != nil {
   144  		return
   145  	}
   146  	lnc.address = lnc.serverConn.LocalAddr().String()
   147  	lnc.logger = s.logger.With(
   148  		zap.String("server", s.serverName),
   149  		zap.Int("listener", index),
   150  		zap.String("listenAddress", lnc.address),
   151  	)
   152  
   153  	s.mwg.Add(1)
   154  
   155  	go func() {
   156  		s.recvFromServerConnGeneric(ctx, lnc)
   157  		s.mwg.Done()
   158  	}()
   159  
   160  	lnc.logger.Info("Started UDP NAT relay service listener")
   161  	return
   162  }
   163  
   164  func (s *UDPNATRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRelayServerConn) {
   165  	cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize)
   166  
   167  	var (
   168  		packetsReceived      uint64
   169  		payloadBytesReceived uint64
   170  	)
   171  
   172  	for {
   173  		queuedPacket := s.getQueuedPacket()
   174  		packetBuf := queuedPacket.buf
   175  		recvBuf := packetBuf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+s.packetBufRecvSize]
   176  
   177  		n, cmsgn, flags, clientAddrPort, err := lnc.serverConn.ReadMsgUDPAddrPort(recvBuf, cmsgBuf)
   178  		if err != nil {
   179  			if errors.Is(err, os.ErrDeadlineExceeded) {
   180  				s.putQueuedPacket(queuedPacket)
   181  				break
   182  			}
   183  
   184  			lnc.logger.Warn("Failed to read packet from serverConn",
   185  				zap.Stringer("clientAddress", clientAddrPort),
   186  				zap.Int("packetLength", n),
   187  				zap.Error(err),
   188  			)
   189  
   190  			s.putQueuedPacket(queuedPacket)
   191  			continue
   192  		}
   193  		err = conn.ParseFlagsForError(flags)
   194  		if err != nil {
   195  			lnc.logger.Warn("Failed to read packet from serverConn",
   196  				zap.Stringer("clientAddress", clientAddrPort),
   197  				zap.Int("packetLength", n),
   198  				zap.Error(err),
   199  			)
   200  
   201  			s.putQueuedPacket(queuedPacket)
   202  			continue
   203  		}
   204  
   205  		s.mu.Lock()
   206  
   207  		entry, ok := s.table[clientAddrPort]
   208  		if !ok {
   209  			entry = &natEntry{
   210  				serverConn: lnc.serverConn,
   211  				logger:     lnc.logger,
   212  			}
   213  
   214  			entry.serverConnUnpacker, err = s.server.NewUnpacker()
   215  			if err != nil {
   216  				lnc.logger.Warn("Failed to create unpacker for serverConn",
   217  					zap.Stringer("clientAddress", clientAddrPort),
   218  					zap.Error(err),
   219  				)
   220  
   221  				s.putQueuedPacket(queuedPacket)
   222  				s.mu.Unlock()
   223  				continue
   224  			}
   225  		}
   226  
   227  		queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(packetBuf, clientAddrPort, s.packetBufFrontHeadroom, n)
   228  		if err != nil {
   229  			lnc.logger.Warn("Failed to unpack packet from serverConn",
   230  				zap.Stringer("clientAddress", clientAddrPort),
   231  				zap.Int("packetLength", n),
   232  				zap.Error(err),
   233  			)
   234  
   235  			s.putQueuedPacket(queuedPacket)
   236  			s.mu.Unlock()
   237  			continue
   238  		}
   239  
   240  		packetsReceived++
   241  		payloadBytesReceived += uint64(queuedPacket.length)
   242  
   243  		cmsg := cmsgBuf[:cmsgn]
   244  
   245  		if !bytes.Equal(entry.clientPktinfoCache, cmsg) {
   246  			clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg)
   247  			if err != nil {
   248  				lnc.logger.Warn("Failed to parse pktinfo control message from serverConn",
   249  					zap.Stringer("clientAddress", clientAddrPort),
   250  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   251  					zap.Error(err),
   252  				)
   253  
   254  				s.putQueuedPacket(queuedPacket)
   255  				s.mu.Unlock()
   256  				continue
   257  			}
   258  
   259  			clientPktinfoCache := make([]byte, len(cmsg))
   260  			copy(clientPktinfoCache, cmsg)
   261  			entry.clientPktinfo.Store(&clientPktinfoCache)
   262  			entry.clientPktinfoCache = clientPktinfoCache
   263  
   264  			if ce := lnc.logger.Check(zap.DebugLevel, "Updated client pktinfo"); ce != nil {
   265  				ce.Write(
   266  					zap.String("server", s.serverName),
   267  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   268  					zap.Stringer("clientPktinfoAddr", clientPktinfoAddr),
   269  					zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex),
   270  				)
   271  			}
   272  		}
   273  
   274  		if !ok {
   275  			natConnSendCh := make(chan *natQueuedPacket, lnc.sendChannelCapacity)
   276  			entry.natConnSendCh = natConnSendCh
   277  			s.table[clientAddrPort] = entry
   278  			s.wg.Add(1)
   279  
   280  			go func() {
   281  				var sendChClean bool
   282  
   283  				defer func() {
   284  					s.mu.Lock()
   285  					close(natConnSendCh)
   286  					delete(s.table, clientAddrPort)
   287  					s.mu.Unlock()
   288  
   289  					if !sendChClean {
   290  						for queuedPacket := range natConnSendCh {
   291  							s.putQueuedPacket(queuedPacket)
   292  						}
   293  					}
   294  
   295  					s.wg.Done()
   296  				}()
   297  
   298  				c, err := s.router.GetUDPClient(ctx, router.RequestInfo{
   299  					ServerIndex:    s.serverIndex,
   300  					SourceAddrPort: clientAddrPort,
   301  					TargetAddr:     queuedPacket.targetAddr,
   302  				})
   303  				if err != nil {
   304  					lnc.logger.Warn("Failed to get UDP client for new NAT session",
   305  						zap.Stringer("clientAddress", clientAddrPort),
   306  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   307  						zap.Error(err),
   308  					)
   309  					return
   310  				}
   311  
   312  				clientInfo, clientSession, err := c.NewSession(ctx)
   313  				if err != nil {
   314  					lnc.logger.Warn("Failed to create new UDP client session",
   315  						zap.Stringer("clientAddress", clientAddrPort),
   316  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   317  						zap.String("client", clientInfo.Name),
   318  						zap.Error(err),
   319  					)
   320  					return
   321  				}
   322  
   323  				natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
   324  				if err != nil {
   325  					lnc.logger.Warn("Failed to create UDP socket for new NAT session",
   326  						zap.Stringer("clientAddress", clientAddrPort),
   327  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   328  						zap.String("client", clientInfo.Name),
   329  						zap.Error(err),
   330  					)
   331  					clientSession.Close()
   332  					return
   333  				}
   334  
   335  				err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout))
   336  				if err != nil {
   337  					lnc.logger.Warn("Failed to set read deadline on natConn",
   338  						zap.Stringer("clientAddress", clientAddrPort),
   339  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   340  						zap.String("client", clientInfo.Name),
   341  						zap.Duration("natTimeout", lnc.natTimeout),
   342  						zap.Error(err),
   343  					)
   344  					natConn.Close()
   345  					clientSession.Close()
   346  					return
   347  				}
   348  
   349  				serverConnPacker, err := entry.serverConnUnpacker.NewPacker()
   350  				if err != nil {
   351  					lnc.logger.Warn("Failed to create packer for serverConn",
   352  						zap.Stringer("clientAddress", clientAddrPort),
   353  						zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   354  						zap.Error(err),
   355  					)
   356  					natConn.Close()
   357  					clientSession.Close()
   358  					return
   359  				}
   360  
   361  				oldState := entry.state.Swap(natConn)
   362  				if oldState != nil {
   363  					natConn.Close()
   364  					clientSession.Close()
   365  					return
   366  				}
   367  
   368  				// No more early returns!
   369  				sendChClean = true
   370  
   371  				lnc.logger.Info("UDP NAT relay started",
   372  					zap.Stringer("clientAddress", clientAddrPort),
   373  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   374  					zap.String("client", clientInfo.Name),
   375  				)
   376  
   377  				s.wg.Add(1)
   378  
   379  				go func() {
   380  					s.relayServerConnToNatConnGeneric(ctx, natUplinkGeneric{
   381  						clientName:     clientInfo.Name,
   382  						clientAddrPort: clientAddrPort,
   383  						natConn:        natConn,
   384  						natConnSendCh:  natConnSendCh,
   385  						natConnPacker:  clientSession.Packer,
   386  						natTimeout:     lnc.natTimeout,
   387  						logger:         lnc.logger,
   388  					})
   389  					natConn.Close()
   390  					clientSession.Close()
   391  					s.wg.Done()
   392  				}()
   393  
   394  				s.relayNatConnToServerConnGeneric(natDownlinkGeneric{
   395  					clientName:         clientInfo.Name,
   396  					clientAddrPort:     clientAddrPort,
   397  					clientPktinfo:      &entry.clientPktinfo,
   398  					natConn:            natConn,
   399  					natConnRecvBufSize: clientSession.MaxPacketSize,
   400  					natConnUnpacker:    clientSession.Unpacker,
   401  					serverConn:         lnc.serverConn,
   402  					serverConnPacker:   serverConnPacker,
   403  					logger:             lnc.logger,
   404  				})
   405  			}()
   406  
   407  			if ce := lnc.logger.Check(zap.DebugLevel, "New UDP NAT session"); ce != nil {
   408  				ce.Write(
   409  					zap.Stringer("clientAddress", clientAddrPort),
   410  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   411  				)
   412  			}
   413  		}
   414  
   415  		select {
   416  		case entry.natConnSendCh <- queuedPacket:
   417  		default:
   418  			if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil {
   419  				ce.Write(
   420  					zap.Stringer("clientAddress", clientAddrPort),
   421  					zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   422  				)
   423  			}
   424  
   425  			s.putQueuedPacket(queuedPacket)
   426  		}
   427  
   428  		s.mu.Unlock()
   429  	}
   430  
   431  	lnc.logger.Info("Finished receiving from serverConn",
   432  		zap.Uint64("packetsReceived", packetsReceived),
   433  		zap.Uint64("payloadBytesReceived", payloadBytesReceived),
   434  	)
   435  }
   436  
   437  func (s *UDPNATRelay) relayServerConnToNatConnGeneric(ctx context.Context, uplink natUplinkGeneric) {
   438  	var (
   439  		destAddrPort     netip.AddrPort
   440  		packetStart      int
   441  		packetLength     int
   442  		err              error
   443  		packetsSent      uint64
   444  		payloadBytesSent uint64
   445  	)
   446  
   447  	for queuedPacket := range uplink.natConnSendCh {
   448  		destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length)
   449  		if err != nil {
   450  			uplink.logger.Warn("Failed to pack packet for natConn",
   451  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   452  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   453  				zap.String("client", uplink.clientName),
   454  				zap.Int("payloadLength", queuedPacket.length),
   455  				zap.Error(err),
   456  			)
   457  
   458  			s.putQueuedPacket(queuedPacket)
   459  			continue
   460  		}
   461  
   462  		_, err = uplink.natConn.WriteToUDPAddrPort(queuedPacket.buf[packetStart:packetStart+packetLength], destAddrPort)
   463  		if err != nil {
   464  			uplink.logger.Warn("Failed to write packet to natConn",
   465  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   466  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   467  				zap.String("client", uplink.clientName),
   468  				zap.Stringer("writeDestAddress", destAddrPort),
   469  				zap.Int("packetLength", packetLength),
   470  				zap.Error(err),
   471  			)
   472  		}
   473  
   474  		err = uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout))
   475  		if err != nil {
   476  			uplink.logger.Warn("Failed to set read deadline on natConn",
   477  				zap.Stringer("clientAddress", uplink.clientAddrPort),
   478  				zap.Stringer("targetAddress", &queuedPacket.targetAddr),
   479  				zap.String("client", uplink.clientName),
   480  				zap.Stringer("writeDestAddress", destAddrPort),
   481  				zap.Duration("natTimeout", uplink.natTimeout),
   482  				zap.Error(err),
   483  			)
   484  		}
   485  
   486  		s.putQueuedPacket(queuedPacket)
   487  		packetsSent++
   488  		payloadBytesSent += uint64(queuedPacket.length)
   489  	}
   490  
   491  	uplink.logger.Info("Finished relay serverConn -> natConn",
   492  		zap.Stringer("clientAddress", uplink.clientAddrPort),
   493  		zap.String("client", uplink.clientName),
   494  		zap.Stringer("lastWriteDestAddress", destAddrPort),
   495  		zap.Uint64("packetsSent", packetsSent),
   496  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   497  	)
   498  
   499  	s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent)
   500  }
   501  
   502  func (s *UDPNATRelay) relayNatConnToServerConnGeneric(downlink natDownlinkGeneric) {
   503  	maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr())
   504  
   505  	serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo()
   506  	natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo()
   507  	headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom)
   508  
   509  	var (
   510  		clientPktinfo    []byte
   511  		clientPktinfop   *[]byte
   512  		packetsSent      uint64
   513  		payloadBytesSent uint64
   514  	)
   515  
   516  	packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear)
   517  	recvBuf := packetBuf[headroom.Front : headroom.Front+downlink.natConnRecvBufSize]
   518  
   519  	for {
   520  		n, _, flags, packetSourceAddrPort, err := downlink.natConn.ReadMsgUDPAddrPort(recvBuf, nil)
   521  		if err != nil {
   522  			if errors.Is(err, os.ErrDeadlineExceeded) {
   523  				break
   524  			}
   525  
   526  			downlink.logger.Warn("Failed to read packet from natConn",
   527  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   528  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   529  				zap.String("client", downlink.clientName),
   530  				zap.Int("packetLength", n),
   531  				zap.Error(err),
   532  			)
   533  			continue
   534  		}
   535  		err = conn.ParseFlagsForError(flags)
   536  		if err != nil {
   537  			downlink.logger.Warn("Failed to read packet from natConn",
   538  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   539  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   540  				zap.String("client", downlink.clientName),
   541  				zap.Int("packetLength", n),
   542  				zap.Error(err),
   543  			)
   544  			continue
   545  		}
   546  
   547  		payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, n)
   548  		if err != nil {
   549  			downlink.logger.Warn("Failed to unpack packet from natConn",
   550  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   551  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   552  				zap.String("client", downlink.clientName),
   553  				zap.Int("packetLength", n),
   554  				zap.Error(err),
   555  			)
   556  			continue
   557  		}
   558  
   559  		packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize)
   560  		if err != nil {
   561  			downlink.logger.Warn("Failed to pack packet for serverConn",
   562  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   563  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   564  				zap.String("client", downlink.clientName),
   565  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   566  				zap.Int("payloadLength", payloadLength),
   567  				zap.Int("maxClientPacketSize", maxClientPacketSize),
   568  				zap.Error(err),
   569  			)
   570  			continue
   571  		}
   572  
   573  		if cpp := downlink.clientPktinfo.Load(); cpp != clientPktinfop {
   574  			clientPktinfo = *cpp
   575  			clientPktinfop = cpp
   576  		}
   577  
   578  		_, _, err = downlink.serverConn.WriteMsgUDPAddrPort(packetBuf[packetStart:packetStart+packetLength], clientPktinfo, downlink.clientAddrPort)
   579  		if err != nil {
   580  			downlink.logger.Warn("Failed to write packet to serverConn",
   581  				zap.Stringer("clientAddress", downlink.clientAddrPort),
   582  				zap.Stringer("packetSourceAddress", packetSourceAddrPort),
   583  				zap.String("client", downlink.clientName),
   584  				zap.Stringer("payloadSourceAddress", payloadSourceAddrPort),
   585  				zap.Int("packetLength", packetLength),
   586  				zap.Error(err),
   587  			)
   588  		}
   589  
   590  		packetsSent++
   591  		payloadBytesSent += uint64(payloadLength)
   592  	}
   593  
   594  	downlink.logger.Info("Finished relay serverConn <- natConn",
   595  		zap.Stringer("clientAddress", downlink.clientAddrPort),
   596  		zap.String("client", downlink.clientName),
   597  		zap.Uint64("packetsSent", packetsSent),
   598  		zap.Uint64("payloadBytesSent", payloadBytesSent),
   599  	)
   600  
   601  	s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent)
   602  }
   603  
   604  // getQueuedPacket retrieves a queued packet from the pool.
   605  func (s *UDPNATRelay) getQueuedPacket() *natQueuedPacket {
   606  	return s.queuedPacketPool.Get().(*natQueuedPacket)
   607  }
   608  
   609  // putQueuedPacket puts the queued packet back into the pool.
   610  func (s *UDPNATRelay) putQueuedPacket(queuedPacket *natQueuedPacket) {
   611  	s.queuedPacketPool.Put(queuedPacket)
   612  }
   613  
   614  // Stop implements the Service Stop method.
   615  func (s *UDPNATRelay) Stop() error {
   616  	for i := range s.listeners {
   617  		lnc := &s.listeners[i]
   618  		if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   619  			lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err))
   620  		}
   621  	}
   622  
   623  	// Wait for serverConn receive goroutines to exit,
   624  	// so there won't be any new sessions added to the table.
   625  	s.mwg.Wait()
   626  
   627  	s.mu.Lock()
   628  	for clientAddrPort, entry := range s.table {
   629  		natConn := entry.state.Swap(entry.serverConn)
   630  		if natConn == nil {
   631  			continue
   632  		}
   633  
   634  		if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil {
   635  			entry.logger.Warn("Failed to set read deadline on natConn",
   636  				zap.Stringer("clientAddress", clientAddrPort),
   637  				zap.Error(err),
   638  			)
   639  		}
   640  	}
   641  	s.mu.Unlock()
   642  
   643  	// Wait for all relay goroutines to exit before closing serverConn,
   644  	// so in-flight packets can be written out.
   645  	s.wg.Wait()
   646  
   647  	for i := range s.listeners {
   648  		lnc := &s.listeners[i]
   649  		if err := lnc.serverConn.Close(); err != nil {
   650  			lnc.logger.Warn("Failed to close serverConn", zap.Error(err))
   651  		}
   652  	}
   653  
   654  	return nil
   655  }