github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/dispatcher.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"encoding/binary"
    19  	"math/rand"
    20  
    21  	"github.com/SagerNet/gvisor/pkg/sleep"
    22  	"github.com/SagerNet/gvisor/pkg/sync"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/hash/jenkins"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    27  )
    28  
    29  // epQueue is a queue of endpoints.
    30  type epQueue struct {
    31  	mu   sync.Mutex
    32  	list endpointList
    33  }
    34  
    35  // enqueue adds e to the queue if the endpoint is not already on the queue.
    36  func (q *epQueue) enqueue(e *endpoint) {
    37  	q.mu.Lock()
    38  	if e.pendingProcessing {
    39  		q.mu.Unlock()
    40  		return
    41  	}
    42  	q.list.PushBack(e)
    43  	e.pendingProcessing = true
    44  	q.mu.Unlock()
    45  }
    46  
    47  // dequeue removes and returns the first element from the queue if available,
    48  // returns nil otherwise.
    49  func (q *epQueue) dequeue() *endpoint {
    50  	q.mu.Lock()
    51  	if e := q.list.Front(); e != nil {
    52  		q.list.Remove(e)
    53  		e.pendingProcessing = false
    54  		q.mu.Unlock()
    55  		return e
    56  	}
    57  	q.mu.Unlock()
    58  	return nil
    59  }
    60  
    61  // empty returns true if the queue is empty, false otherwise.
    62  func (q *epQueue) empty() bool {
    63  	q.mu.Lock()
    64  	v := q.list.Empty()
    65  	q.mu.Unlock()
    66  	return v
    67  }
    68  
    69  // processor is responsible for processing packets queued to a tcp endpoint.
    70  type processor struct {
    71  	epQ              epQueue
    72  	sleeper          sleep.Sleeper
    73  	newEndpointWaker sleep.Waker
    74  	closeWaker       sleep.Waker
    75  }
    76  
    77  func (p *processor) close() {
    78  	p.closeWaker.Assert()
    79  }
    80  
    81  func (p *processor) queueEndpoint(ep *endpoint) {
    82  	// Queue an endpoint for processing by the processor goroutine.
    83  	p.epQ.enqueue(ep)
    84  	p.newEndpointWaker.Assert()
    85  }
    86  
    87  const (
    88  	newEndpointWaker = 1
    89  	closeWaker       = 2
    90  )
    91  
    92  func (p *processor) start(wg *sync.WaitGroup) {
    93  	defer wg.Done()
    94  	defer p.sleeper.Done()
    95  
    96  	for {
    97  		if id, _ := p.sleeper.Fetch(true); id == closeWaker {
    98  			break
    99  		}
   100  		for {
   101  			ep := p.epQ.dequeue()
   102  			if ep == nil {
   103  				break
   104  			}
   105  			if ep.segmentQueue.empty() {
   106  				continue
   107  			}
   108  
   109  			// If socket has transitioned out of connected state then just let the
   110  			// worker handle the packet.
   111  			//
   112  			// NOTE: We read this outside of e.mu lock which means that by the time
   113  			// we get to handleSegments the endpoint may not be in ESTABLISHED. But
   114  			// this should be fine as all normal shutdown states are handled by
   115  			// handleSegments and if the endpoint moves to a CLOSED/ERROR state
   116  			// then handleSegments is a noop.
   117  			if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
   118  				// If the endpoint is in a connected state then we do direct delivery
   119  				// to ensure low latency and avoid scheduler interactions.
   120  				switch err := ep.handleSegmentsLocked(true /* fastPath */); {
   121  				case err != nil:
   122  					// Send any active resets if required.
   123  					ep.resetConnectionLocked(err)
   124  					fallthrough
   125  				case ep.EndpointState() == StateClose:
   126  					ep.notifyProtocolGoroutine(notifyTickleWorker)
   127  				case !ep.segmentQueue.empty():
   128  					p.epQ.enqueue(ep)
   129  				}
   130  				ep.mu.Unlock() // +checklocksforce
   131  			} else {
   132  				ep.newSegmentWaker.Assert()
   133  			}
   134  		}
   135  	}
   136  }
   137  
   138  // dispatcher manages a pool of TCP endpoint processors which are responsible
   139  // for the processing of inbound segments. This fixed pool of processor
   140  // goroutines do full tcp processing. The processor is selected based on the
   141  // hash of the endpoint id to ensure that delivery for the same endpoint happens
   142  // in-order.
   143  type dispatcher struct {
   144  	processors []processor
   145  	// seed is a random secret for a jenkins hash.
   146  	seed uint32
   147  	wg   sync.WaitGroup
   148  }
   149  
   150  func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
   151  	d.close()
   152  	d.wait()
   153  	d.processors = make([]processor, nProcessors)
   154  	d.seed = rng.Uint32()
   155  	for i := range d.processors {
   156  		p := &d.processors[i]
   157  		p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
   158  		p.sleeper.AddWaker(&p.closeWaker, closeWaker)
   159  		d.wg.Add(1)
   160  		// NB: sleeper-waker registration must happen synchronously to avoid races
   161  		// with `close`.  It's possible to pull all this logic into `start`, but
   162  		// that results in a heap-allocated function literal.
   163  		go p.start(&d.wg)
   164  	}
   165  }
   166  
   167  func (d *dispatcher) close() {
   168  	for i := range d.processors {
   169  		d.processors[i].close()
   170  	}
   171  }
   172  
   173  func (d *dispatcher) wait() {
   174  	d.wg.Wait()
   175  }
   176  
   177  func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) {
   178  	ep := stackEP.(*endpoint)
   179  
   180  	s := newIncomingSegment(id, clock, pkt)
   181  	if !s.parse(pkt.RXTransportChecksumValidated) {
   182  		ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
   183  		ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   184  		s.decRef()
   185  		return
   186  	}
   187  
   188  	if !s.csumValid {
   189  		ep.stack.Stats().TCP.ChecksumErrors.Increment()
   190  		ep.stats.ReceiveErrors.ChecksumErrors.Increment()
   191  		s.decRef()
   192  		return
   193  	}
   194  
   195  	ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
   196  	ep.stats.SegmentsReceived.Increment()
   197  	if (s.flags & header.TCPFlagRst) != 0 {
   198  		ep.stack.Stats().TCP.ResetsReceived.Increment()
   199  	}
   200  
   201  	if !ep.enqueueSegment(s) {
   202  		s.decRef()
   203  		return
   204  	}
   205  
   206  	// For sockets not in established state let the worker goroutine
   207  	// handle the packets.
   208  	if ep.EndpointState() != StateEstablished {
   209  		ep.newSegmentWaker.Assert()
   210  		return
   211  	}
   212  
   213  	d.selectProcessor(id).queueEndpoint(ep)
   214  }
   215  
   216  func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
   217  	var payload [4]byte
   218  	binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
   219  	binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
   220  
   221  	h := jenkins.Sum32(d.seed)
   222  	h.Write(payload[:])
   223  	h.Write([]byte(id.LocalAddress))
   224  	h.Write([]byte(id.RemoteAddress))
   225  
   226  	return &d.processors[h.Sum32()%uint32(len(d.processors))]
   227  }