github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/link/fdbased/packet_dispatchers.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  package fdbased
    19  
    20  import (
    21  	"golang.org/x/sys/unix"
    22  	"github.com/sagernet/gvisor/pkg/buffer"
    23  	"github.com/sagernet/gvisor/pkg/tcpip"
    24  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    25  	"github.com/sagernet/gvisor/pkg/tcpip/link/rawfile"
    26  	"github.com/sagernet/gvisor/pkg/tcpip/link/stopfd"
    27  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    28  	"github.com/sagernet/gvisor/pkg/tcpip/stack/gro"
    29  )
    30  
    31  // BufConfig defines the shape of the buffer used to read packets from the NIC.
    32  var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
    33  
    34  type iovecBuffer struct {
    35  	// buffer is the actual buffer that holds the packet contents. Some contents
    36  	// are reused across calls to pullBuffer if number of requested bytes is
    37  	// smaller than the number of bytes allocated in the buffer.
    38  	views []*buffer.View
    39  
    40  	// iovecs are initialized with base pointers/len of the corresponding
    41  	// entries in the views defined above, except when GSO is enabled
    42  	// (skipsVnetHdr) then the first iovec points to a buffer for the vnet header
    43  	// which is stripped before the views are passed up the stack for further
    44  	// processing.
    45  	iovecs []unix.Iovec
    46  
    47  	// sizes is an array of buffer sizes for the underlying views. sizes is
    48  	// immutable.
    49  	sizes []int
    50  
    51  	// skipsVnetHdr is true if virtioNetHdr is to skipped.
    52  	skipsVnetHdr bool
    53  
    54  	// pulledIndex is the index of the last []byte buffer pulled from the
    55  	// underlying buffer storage during a call to pullBuffers. It is -1
    56  	// if no buffer is pulled.
    57  	pulledIndex int
    58  }
    59  
    60  func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer {
    61  	b := &iovecBuffer{
    62  		views:        make([]*buffer.View, len(sizes)),
    63  		sizes:        sizes,
    64  		skipsVnetHdr: skipsVnetHdr,
    65  	}
    66  	niov := len(b.views)
    67  	if b.skipsVnetHdr {
    68  		niov++
    69  	}
    70  	b.iovecs = make([]unix.Iovec, niov)
    71  	return b
    72  }
    73  
    74  func (b *iovecBuffer) nextIovecs() []unix.Iovec {
    75  	vnetHdrOff := 0
    76  	if b.skipsVnetHdr {
    77  		var vnetHdr [virtioNetHdrSize]byte
    78  		// The kernel adds virtioNetHdr before each packet, but
    79  		// we don't use it, so we allocate a buffer for it,
    80  		// add it in iovecs but don't add it in a view.
    81  		b.iovecs[0] = unix.Iovec{Base: &vnetHdr[0]}
    82  		b.iovecs[0].SetLen(virtioNetHdrSize)
    83  		vnetHdrOff++
    84  	}
    85  
    86  	for i := range b.views {
    87  		if b.views[i] != nil {
    88  			break
    89  		}
    90  		v := buffer.NewViewSize(b.sizes[i])
    91  		b.views[i] = v
    92  		b.iovecs[i+vnetHdrOff] = unix.Iovec{Base: v.BasePtr()}
    93  		b.iovecs[i+vnetHdrOff].SetLen(v.Size())
    94  	}
    95  	return b.iovecs
    96  }
    97  
    98  // pullBuffer extracts the enough underlying storage from b.buffer to hold n
    99  // bytes. It removes this storage from b.buffer, returns a new buffer
   100  // that holds the storage, and updates pulledIndex to indicate which part
   101  // of b.buffer's storage must be reallocated during the next call to
   102  // nextIovecs.
   103  func (b *iovecBuffer) pullBuffer(n int) buffer.Buffer {
   104  	var views []*buffer.View
   105  	c := 0
   106  	if b.skipsVnetHdr {
   107  		c += virtioNetHdrSize
   108  		if c >= n {
   109  			// Nothing in the packet.
   110  			return buffer.Buffer{}
   111  		}
   112  	}
   113  	// Remove the used views from the buffer.
   114  	for i, v := range b.views {
   115  		c += v.Size()
   116  		if c >= n {
   117  			b.views[i].CapLength(v.Size() - (c - n))
   118  			views = append(views, b.views[:i+1]...)
   119  			break
   120  		}
   121  	}
   122  	for i := range views {
   123  		b.views[i] = nil
   124  	}
   125  	if b.skipsVnetHdr {
   126  		// Exclude the size of the vnet header.
   127  		n -= virtioNetHdrSize
   128  	}
   129  	pulled := buffer.Buffer{}
   130  	for _, v := range views {
   131  		pulled.Append(v)
   132  	}
   133  	pulled.Truncate(int64(n))
   134  	return pulled
   135  }
   136  
   137  func (b *iovecBuffer) release() {
   138  	for _, v := range b.views {
   139  		if v != nil {
   140  			v.Release()
   141  			v = nil
   142  		}
   143  	}
   144  }
   145  
   146  // readVDispatcher uses readv() system call to read inbound packets and
   147  // dispatches them.
   148  type readVDispatcher struct {
   149  	stopfd.StopFD
   150  	// fd is the file descriptor used to send and receive packets.
   151  	fd int
   152  
   153  	// e is the endpoint this dispatcher is attached to.
   154  	e *endpoint
   155  
   156  	// buf is the iovec buffer that contains the packet contents.
   157  	buf *iovecBuffer
   158  }
   159  
   160  func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
   161  	stopFD, err := stopfd.New()
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  	d := &readVDispatcher{
   166  		StopFD: stopFD,
   167  		fd:     fd,
   168  		e:      e,
   169  	}
   170  	skipsVnetHdr := d.e.gsoKind == stack.HostGSOSupported
   171  	d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
   172  	return d, nil
   173  }
   174  
   175  func (d *readVDispatcher) release() {
   176  	d.buf.release()
   177  }
   178  
   179  // dispatch reads one packet from the file descriptor and dispatches it.
   180  func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
   181  	n, err := rawfile.BlockingReadvUntilStopped(d.EFD, d.fd, d.buf.nextIovecs())
   182  	if n <= 0 || err != nil {
   183  		return false, err
   184  	}
   185  
   186  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   187  		Payload: d.buf.pullBuffer(n),
   188  	})
   189  	defer pkt.DecRef()
   190  
   191  	var p tcpip.NetworkProtocolNumber
   192  	if d.e.hdrSize > 0 {
   193  		if !d.e.parseHeader(pkt) {
   194  			return false, nil
   195  		}
   196  		p = header.Ethernet(pkt.LinkHeader().Slice()).Type()
   197  	} else {
   198  		// We don't get any indication of what the packet is, so try to guess
   199  		// if it's an IPv4 or IPv6 packet.
   200  		// IP version information is at the first octet, so pulling up 1 byte.
   201  		h, ok := pkt.Data().PullUp(1)
   202  		if !ok {
   203  			return true, nil
   204  		}
   205  		switch header.IPVersion(h) {
   206  		case header.IPv4Version:
   207  			p = header.IPv4ProtocolNumber
   208  		case header.IPv6Version:
   209  			p = header.IPv6ProtocolNumber
   210  		default:
   211  			return true, nil
   212  		}
   213  	}
   214  
   215  	d.e.mu.RLock()
   216  	dsp := d.e.dispatcher
   217  	d.e.mu.RUnlock()
   218  	dsp.DeliverNetworkPacket(p, pkt)
   219  
   220  	return true, nil
   221  }
   222  
   223  // recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
   224  // dispatches them.
   225  type recvMMsgDispatcher struct {
   226  	stopfd.StopFD
   227  	// fd is the file descriptor used to send and receive packets.
   228  	fd int
   229  
   230  	// e is the endpoint this dispatcher is attached to.
   231  	e *endpoint
   232  
   233  	// bufs is an array of iovec buffers that contain packet contents.
   234  	bufs []*iovecBuffer
   235  
   236  	// msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
   237  	// reference an array of iovecs in the iovecs field defined above.  This
   238  	// array is passed as the parameter to recvmmsg call to retrieve
   239  	// potentially more than 1 packet per unix.
   240  	msgHdrs []rawfile.MMsgHdr
   241  
   242  	// pkts is reused to avoid allocations.
   243  	pkts stack.PacketBufferList
   244  
   245  	// gro coalesces incoming packets to increase throughput.
   246  	gro gro.GRO
   247  }
   248  
   249  const (
   250  	// MaxMsgsPerRecv is the maximum number of packets we want to retrieve
   251  	// in a single RecvMMsg call.
   252  	MaxMsgsPerRecv = 8
   253  )
   254  
   255  func newRecvMMsgDispatcher(fd int, e *endpoint, opts *Options) (linkDispatcher, error) {
   256  	stopFD, err := stopfd.New()
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	d := &recvMMsgDispatcher{
   261  		StopFD:  stopFD,
   262  		fd:      fd,
   263  		e:       e,
   264  		bufs:    make([]*iovecBuffer, MaxMsgsPerRecv),
   265  		msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
   266  	}
   267  	skipsVnetHdr := d.e.gsoKind == stack.HostGSOSupported
   268  	for i := range d.bufs {
   269  		d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
   270  	}
   271  	d.gro.Init(opts.GRO)
   272  
   273  	return d, nil
   274  }
   275  
   276  func (d *recvMMsgDispatcher) release() {
   277  	for _, iov := range d.bufs {
   278  		iov.release()
   279  	}
   280  }
   281  
   282  // recvMMsgDispatch reads more than one packet at a time from the file
   283  // descriptor and dispatches it.
   284  func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
   285  	// Fill message headers.
   286  	for k := range d.msgHdrs {
   287  		if d.msgHdrs[k].Msg.Iovlen > 0 {
   288  			break
   289  		}
   290  		iovecs := d.bufs[k].nextIovecs()
   291  		iovLen := len(iovecs)
   292  		d.msgHdrs[k].Len = 0
   293  		d.msgHdrs[k].Msg.Iov = &iovecs[0]
   294  		d.msgHdrs[k].Msg.SetIovlen(iovLen)
   295  	}
   296  
   297  	nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.EFD, d.fd, d.msgHdrs)
   298  	if nMsgs == -1 || err != nil {
   299  		return false, err
   300  	}
   301  
   302  	// Process each of received packets.
   303  
   304  	d.e.mu.RLock()
   305  	dsp := d.e.dispatcher
   306  	d.e.mu.RUnlock()
   307  
   308  	d.gro.Dispatcher = dsp
   309  	defer d.pkts.Reset()
   310  
   311  	for k := 0; k < nMsgs; k++ {
   312  		n := int(d.msgHdrs[k].Len)
   313  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   314  			Payload: d.bufs[k].pullBuffer(n),
   315  		})
   316  		d.pkts.PushBack(pkt)
   317  
   318  		// Mark that this iovec has been processed.
   319  		d.msgHdrs[k].Msg.Iovlen = 0
   320  
   321  		var p tcpip.NetworkProtocolNumber
   322  		if d.e.hdrSize > 0 {
   323  			hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
   324  			if !ok {
   325  				return false, nil
   326  			}
   327  			p = header.Ethernet(hdr).Type()
   328  		} else {
   329  			// We don't get any indication of what the packet is, so try to guess
   330  			// if it's an IPv4 or IPv6 packet.
   331  			// IP version information is at the first octet, so pulling up 1 byte.
   332  			h, ok := pkt.Data().PullUp(1)
   333  			if !ok {
   334  				// Skip this packet.
   335  				continue
   336  			}
   337  			switch header.IPVersion(h) {
   338  			case header.IPv4Version:
   339  				p = header.IPv4ProtocolNumber
   340  			case header.IPv6Version:
   341  				p = header.IPv6ProtocolNumber
   342  			default:
   343  				// Skip this packet.
   344  				continue
   345  			}
   346  		}
   347  
   348  		// Only use GRO if there's more than one packet.
   349  		if nMsgs > 1 {
   350  			pkt.NetworkProtocolNumber = p
   351  			pkt.RXChecksumValidated = d.e.caps&stack.CapabilityRXChecksumOffload != 0
   352  			d.gro.Enqueue(pkt)
   353  		} else {
   354  			dsp.DeliverNetworkPacket(p, pkt)
   355  			return true, nil
   356  		}
   357  	}
   358  	d.gro.Flush()
   359  
   360  	return true, nil
   361  }