github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/link/sharedmem/sharedmem_server.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  package sharedmem
    19  
    20  import (
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/buffer"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/link/rawfile"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    28  )
    29  
    30  type serverEndpoint struct {
    31  	// mtu (maximum transmission unit) is the maximum size of a packet.
    32  	// mtu is immutable.
    33  	mtu uint32
    34  
    35  	// bufferSize is the size of each individual buffer.
    36  	// bufferSize is immutable.
    37  	bufferSize uint32
    38  
    39  	// addr is the local address of this endpoint.
    40  	// addr is immutable
    41  	addr tcpip.LinkAddress
    42  
    43  	// rx is the receive queue.
    44  	rx serverRx
    45  
    46  	// stopRequested determines whether the worker goroutines should stop.
    47  	stopRequested atomicbitops.Uint32
    48  
    49  	// Wait group used to indicate that all workers have stopped.
    50  	completed sync.WaitGroup
    51  
    52  	// peerFD is an fd to the peer that can be used to detect when the peer is
    53  	// gone.
    54  	// peerFD is immutable.
    55  	peerFD int
    56  
    57  	// caps holds the endpoint capabilities.
    58  	caps stack.LinkEndpointCapabilities
    59  
    60  	// hdrSize is the size of the link layer header if any.
    61  	// hdrSize is immutable.
    62  	hdrSize uint32
    63  
    64  	// virtioNetHeaderRequired if true indicates that a virtio header is expected
    65  	// in all inbound/outbound packets.
    66  	virtioNetHeaderRequired bool
    67  
    68  	// onClosed is a function to be called when the FD's peer (if any) closes its
    69  	// end of the communication pipe.
    70  	onClosed func(tcpip.Error)
    71  
    72  	// mu protects the following fields.
    73  	mu sync.Mutex
    74  
    75  	// tx is the transmit queue.
    76  	// +checklocks:mu
    77  	tx serverTx
    78  
    79  	// workerStarted specifies whether the worker goroutine was started.
    80  	// +checklocks:mu
    81  	workerStarted bool
    82  }
    83  
    84  // NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
    85  // broken up into buffers of "bufferSize" bytes.
    86  func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
    87  	e := &serverEndpoint{
    88  		mtu:        opts.MTU,
    89  		bufferSize: opts.BufferSize,
    90  		addr:       opts.LinkAddress,
    91  		peerFD:     opts.PeerFD,
    92  		onClosed:   opts.OnClosed,
    93  	}
    94  
    95  	if err := e.tx.init(&opts.RX); err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	if err := e.rx.init(&opts.TX); err != nil {
   100  		e.tx.cleanup()
   101  		return nil, err
   102  	}
   103  
   104  	e.caps = stack.LinkEndpointCapabilities(0)
   105  	if opts.RXChecksumOffload {
   106  		e.caps |= stack.CapabilityRXChecksumOffload
   107  	}
   108  
   109  	if opts.TXChecksumOffload {
   110  		e.caps |= stack.CapabilityTXChecksumOffload
   111  	}
   112  
   113  	if opts.LinkAddress != "" {
   114  		e.hdrSize = header.EthernetMinimumSize
   115  		e.caps |= stack.CapabilityResolutionRequired
   116  	}
   117  
   118  	return e, nil
   119  }
   120  
   121  // Close frees all resources associated with the endpoint.
   122  func (e *serverEndpoint) Close() {
   123  	// Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
   124  	// up in case it's sleeping.
   125  	e.stopRequested.Store(1)
   126  	e.rx.eventFD.Notify()
   127  
   128  	// Cleanup the queues inline if the worker hasn't started yet; we also know it
   129  	// won't start from now on because stopRequested is set to 1.
   130  	e.mu.Lock()
   131  	defer e.mu.Unlock()
   132  	workerPresent := e.workerStarted
   133  
   134  	if !workerPresent {
   135  		e.tx.cleanup()
   136  		e.rx.cleanup()
   137  	}
   138  }
   139  
   140  // Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
   141  // stopped after a Close() call.
   142  func (e *serverEndpoint) Wait() {
   143  	e.completed.Wait()
   144  }
   145  
   146  // Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
   147  // reads packets from the rx queue.
   148  func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
   149  	e.mu.Lock()
   150  	if !e.workerStarted && e.stopRequested.Load() == 0 {
   151  		e.workerStarted = true
   152  		e.completed.Add(1)
   153  		if e.peerFD >= 0 {
   154  			e.completed.Add(1)
   155  			// Spin up a goroutine to monitor for peer shutdown.
   156  			go func() {
   157  				b := make([]byte, 1)
   158  				// When sharedmem endpoint is in use the peerFD is never used for any
   159  				// data transfer and this Read should only return if the peer is
   160  				// shutting down.
   161  				_, err := rawfile.BlockingRead(e.peerFD, b)
   162  				if e.onClosed != nil {
   163  					e.onClosed(err)
   164  				}
   165  				e.completed.Done()
   166  			}()
   167  		}
   168  		// Link endpoints are not savable. When transportation endpoints are saved,
   169  		// they stop sending outgoing packets and all incoming packets are rejected.
   170  		go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
   171  	}
   172  	e.mu.Unlock()
   173  }
   174  
   175  // IsAttached implements stack.LinkEndpoint.IsAttached.
   176  func (e *serverEndpoint) IsAttached() bool {
   177  	e.mu.Lock()
   178  	defer e.mu.Unlock()
   179  	return e.workerStarted
   180  }
   181  
   182  // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
   183  // during construction.
   184  func (e *serverEndpoint) MTU() uint32 {
   185  	return e.mtu
   186  }
   187  
   188  // Capabilities implements stack.LinkEndpoint.Capabilities.
   189  func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
   190  	return e.caps
   191  }
   192  
   193  // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
   194  // ethernet frame header size.
   195  func (e *serverEndpoint) MaxHeaderLength() uint16 {
   196  	return uint16(e.hdrSize)
   197  }
   198  
   199  // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
   200  // link address.
   201  func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
   202  	return e.addr
   203  }
   204  
   205  // AddHeader implements stack.LinkEndpoint.AddHeader.
   206  func (e *serverEndpoint) AddHeader(pkt stack.PacketBufferPtr) {
   207  	// Add ethernet header if needed.
   208  	if len(e.addr) == 0 {
   209  		return
   210  	}
   211  
   212  	eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
   213  	eth.Encode(&header.EthernetFields{
   214  		SrcAddr: pkt.EgressRoute.LocalLinkAddress,
   215  		DstAddr: pkt.EgressRoute.RemoteLinkAddress,
   216  		Type:    pkt.NetworkProtocolNumber,
   217  	})
   218  }
   219  
   220  func (e *serverEndpoint) parseHeader(pkt stack.PacketBufferPtr) bool {
   221  	_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
   222  	return ok
   223  }
   224  
   225  // ParseHeader implements stack.LinkEndpoint.ParseHeader.
   226  func (e *serverEndpoint) ParseHeader(pkt stack.PacketBufferPtr) bool {
   227  	// Add ethernet header if needed.
   228  	if len(e.addr) == 0 {
   229  		return true
   230  	}
   231  
   232  	return e.parseHeader(pkt)
   233  }
   234  
   235  func (e *serverEndpoint) AddVirtioNetHeader(pkt stack.PacketBufferPtr) {
   236  	virtio := header.VirtioNetHeader(pkt.VirtioNetHeader().Push(header.VirtioNetHeaderSize))
   237  	virtio.Encode(&header.VirtioNetHeaderFields{})
   238  }
   239  
   240  // +checklocks:e.mu
   241  func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) tcpip.Error {
   242  	if e.virtioNetHeaderRequired {
   243  		e.AddVirtioNetHeader(pkt)
   244  	}
   245  
   246  	ok := e.tx.transmit(pkt)
   247  	if !ok {
   248  		return &tcpip.ErrWouldBlock{}
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  // WritePacket writes outbound packets to the file descriptor. If it is not
   255  // currently writable, the packet is dropped.
   256  // WritePacket implements stack.LinkEndpoint.WritePacket.
   257  func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) tcpip.Error {
   258  	// Transmit the packet.
   259  	e.mu.Lock()
   260  	defer e.mu.Unlock()
   261  	if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
   262  		return err
   263  	}
   264  	e.tx.notify()
   265  	return nil
   266  }
   267  
   268  // WritePackets implements stack.LinkEndpoint.WritePackets.
   269  func (e *serverEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
   270  	n := 0
   271  	var err tcpip.Error
   272  	e.mu.Lock()
   273  	defer e.mu.Unlock()
   274  	for _, pkt := range pkts.AsSlice() {
   275  		if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
   276  			break
   277  		}
   278  		n++
   279  	}
   280  	// WritePackets never returns an error if it successfully transmitted at least
   281  	// one packet.
   282  	if err != nil && n == 0 {
   283  		return 0, err
   284  	}
   285  	e.tx.notify()
   286  	return n, nil
   287  }
   288  
   289  // dispatchLoop reads packets from the rx queue in a loop and dispatches them
   290  // to the network stack.
   291  func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
   292  	for e.stopRequested.Load() == 0 {
   293  		b := e.rx.receive()
   294  		if b == nil {
   295  			e.rx.EnableNotification()
   296  			// Now pull again to make sure we didn't receive any packets
   297  			// while notifications were not enabled.
   298  			for {
   299  				b = e.rx.receive()
   300  				if b != nil {
   301  					// Disable notifications as we only need to be notified when we are going
   302  					// to block on eventFD. This should prevent the peer from needlessly
   303  					// writing to eventFD when this end is already awake and processing
   304  					// packets.
   305  					e.rx.DisableNotification()
   306  					break
   307  				}
   308  				e.rx.waitForPackets()
   309  			}
   310  		}
   311  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   312  			Payload: buffer.MakeWithView(b),
   313  		})
   314  		if e.virtioNetHeaderRequired {
   315  			_, ok := pkt.VirtioNetHeader().Consume(header.VirtioNetHeaderSize)
   316  			if !ok {
   317  				pkt.DecRef()
   318  				continue
   319  			}
   320  		}
   321  		var proto tcpip.NetworkProtocolNumber
   322  		if len(e.addr) != 0 {
   323  			if !e.parseHeader(pkt) {
   324  				pkt.DecRef()
   325  				continue
   326  			}
   327  			proto = header.Ethernet(pkt.LinkHeader().Slice()).Type()
   328  		} else {
   329  			// We don't get any indication of what the packet is, so try to guess
   330  			// if it's an IPv4 or IPv6 packet.
   331  			// IP version information is at the first octet, so pulling up 1 byte.
   332  			h, ok := pkt.Data().PullUp(1)
   333  			if !ok {
   334  				pkt.DecRef()
   335  				continue
   336  			}
   337  			switch header.IPVersion(h) {
   338  			case header.IPv4Version:
   339  				proto = header.IPv4ProtocolNumber
   340  			case header.IPv6Version:
   341  				proto = header.IPv6ProtocolNumber
   342  			default:
   343  				pkt.DecRef()
   344  				continue
   345  			}
   346  		}
   347  		// Send packet up the stack.
   348  		d.DeliverNetworkPacket(proto, pkt)
   349  		pkt.DecRef()
   350  	}
   351  
   352  	e.mu.Lock()
   353  	defer e.mu.Unlock()
   354  
   355  	// Clean state.
   356  	e.tx.cleanup()
   357  	e.rx.cleanup()
   358  
   359  	e.completed.Done()
   360  }
   361  
   362  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
   363  func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
   364  	if e.hdrSize > 0 {
   365  		return header.ARPHardwareEther
   366  	}
   367  	return header.ARPHardwareNone
   368  }