github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/stack/pending_packets.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    21  )
    22  
    23  const (
    24  	// maxPendingResolutions is the maximum number of pending link-address
    25  	// resolutions.
    26  	maxPendingResolutions          = 64
    27  	maxPendingPacketsPerResolution = 256
    28  )
    29  
    30  type pendingPacket struct {
    31  	routeInfo RouteInfo
    32  	pkt       PacketBufferPtr
    33  }
    34  
    35  // packetsPendingLinkResolution is a queue of packets pending link resolution.
    36  //
    37  // Once link resolution completes successfully, the packets will be written.
    38  type packetsPendingLinkResolution struct {
    39  	nic *nic
    40  
    41  	mu struct {
    42  		packetsPendingLinkResolutionMutex
    43  
    44  		// The packets to send once the resolver completes.
    45  		//
    46  		// The link resolution channel is used as the key for this map.
    47  		packets map[<-chan struct{}][]pendingPacket
    48  
    49  		// FIFO of channels used to cancel the oldest goroutine waiting for
    50  		// link-address resolution.
    51  		//
    52  		// cancelChans holds the same channels that are used as keys to packets.
    53  		cancelChans []<-chan struct{}
    54  	}
    55  }
    56  
    57  func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(pkt PacketBufferPtr) {
    58  	f.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
    59  
    60  	if ipEndpointStats, ok := f.nic.getNetworkEndpoint(pkt.NetworkProtocolNumber).Stats().(IPNetworkEndpointStats); ok {
    61  		ipEndpointStats.IPStats().OutgoingPacketErrors.Increment()
    62  	}
    63  }
    64  
    65  func (f *packetsPendingLinkResolution) init(nic *nic) {
    66  	f.mu.Lock()
    67  	defer f.mu.Unlock()
    68  	f.nic = nic
    69  	f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
    70  }
    71  
    72  // cancel drains all pending packet queues and release all packet
    73  // references.
    74  func (f *packetsPendingLinkResolution) cancel() {
    75  	f.mu.Lock()
    76  	defer f.mu.Unlock()
    77  	for ch, pendingPackets := range f.mu.packets {
    78  		for _, p := range pendingPackets {
    79  			p.pkt.DecRef()
    80  		}
    81  		delete(f.mu.packets, ch)
    82  	}
    83  	f.mu.cancelChans = nil
    84  }
    85  
    86  // dequeue any pending packets associated with ch.
    87  //
    88  // If err is nil, packets will be written and sent to the given remote link
    89  // address.
    90  func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, err tcpip.Error) {
    91  	f.mu.Lock()
    92  	packets, ok := f.mu.packets[ch]
    93  	delete(f.mu.packets, ch)
    94  
    95  	if ok {
    96  		for i, cancelChan := range f.mu.cancelChans {
    97  			if cancelChan == ch {
    98  				f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
    99  				break
   100  			}
   101  		}
   102  	}
   103  
   104  	f.mu.Unlock()
   105  
   106  	if ok {
   107  		f.dequeuePackets(packets, linkAddr, err)
   108  	}
   109  }
   110  
   111  // enqueue a packet to be sent once link resolution completes.
   112  //
   113  // If the maximum number of pending resolutions is reached, the packets
   114  // associated with the oldest link resolution will be dequeued as if they failed
   115  // link resolution.
   116  func (f *packetsPendingLinkResolution) enqueue(r *Route, pkt PacketBufferPtr) tcpip.Error {
   117  	f.mu.Lock()
   118  	// Make sure we attempt resolution while holding f's lock so that we avoid
   119  	// a race where link resolution completes before we enqueue the packets.
   120  	//
   121  	//   A @ T1: Call ResolvedFields (get link resolution channel)
   122  	//   B @ T2: Complete link resolution, dequeue pending packets
   123  	//   C @ T1: Enqueue packet that already completed link resolution (which will
   124  	//       never dequeue)
   125  	//
   126  	// To make sure B does not interleave with A and C, we make sure A and C are
   127  	// done while holding the lock.
   128  	routeInfo, ch, err := r.resolvedFields(nil)
   129  	switch err.(type) {
   130  	case nil:
   131  		// The route resolved immediately, so we don't need to wait for link
   132  		// resolution to send the packet.
   133  		f.mu.Unlock()
   134  		pkt.EgressRoute = routeInfo
   135  		return f.nic.writePacket(pkt)
   136  	case *tcpip.ErrWouldBlock:
   137  		// We need to wait for link resolution to complete.
   138  	default:
   139  		f.mu.Unlock()
   140  		return err
   141  	}
   142  
   143  	defer f.mu.Unlock()
   144  
   145  	packets, ok := f.mu.packets[ch]
   146  	packets = append(packets, pendingPacket{
   147  		routeInfo: routeInfo,
   148  		pkt:       pkt.IncRef(),
   149  	})
   150  
   151  	if len(packets) > maxPendingPacketsPerResolution {
   152  		f.incrementOutgoingPacketErrors(packets[0].pkt)
   153  		packets[0].pkt.DecRef()
   154  		packets[0] = pendingPacket{}
   155  		packets = packets[1:]
   156  
   157  		if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
   158  			panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
   159  		}
   160  	}
   161  
   162  	f.mu.packets[ch] = packets
   163  
   164  	if ok {
   165  		return nil
   166  	}
   167  
   168  	cancelledPackets := f.newCancelChannelLocked(ch)
   169  
   170  	if len(cancelledPackets) != 0 {
   171  		// Dequeue the pending packets in a new goroutine to not hold up the current
   172  		// goroutine as handing link resolution failures may be a costly operation.
   173  		go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, &tcpip.ErrAborted{})
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  // newCancelChannelLocked appends the link resolution channel to a FIFO. If the
   180  // maximum number of pending resolutions is reached, the oldest channel will be
   181  // removed and its associated pending packets will be returned.
   182  func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
   183  	f.mu.cancelChans = append(f.mu.cancelChans, newCH)
   184  	if len(f.mu.cancelChans) <= maxPendingResolutions {
   185  		return nil
   186  	}
   187  
   188  	ch := f.mu.cancelChans[0]
   189  	f.mu.cancelChans[0] = nil
   190  	f.mu.cancelChans = f.mu.cancelChans[1:]
   191  	if l := len(f.mu.cancelChans); l > maxPendingResolutions {
   192  		panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
   193  	}
   194  
   195  	packets, ok := f.mu.packets[ch]
   196  	if !ok {
   197  		panic("must have a packet queue for an uncancelled channel")
   198  	}
   199  	delete(f.mu.packets, ch)
   200  
   201  	return packets
   202  }
   203  
   204  func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, err tcpip.Error) {
   205  	for _, p := range packets {
   206  		if err == nil {
   207  			p.routeInfo.RemoteLinkAddress = linkAddr
   208  			p.pkt.EgressRoute = p.routeInfo
   209  			_ = f.nic.writePacket(p.pkt)
   210  		} else {
   211  			f.incrementOutgoingPacketErrors(p.pkt)
   212  
   213  			if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.pkt.NetworkProtocolNumber).(LinkResolvableNetworkEndpoint); ok {
   214  				linkResolvableEP.HandleLinkResolutionFailure(p.pkt)
   215  			}
   216  		}
   217  		p.pkt.DecRef()
   218  	}
   219  }