github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/pending_packets.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/sync"
    21  	"github.com/SagerNet/gvisor/pkg/tcpip"
    22  )
    23  
    24  const (
    25  	// maxPendingResolutions is the maximum number of pending link-address
    26  	// resolutions.
    27  	maxPendingResolutions          = 64
    28  	maxPendingPacketsPerResolution = 256
    29  )
    30  
    31  // pendingPacketBuffer is a pending packet buffer.
    32  //
    33  // TODO(github.com/SagerNet/issue/5331): Drop this when we drop WritePacket and only use
    34  // WritePackets so we can use a PacketBufferList everywhere.
    35  type pendingPacketBuffer interface {
    36  	len() int
    37  }
    38  
    39  func (*PacketBuffer) len() int {
    40  	return 1
    41  }
    42  
    43  func (p *PacketBufferList) len() int {
    44  	return p.Len()
    45  }
    46  
    47  type pendingPacket struct {
    48  	routeInfo RouteInfo
    49  	proto     tcpip.NetworkProtocolNumber
    50  	pkt       pendingPacketBuffer
    51  }
    52  
    53  // packetsPendingLinkResolution is a queue of packets pending link resolution.
    54  //
    55  // Once link resolution completes successfully, the packets will be written.
    56  type packetsPendingLinkResolution struct {
    57  	nic *nic
    58  
    59  	mu struct {
    60  		sync.Mutex
    61  
    62  		// The packets to send once the resolver completes.
    63  		//
    64  		// The link resolution channel is used as the key for this map.
    65  		packets map[<-chan struct{}][]pendingPacket
    66  
    67  		// FIFO of channels used to cancel the oldest goroutine waiting for
    68  		// link-address resolution.
    69  		//
    70  		// cancelChans holds the same channels that are used as keys to packets.
    71  		cancelChans []<-chan struct{}
    72  	}
    73  }
    74  
    75  func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
    76  	n := uint64(pkt.len())
    77  	f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
    78  
    79  	if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
    80  		ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
    81  	}
    82  }
    83  
    84  func (f *packetsPendingLinkResolution) init(nic *nic) {
    85  	f.mu.Lock()
    86  	defer f.mu.Unlock()
    87  	f.nic = nic
    88  	f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
    89  }
    90  
    91  // dequeue any pending packets associated with ch.
    92  //
    93  // If err is nil, packets will be written and sent to the given remote link
    94  // address.
    95  func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, err tcpip.Error) {
    96  	f.mu.Lock()
    97  	packets, ok := f.mu.packets[ch]
    98  	delete(f.mu.packets, ch)
    99  
   100  	if ok {
   101  		for i, cancelChan := range f.mu.cancelChans {
   102  			if cancelChan == ch {
   103  				f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
   104  				break
   105  			}
   106  		}
   107  	}
   108  
   109  	f.mu.Unlock()
   110  
   111  	if ok {
   112  		f.dequeuePackets(packets, linkAddr, err)
   113  	}
   114  }
   115  
   116  // enqueue a packet to be sent once link resolution completes.
   117  //
   118  // If the maximum number of pending resolutions is reached, the packets
   119  // associated with the oldest link resolution will be dequeued as if they failed
   120  // link resolution.
   121  func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
   122  	f.mu.Lock()
   123  	// Make sure we attempt resolution while holding f's lock so that we avoid
   124  	// a race where link resolution completes before we enqueue the packets.
   125  	//
   126  	//   A @ T1: Call ResolvedFields (get link resolution channel)
   127  	//   B @ T2: Complete link resolution, dequeue pending packets
   128  	//   C @ T1: Enqueue packet that already completed link resolution (which will
   129  	//       never dequeue)
   130  	//
   131  	// To make sure B does not interleave with A and C, we make sure A and C are
   132  	// done while holding the lock.
   133  	routeInfo, ch, err := r.resolvedFields(nil)
   134  	switch err.(type) {
   135  	case nil:
   136  		// The route resolved immediately, so we don't need to wait for link
   137  		// resolution to send the packet.
   138  		f.mu.Unlock()
   139  		return f.nic.writePacketBuffer(routeInfo, proto, pkt)
   140  	case *tcpip.ErrWouldBlock:
   141  		// We need to wait for link resolution to complete.
   142  	default:
   143  		f.mu.Unlock()
   144  		return 0, err
   145  	}
   146  
   147  	defer f.mu.Unlock()
   148  
   149  	packets, ok := f.mu.packets[ch]
   150  	packets = append(packets, pendingPacket{
   151  		routeInfo: routeInfo,
   152  		proto:     proto,
   153  		pkt:       pkt,
   154  	})
   155  
   156  	if len(packets) > maxPendingPacketsPerResolution {
   157  		f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
   158  		packets[0] = pendingPacket{}
   159  		packets = packets[1:]
   160  
   161  		if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
   162  			panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
   163  		}
   164  	}
   165  
   166  	f.mu.packets[ch] = packets
   167  
   168  	if ok {
   169  		return pkt.len(), nil
   170  	}
   171  
   172  	cancelledPackets := f.newCancelChannelLocked(ch)
   173  
   174  	if len(cancelledPackets) != 0 {
   175  		// Dequeue the pending packets in a new goroutine to not hold up the current
   176  		// goroutine as handing link resolution failures may be a costly operation.
   177  		go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, &tcpip.ErrAborted{})
   178  	}
   179  
   180  	return pkt.len(), nil
   181  }
   182  
   183  // newCancelChannelLocked appends the link resolution channel to a FIFO. If the
   184  // maximum number of pending resolutions is reached, the oldest channel will be
   185  // removed and its associated pending packets will be returned.
   186  func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
   187  	f.mu.cancelChans = append(f.mu.cancelChans, newCH)
   188  	if len(f.mu.cancelChans) <= maxPendingResolutions {
   189  		return nil
   190  	}
   191  
   192  	ch := f.mu.cancelChans[0]
   193  	f.mu.cancelChans[0] = nil
   194  	f.mu.cancelChans = f.mu.cancelChans[1:]
   195  	if l := len(f.mu.cancelChans); l > maxPendingResolutions {
   196  		panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
   197  	}
   198  
   199  	packets, ok := f.mu.packets[ch]
   200  	if !ok {
   201  		panic("must have a packet queue for an uncancelled channel")
   202  	}
   203  	delete(f.mu.packets, ch)
   204  
   205  	return packets
   206  }
   207  
   208  func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, err tcpip.Error) {
   209  	for _, p := range packets {
   210  		if err == nil {
   211  			p.routeInfo.RemoteLinkAddress = linkAddr
   212  			_, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt)
   213  		} else {
   214  			f.incrementOutgoingPacketErrors(p.proto, p.pkt)
   215  
   216  			if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
   217  				switch pkt := p.pkt.(type) {
   218  				case *PacketBuffer:
   219  					linkResolvableEP.HandleLinkResolutionFailure(pkt)
   220  				case *PacketBufferList:
   221  					for pb := pkt.Front(); pb != nil; pb = pb.Next() {
   222  						linkResolvableEP.HandleLinkResolutionFailure(pb)
   223  					}
   224  				default:
   225  					panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
   226  				}
   227  			}
   228  		}
   229  	}
   230  }