gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/link/fdbased/processors.go (about)

     1  // Copyright 2024 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  package fdbased
    19  
    20  import (
    21  	"encoding/binary"
    22  
    23  	"gvisor.dev/gvisor/pkg/rand"
    24  	"gvisor.dev/gvisor/pkg/sleep"
    25  	"gvisor.dev/gvisor/pkg/sync"
    26  	"gvisor.dev/gvisor/pkg/tcpip"
    27  	"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
    28  	"gvisor.dev/gvisor/pkg/tcpip/header"
    29  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    30  	"gvisor.dev/gvisor/pkg/tcpip/stack/gro"
    31  )
    32  
    33  type processor struct {
    34  	mu sync.Mutex
    35  	// +checklocks:mu
    36  	pkts stack.PacketBufferList
    37  
    38  	e           *endpoint
    39  	gro         gro.GRO
    40  	sleeper     sleep.Sleeper
    41  	packetWaker sleep.Waker
    42  	closeWaker  sleep.Waker
    43  }
    44  
    45  func (p *processor) start(wg *sync.WaitGroup) {
    46  	defer wg.Done()
    47  	defer p.sleeper.Done()
    48  	for {
    49  		switch w := p.sleeper.Fetch(true); {
    50  		case w == &p.packetWaker:
    51  			p.deliverPackets()
    52  		case w == &p.closeWaker:
    53  			p.mu.Lock()
    54  			p.pkts.Reset()
    55  			p.mu.Unlock()
    56  			return
    57  		}
    58  	}
    59  }
    60  
    61  func (p *processor) deliverPackets() {
    62  	p.e.mu.RLock()
    63  	p.gro.Dispatcher = p.e.dispatcher
    64  	p.e.mu.RUnlock()
    65  	if p.gro.Dispatcher == nil {
    66  		p.mu.Lock()
    67  		p.pkts.Reset()
    68  		p.mu.Unlock()
    69  		return
    70  	}
    71  
    72  	p.mu.Lock()
    73  	for p.pkts.Len() > 0 {
    74  		pkt := p.pkts.PopFront()
    75  		p.mu.Unlock()
    76  		p.gro.Enqueue(pkt)
    77  		pkt.DecRef()
    78  		p.mu.Lock()
    79  	}
    80  	p.mu.Unlock()
    81  	p.gro.Flush()
    82  }
    83  
    84  // processorManager handles starting, closing, and queuing packets on processor
    85  // goroutines.
    86  type processorManager struct {
    87  	processors []processor
    88  	seed       uint32
    89  	wg         sync.WaitGroup
    90  	e          *endpoint
    91  	ready      []bool
    92  }
    93  
    94  // newProcessorManager creates a new processor manager.
    95  func newProcessorManager(opts *Options, e *endpoint) *processorManager {
    96  	m := &processorManager{}
    97  	m.seed = rand.Uint32()
    98  	m.ready = make([]bool, opts.ProcessorsPerChannel)
    99  	m.processors = make([]processor, opts.ProcessorsPerChannel)
   100  	m.e = e
   101  	m.wg.Add(opts.ProcessorsPerChannel)
   102  
   103  	for i := range m.processors {
   104  		p := &m.processors[i]
   105  		p.sleeper.AddWaker(&p.packetWaker)
   106  		p.sleeper.AddWaker(&p.closeWaker)
   107  		p.gro.Init(opts.GRO)
   108  		p.e = e
   109  	}
   110  
   111  	return m
   112  }
   113  
   114  // start starts the processor goroutines if the processor manager is configured
   115  // with more than one processor.
   116  func (m *processorManager) start() {
   117  	for i := range m.processors {
   118  		p := &m.processors[i]
   119  		// Only start processor in a separate goroutine if we have multiple of them.
   120  		if len(m.processors) > 1 {
   121  			go p.start(&m.wg)
   122  		}
   123  	}
   124  }
   125  
   126  func (m *processorManager) connectionHash(cid *connectionID) uint32 {
   127  	var payload [4]byte
   128  	binary.LittleEndian.PutUint16(payload[0:], cid.srcPort)
   129  	binary.LittleEndian.PutUint16(payload[2:], cid.dstPort)
   130  
   131  	h := jenkins.Sum32(m.seed)
   132  	h.Write(payload[:])
   133  	h.Write(cid.srcAddr)
   134  	h.Write(cid.dstAddr)
   135  	return h.Sum32()
   136  }
   137  
   138  // queuePacket queues a packet to be delivered to the appropriate processor.
   139  func (m *processorManager) queuePacket(pkt *stack.PacketBuffer, hasEthHeader bool) {
   140  	var pIdx int
   141  	cid, nonConnectionPkt := tcpipConnectionID(pkt)
   142  	if !hasEthHeader {
   143  		if nonConnectionPkt {
   144  			// If there's no eth header this should be a standard tcpip packet. If
   145  			// it isn't the packet is invalid so drop it.
   146  			return
   147  		}
   148  		pkt.NetworkProtocolNumber = cid.proto
   149  	}
   150  	if len(m.processors) == 1 || nonConnectionPkt {
   151  		// If the packet is not associated with an active connection, use the
   152  		// first processor.
   153  		pIdx = 0
   154  	} else {
   155  		pIdx = int(m.connectionHash(&cid)) % len(m.processors)
   156  	}
   157  	p := &m.processors[pIdx]
   158  	p.mu.Lock()
   159  	defer p.mu.Unlock()
   160  	pkt.IncRef()
   161  	p.pkts.PushBack(pkt)
   162  	m.ready[pIdx] = true
   163  }
   164  
   165  type connectionID struct {
   166  	srcAddr, dstAddr []byte
   167  	srcPort, dstPort uint16
   168  	proto            tcpip.NetworkProtocolNumber
   169  }
   170  
   171  // tcpipConnectionID returns a tcpip connection id tuple based on the data found
   172  // in the packet. It returns true if the packet is not associated with an active
   173  // connection (e.g ARP, NDP, etc). The method assumes link headers have already
   174  // been processed if they were present.
   175  func tcpipConnectionID(pkt *stack.PacketBuffer) (connectionID, bool) {
   176  	var cid connectionID
   177  	h, ok := pkt.Data().PullUp(1)
   178  	if !ok {
   179  		// Skip this packet.
   180  		return cid, true
   181  	}
   182  
   183  	const tcpSrcDstPortLen = 4
   184  	switch header.IPVersion(h) {
   185  	case header.IPv4Version:
   186  		hdrLen := header.IPv4(h).HeaderLength()
   187  		h, ok = pkt.Data().PullUp(int(hdrLen) + tcpSrcDstPortLen)
   188  		if !ok {
   189  			return cid, true
   190  		}
   191  		ipHdr := header.IPv4(h[:hdrLen])
   192  		tcpHdr := header.TCP(h[hdrLen:][:tcpSrcDstPortLen])
   193  
   194  		cid.srcAddr = ipHdr.SourceAddressSlice()
   195  		cid.dstAddr = ipHdr.DestinationAddressSlice()
   196  		cid.srcPort = tcpHdr.SourcePort()
   197  		cid.dstPort = tcpHdr.DestinationPort()
   198  		cid.proto = header.IPv4ProtocolNumber
   199  	case header.IPv6Version:
   200  		h, ok = pkt.Data().PullUp(header.IPv6FixedHeaderSize + tcpSrcDstPortLen)
   201  		if !ok {
   202  			return cid, true
   203  		}
   204  		ipHdr := header.IPv6(h)
   205  
   206  		var tcpHdr header.TCP
   207  		if tcpip.TransportProtocolNumber(ipHdr.NextHeader()) == header.TCPProtocolNumber {
   208  			tcpHdr = header.TCP(h[header.IPv6FixedHeaderSize:][:tcpSrcDstPortLen])
   209  		} else {
   210  			// Slow path for IPv6 extension headers :(.
   211  			dataBuf := pkt.Data().ToBuffer()
   212  			dataBuf.TrimFront(header.IPv6MinimumSize)
   213  			it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf)
   214  			defer it.Release()
   215  			for {
   216  				hdr, done, err := it.Next()
   217  				if done || err != nil {
   218  					break
   219  				}
   220  				hdr.Release()
   221  			}
   222  			h, ok = pkt.Data().PullUp(int(it.HeaderOffset()) + tcpSrcDstPortLen)
   223  			if !ok {
   224  				return cid, true
   225  			}
   226  			tcpHdr = header.TCP(h[it.HeaderOffset():][:tcpSrcDstPortLen])
   227  		}
   228  		cid.srcAddr = ipHdr.SourceAddressSlice()
   229  		cid.dstAddr = ipHdr.DestinationAddressSlice()
   230  		cid.srcPort = tcpHdr.SourcePort()
   231  		cid.dstPort = tcpHdr.DestinationPort()
   232  		cid.proto = header.IPv6ProtocolNumber
   233  	default:
   234  		return cid, true
   235  	}
   236  	return cid, false
   237  }
   238  
   239  func (m *processorManager) close() {
   240  	if len(m.processors) < 2 {
   241  		return
   242  	}
   243  	for i := range m.processors {
   244  		p := &m.processors[i]
   245  		p.closeWaker.Assert()
   246  	}
   247  }
   248  
   249  // wakeReady wakes up all processors that have a packet queued. If there is only
   250  // one processor, the method delivers the packet inline without waking a
   251  // goroutine.
   252  func (m *processorManager) wakeReady() {
   253  	for i, ready := range m.ready {
   254  		if !ready {
   255  			continue
   256  		}
   257  		p := &m.processors[i]
   258  		if len(m.processors) > 1 {
   259  			p.packetWaker.Assert()
   260  		} else {
   261  			p.deliverPackets()
   262  		}
   263  		m.ready[i] = false
   264  	}
   265  }