github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/link/sharedmem/sharedmem.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  // Package sharedmem provides the implemention of data-link layer endpoints
    19  // backed by shared memory.
    20  //
    21  // Shared memory endpoints can be used in the networking stack by calling New()
    22  // to create a new endpoint, and then passing it as an argument to
    23  // Stack.CreateNIC().
    24  package sharedmem
    25  
    26  import (
    27  	"fmt"
    28  
    29  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    30  	"github.com/nicocha30/gvisor-ligolo/pkg/buffer"
    31  	"github.com/nicocha30/gvisor-ligolo/pkg/eventfd"
    32  	"github.com/nicocha30/gvisor-ligolo/pkg/log"
    33  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    34  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    35  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    36  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/link/rawfile"
    37  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/link/sharedmem/queue"
    38  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    39  )
    40  
    41  // QueueConfig holds all the file descriptors needed to describe a tx or rx
    42  // queue over shared memory. It is used when creating new shared memory
    43  // endpoints to describe tx and rx queues.
    44  type QueueConfig struct {
    45  	// DataFD is a file descriptor for the file that contains the data to
    46  	// be transmitted via this queue. Descriptors contain offsets within
    47  	// this file.
    48  	DataFD int
    49  
    50  	// EventFD is a file descriptor for the event that is signaled when
    51  	// data is becomes available in this queue.
    52  	EventFD eventfd.Eventfd
    53  
    54  	// TxPipeFD is a file descriptor for the tx pipe associated with the
    55  	// queue.
    56  	TxPipeFD int
    57  
    58  	// RxPipeFD is a file descriptor for the rx pipe associated with the
    59  	// queue.
    60  	RxPipeFD int
    61  
    62  	// SharedDataFD is a file descriptor for the file that contains shared
    63  	// state between the two ends of the queue. This data specifies, for
    64  	// example, whether EventFD signaling is enabled or disabled.
    65  	SharedDataFD int
    66  }
    67  
    68  // FDs returns the FD's in the QueueConfig as a slice of ints. This must
    69  // be used in conjunction with QueueConfigFromFDs to ensure the order
    70  // of FDs matches when reconstructing the config when serialized or sent
    71  // as part of control messages.
    72  func (q *QueueConfig) FDs() []int {
    73  	return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
    74  }
    75  
    76  // QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
    77  // entry represents an file descriptor. The order of FDs in the slice must be in
    78  // the order specified below for the config to be valid. QueueConfig.FDs()
    79  // should be used when the config needs to be serialized or sent as part of a
    80  // control message to ensure the correct order.
    81  func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
    82  	if len(fds) != 5 {
    83  		return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
    84  	}
    85  	return QueueConfig{
    86  		DataFD:       fds[0],
    87  		EventFD:      eventfd.Wrap(fds[1]),
    88  		TxPipeFD:     fds[2],
    89  		RxPipeFD:     fds[3],
    90  		SharedDataFD: fds[4],
    91  	}, nil
    92  }
    93  
    94  // Options specify the details about the sharedmem endpoint to be created.
    95  type Options struct {
    96  	// MTU is the mtu to use for this endpoint.
    97  	MTU uint32
    98  
    99  	// BufferSize is the size of each scatter/gather buffer that will hold packet
   100  	// data.
   101  	//
   102  	// NOTE: This directly determines number of packets that can be held in
   103  	// the ring buffer at any time. This does not have to be sized to the MTU as
   104  	// the shared memory queue design allows usage of more than one buffer to be
   105  	// used to make up a given packet.
   106  	BufferSize uint32
   107  
   108  	// LinkAddress is the link address for this endpoint (required).
   109  	LinkAddress tcpip.LinkAddress
   110  
   111  	// TX is the transmit queue configuration for this shared memory endpoint.
   112  	TX QueueConfig
   113  
   114  	// RX is the receive queue configuration for this shared memory endpoint.
   115  	RX QueueConfig
   116  
   117  	// PeerFD is the fd for the connected peer which can be used to detect
   118  	// peer disconnects.
   119  	PeerFD int
   120  
   121  	// OnClosed is a function that is called when the endpoint is being closed
   122  	// (probably due to peer going away)
   123  	OnClosed func(err tcpip.Error)
   124  
   125  	// TXChecksumOffload if true, indicates that this endpoints capability
   126  	// set should include CapabilityTXChecksumOffload.
   127  	TXChecksumOffload bool
   128  
   129  	// RXChecksumOffload if true, indicates that this endpoints capability
   130  	// set should include CapabilityRXChecksumOffload.
   131  	RXChecksumOffload bool
   132  
   133  	// VirtioNetHeaderRequired if true, indicates that all outbound packets should have
   134  	// a virtio header and inbound packets should have a virtio header as well.
   135  	VirtioNetHeaderRequired bool
   136  
   137  	// GSOMaxSize is the maximum GSO packet size. It is zero if GSO is
   138  	// disabled. Note that only gVisor GSO is supported, not host GSO.
   139  	GSOMaxSize uint32
   140  }
   141  
   142  var _ stack.LinkEndpoint = (*endpoint)(nil)
   143  var _ stack.GSOEndpoint = (*endpoint)(nil)
   144  
   145  type endpoint struct {
   146  	// mtu (maximum transmission unit) is the maximum size of a packet.
   147  	// mtu is immutable.
   148  	mtu uint32
   149  
   150  	// bufferSize is the size of each individual buffer.
   151  	// bufferSize is immutable.
   152  	bufferSize uint32
   153  
   154  	// addr is the local address of this endpoint.
   155  	// addr is immutable.
   156  	addr tcpip.LinkAddress
   157  
   158  	// peerFD is an fd to the peer that can be used to detect when the
   159  	// peer is gone.
   160  	// peerFD is immutable.
   161  	peerFD int
   162  
   163  	// caps holds the endpoint capabilities.
   164  	caps stack.LinkEndpointCapabilities
   165  
   166  	// hdrSize is the size of the link layer header if any.
   167  	// hdrSize is immutable.
   168  	hdrSize uint32
   169  
   170  	// gSOMaxSize is the maximum GSO packet size. It is zero if GSO is
   171  	// disabled. Note that only gVisor GSO is supported, not host GSO.
   172  	// gsoMaxSize is immutable.
   173  	gsoMaxSize uint32
   174  
   175  	// virtioNetHeaderRequired if true indicates that a virtio header is expected
   176  	// in all inbound/outbound packets.
   177  	virtioNetHeaderRequired bool
   178  
   179  	// rx is the receive queue.
   180  	rx rx
   181  
   182  	// stopRequested  determines whether the worker goroutines should stop.
   183  	stopRequested atomicbitops.Uint32
   184  
   185  	// Wait group used to indicate that all workers have stopped.
   186  	completed sync.WaitGroup
   187  
   188  	// onClosed is a function to be called when the FD's peer (if any) closes
   189  	// its end of the communication pipe.
   190  	onClosed func(tcpip.Error)
   191  
   192  	// mu protects the following fields.
   193  	mu sync.Mutex
   194  
   195  	// tx is the transmit queue.
   196  	// +checklocks:mu
   197  	tx tx
   198  
   199  	// workerStarted specifies whether the worker goroutine was started.
   200  	// +checklocks:mu
   201  	workerStarted bool
   202  }
   203  
   204  // New creates a new shared-memory-based endpoint. Buffers will be broken up
   205  // into buffers of "bufferSize" bytes.
   206  //
   207  // In order to release all resources held by the returned endpoint, Close()
   208  // must be called followed by Wait().
   209  func New(opts Options) (stack.LinkEndpoint, error) {
   210  	e := &endpoint{
   211  		mtu:                     opts.MTU,
   212  		bufferSize:              opts.BufferSize,
   213  		addr:                    opts.LinkAddress,
   214  		peerFD:                  opts.PeerFD,
   215  		onClosed:                opts.OnClosed,
   216  		virtioNetHeaderRequired: opts.VirtioNetHeaderRequired,
   217  		gsoMaxSize:              opts.GSOMaxSize,
   218  	}
   219  
   220  	if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
   225  		e.tx.cleanup()
   226  		return nil, err
   227  	}
   228  
   229  	e.caps = stack.LinkEndpointCapabilities(0)
   230  	if opts.RXChecksumOffload {
   231  		e.caps |= stack.CapabilityRXChecksumOffload
   232  	}
   233  
   234  	if opts.TXChecksumOffload {
   235  		e.caps |= stack.CapabilityTXChecksumOffload
   236  	}
   237  
   238  	if opts.LinkAddress != "" {
   239  		e.hdrSize = header.EthernetMinimumSize
   240  		e.caps |= stack.CapabilityResolutionRequired
   241  	}
   242  
   243  	if opts.VirtioNetHeaderRequired {
   244  		e.hdrSize += header.VirtioNetHeaderSize
   245  	}
   246  
   247  	return e, nil
   248  }
   249  
   250  // Close frees most resources associated with the endpoint. Wait() must be
   251  // called after Close() in order to free the rest.
   252  func (e *endpoint) Close() {
   253  	// Tell dispatch goroutine to stop, then write to the eventfd so that
   254  	// it wakes up in case it's sleeping.
   255  	e.stopRequested.Store(1)
   256  	e.rx.eventFD.Notify()
   257  
   258  	// Cleanup the queues inline if the worker hasn't started yet; we also
   259  	// know it won't start from now on because stopRequested is set to 1.
   260  	e.mu.Lock()
   261  	defer e.mu.Unlock()
   262  	workerPresent := e.workerStarted
   263  
   264  	if !workerPresent {
   265  		e.tx.cleanup()
   266  		e.rx.cleanup()
   267  	}
   268  }
   269  
   270  // Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
   271  // stopped after a Close() call.
   272  func (e *endpoint) Wait() {
   273  	e.completed.Wait()
   274  	e.rx.eventFD.Close()
   275  }
   276  
   277  // Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
   278  // reads packets from the rx queue.
   279  func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
   280  	if dispatcher == nil {
   281  		e.Close()
   282  		return
   283  	}
   284  	e.mu.Lock()
   285  	if !e.workerStarted && e.stopRequested.Load() == 0 {
   286  		e.workerStarted = true
   287  		e.completed.Add(1)
   288  
   289  		// Spin up a goroutine to monitor for peer shutdown.
   290  		if e.peerFD >= 0 {
   291  			e.completed.Add(1)
   292  			go func() {
   293  				defer e.completed.Done()
   294  				b := make([]byte, 1)
   295  				// When sharedmem endpoint is in use the peerFD is never used for any data
   296  				// transfer and this Read should only return if the peer is shutting down.
   297  				_, err := rawfile.BlockingRead(e.peerFD, b)
   298  				if e.onClosed != nil {
   299  					e.onClosed(err)
   300  				}
   301  			}()
   302  		}
   303  
   304  		// Link endpoints are not savable. When transportation endpoints
   305  		// are saved, they stop sending outgoing packets and all
   306  		// incoming packets are rejected.
   307  		go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
   308  	}
   309  	e.mu.Unlock()
   310  }
   311  
   312  // IsAttached implements stack.LinkEndpoint.IsAttached.
   313  func (e *endpoint) IsAttached() bool {
   314  	e.mu.Lock()
   315  	defer e.mu.Unlock()
   316  	return e.workerStarted
   317  }
   318  
   319  // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
   320  // during construction.
   321  func (e *endpoint) MTU() uint32 {
   322  	return e.mtu
   323  }
   324  
   325  // Capabilities implements stack.LinkEndpoint.Capabilities.
   326  func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
   327  	return e.caps
   328  }
   329  
   330  // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
   331  // ethernet frame header size.
   332  func (e *endpoint) MaxHeaderLength() uint16 {
   333  	return uint16(e.hdrSize)
   334  }
   335  
   336  // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
   337  // link address.
   338  func (e *endpoint) LinkAddress() tcpip.LinkAddress {
   339  	return e.addr
   340  }
   341  
   342  // AddHeader implements stack.LinkEndpoint.AddHeader.
   343  func (e *endpoint) AddHeader(pkt stack.PacketBufferPtr) {
   344  	// Add ethernet header if needed.
   345  	if len(e.addr) == 0 {
   346  		return
   347  	}
   348  
   349  	eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
   350  	eth.Encode(&header.EthernetFields{
   351  		SrcAddr: pkt.EgressRoute.LocalLinkAddress,
   352  		DstAddr: pkt.EgressRoute.RemoteLinkAddress,
   353  		Type:    pkt.NetworkProtocolNumber,
   354  	})
   355  }
   356  
   357  func (e *endpoint) parseHeader(pkt stack.PacketBufferPtr) bool {
   358  	_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
   359  	return ok
   360  }
   361  
   362  // ParseHeader implements stack.LinkEndpoint.ParseHeader.
   363  func (e *endpoint) ParseHeader(pkt stack.PacketBufferPtr) bool {
   364  	// Add ethernet header if needed.
   365  	if len(e.addr) == 0 {
   366  		return true
   367  	}
   368  
   369  	return e.parseHeader(pkt)
   370  }
   371  
   372  func (e *endpoint) AddVirtioNetHeader(pkt stack.PacketBufferPtr) {
   373  	virtio := header.VirtioNetHeader(pkt.VirtioNetHeader().Push(header.VirtioNetHeaderSize))
   374  	virtio.Encode(&header.VirtioNetHeaderFields{})
   375  }
   376  
   377  // +checklocks:e.mu
   378  func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) tcpip.Error {
   379  	if e.virtioNetHeaderRequired {
   380  		e.AddVirtioNetHeader(pkt)
   381  	}
   382  
   383  	// Transmit the packet.
   384  	b := pkt.ToBuffer()
   385  	defer b.Release()
   386  	ok := e.tx.transmit(b)
   387  	if !ok {
   388  		return &tcpip.ErrWouldBlock{}
   389  	}
   390  
   391  	return nil
   392  }
   393  
   394  // WritePackets implements stack.LinkEndpoint.WritePackets.
   395  func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
   396  	n := 0
   397  	var err tcpip.Error
   398  	e.mu.Lock()
   399  	defer e.mu.Unlock()
   400  	for _, pkt := range pkts.AsSlice() {
   401  		if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
   402  			break
   403  		}
   404  		n++
   405  	}
   406  	// WritePackets never returns an error if it successfully transmitted at least
   407  	// one packet.
   408  	if err != nil && n == 0 {
   409  		return 0, err
   410  	}
   411  	e.tx.notify()
   412  	return n, nil
   413  }
   414  
   415  // dispatchLoop reads packets from the rx queue in a loop and dispatches them
   416  // to the network stack.
   417  func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
   418  	// Post initial set of buffers.
   419  	limit := e.rx.q.PostedBuffersLimit()
   420  	if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l {
   421  		limit = l
   422  	}
   423  	for i := uint64(0); i < limit; i++ {
   424  		b := queue.RxBuffer{
   425  			Offset: i * uint64(e.bufferSize),
   426  			Size:   e.bufferSize,
   427  			ID:     i,
   428  		}
   429  		if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) {
   430  			log.Warningf("Unable to post %v-th buffer", i)
   431  		}
   432  	}
   433  
   434  	// Read in a loop until a stop is requested.
   435  	var rxb []queue.RxBuffer
   436  	for e.stopRequested.Load() == 0 {
   437  		var n uint32
   438  		rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested)
   439  
   440  		// Copy data from the shared area to its own buffer, then
   441  		// prepare to repost the buffer.
   442  		v := buffer.NewView(int(n))
   443  		v.Grow(int(n))
   444  		offset := uint32(0)
   445  		for i := range rxb {
   446  			v.WriteAt(e.rx.data[rxb[i].Offset:][:rxb[i].Size], int(offset))
   447  			offset += rxb[i].Size
   448  
   449  			rxb[i].Size = e.bufferSize
   450  		}
   451  
   452  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   453  			Payload: buffer.MakeWithView(v),
   454  		})
   455  
   456  		if e.virtioNetHeaderRequired {
   457  			_, ok := pkt.VirtioNetHeader().Consume(header.VirtioNetHeaderSize)
   458  			if !ok {
   459  				pkt.DecRef()
   460  				continue
   461  			}
   462  		}
   463  
   464  		var proto tcpip.NetworkProtocolNumber
   465  		if len(e.addr) != 0 {
   466  			if !e.parseHeader(pkt) {
   467  				pkt.DecRef()
   468  				continue
   469  			}
   470  			proto = header.Ethernet(pkt.LinkHeader().Slice()).Type()
   471  		} else {
   472  			// We don't get any indication of what the packet is, so try to guess
   473  			// if it's an IPv4 or IPv6 packet.
   474  			// IP version information is at the first octet, so pulling up 1 byte.
   475  			h, ok := pkt.Data().PullUp(1)
   476  			if !ok {
   477  				pkt.DecRef()
   478  				continue
   479  			}
   480  			switch header.IPVersion(h) {
   481  			case header.IPv4Version:
   482  				proto = header.IPv4ProtocolNumber
   483  			case header.IPv6Version:
   484  				proto = header.IPv6ProtocolNumber
   485  			default:
   486  				pkt.DecRef()
   487  				continue
   488  			}
   489  		}
   490  
   491  		// Send packet up the stack.
   492  		d.DeliverNetworkPacket(proto, pkt)
   493  		pkt.DecRef()
   494  	}
   495  
   496  	e.mu.Lock()
   497  	defer e.mu.Unlock()
   498  
   499  	// Clean state.
   500  	e.tx.cleanup()
   501  	e.rx.cleanup()
   502  
   503  	e.completed.Done()
   504  }
   505  
   506  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
   507  func (*endpoint) ARPHardwareType() header.ARPHardwareType {
   508  	return header.ARPHardwareEther
   509  }
   510  
   511  // GSOMaxSize implements stack.GSOEndpoint.
   512  func (e *endpoint) GSOMaxSize() uint32 {
   513  	return e.gsoMaxSize
   514  }
   515  
   516  // SupportsGSO implements stack.GSOEndpoint.
   517  func (e *endpoint) SupportedGSO() stack.SupportedGSO {
   518  	return stack.GvisorGSOSupported
   519  }