github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/tcp/dispatcher.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"encoding/binary"
    19  	"fmt"
    20  	"math/rand"
    21  
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/sleep"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/hash/jenkins"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    28  	"github.com/nicocha30/gvisor-ligolo/pkg/waiter"
    29  )
    30  
    31  // epQueue is a queue of endpoints.
    32  type epQueue struct {
    33  	mu   sync.Mutex
    34  	list endpointList
    35  }
    36  
    37  // enqueue adds e to the queue if the endpoint is not already on the queue.
    38  func (q *epQueue) enqueue(e *endpoint) {
    39  	q.mu.Lock()
    40  	defer q.mu.Unlock()
    41  	e.pendingProcessingMu.Lock()
    42  	defer e.pendingProcessingMu.Unlock()
    43  
    44  	if e.pendingProcessing {
    45  		return
    46  	}
    47  	q.list.PushBack(e)
    48  	e.pendingProcessing = true
    49  }
    50  
    51  // dequeue removes and returns the first element from the queue if available,
    52  // returns nil otherwise.
    53  func (q *epQueue) dequeue() *endpoint {
    54  	q.mu.Lock()
    55  	if e := q.list.Front(); e != nil {
    56  		q.list.Remove(e)
    57  		e.pendingProcessingMu.Lock()
    58  		e.pendingProcessing = false
    59  		e.pendingProcessingMu.Unlock()
    60  		q.mu.Unlock()
    61  		return e
    62  	}
    63  	q.mu.Unlock()
    64  	return nil
    65  }
    66  
    67  // empty returns true if the queue is empty, false otherwise.
    68  func (q *epQueue) empty() bool {
    69  	q.mu.Lock()
    70  	v := q.list.Empty()
    71  	q.mu.Unlock()
    72  	return v
    73  }
    74  
    75  // processor is responsible for processing packets queued to a tcp endpoint.
    76  type processor struct {
    77  	epQ              epQueue
    78  	sleeper          sleep.Sleeper
    79  	newEndpointWaker sleep.Waker
    80  	closeWaker       sleep.Waker
    81  	pauseWaker       sleep.Waker
    82  	pauseChan        chan struct{}
    83  	resumeChan       chan struct{}
    84  }
    85  
    86  func (p *processor) close() {
    87  	p.closeWaker.Assert()
    88  }
    89  
    90  func (p *processor) queueEndpoint(ep *endpoint) {
    91  	// Queue an endpoint for processing by the processor goroutine.
    92  	p.epQ.enqueue(ep)
    93  	p.newEndpointWaker.Assert()
    94  }
    95  
    96  // deliverAccepted delivers a passively connected endpoint to the accept queue
    97  // of its associated listening endpoint.
    98  //
    99  // +checklocks:ep.mu
   100  func deliverAccepted(ep *endpoint) bool {
   101  	lEP := ep.h.listenEP
   102  	lEP.acceptMu.Lock()
   103  
   104  	// Remove endpoint from list of pendingEndpoints as the handshake is now
   105  	// complete.
   106  	delete(lEP.acceptQueue.pendingEndpoints, ep)
   107  	// Deliver this endpoint to the listening socket's accept queue.
   108  	if lEP.acceptQueue.capacity == 0 {
   109  		lEP.acceptMu.Unlock()
   110  		return false
   111  	}
   112  
   113  	// NOTE: We always queue the endpoint and on purpose do not check if
   114  	// accept queue is full at this point. This is similar to linux because
   115  	// two racing incoming ACK's can both pass the acceptQueue.isFull check
   116  	// and proceed to ESTABLISHED state. In such a case its better to
   117  	// deliver both even if it temporarily exceeds the queue limit rather
   118  	// than drop a connection that is fully connected.
   119  	//
   120  	// For reference see:
   121  	//    https://github.com/torvalds/linux/blob/169e77764adc041b1dacba84ea90516a895d43b2/net/ipv4/tcp_minisocks.c#L764
   122  	//    https://github.com/torvalds/linux/blob/169e77764adc041b1dacba84ea90516a895d43b2/net/ipv4/tcp_ipv4.c#L1500
   123  	lEP.acceptQueue.endpoints.PushBack(ep)
   124  	lEP.acceptMu.Unlock()
   125  	ep.h.listenEP.waiterQueue.Notify(waiter.ReadableEvents)
   126  
   127  	return true
   128  }
   129  
   130  // handleConnecting is responsible for TCP processing for an endpoint in one of
   131  // the connecting states.
   132  func (p *processor) handleConnecting(ep *endpoint) {
   133  	if !ep.TryLock() {
   134  		return
   135  	}
   136  	cleanup := func() {
   137  		ep.mu.Unlock()
   138  		ep.drainClosingSegmentQueue()
   139  		ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   140  	}
   141  	if !ep.EndpointState().connecting() {
   142  		// If the endpoint has already transitioned out of a connecting
   143  		// stage then just return (only possible if it was closed or
   144  		// timed out by the time we got around to processing the wakeup.
   145  		ep.mu.Unlock()
   146  		return
   147  	}
   148  	if err := ep.h.processSegments(); err != nil { // +checklocksforce:ep.h.ep.mu
   149  		// handshake failed. clean up the tcp endpoint and handshake
   150  		// state.
   151  		if lEP := ep.h.listenEP; lEP != nil {
   152  			lEP.acceptMu.Lock()
   153  			delete(lEP.acceptQueue.pendingEndpoints, ep)
   154  			lEP.acceptMu.Unlock()
   155  		}
   156  		ep.handshakeFailed(err)
   157  		cleanup()
   158  		return
   159  	}
   160  
   161  	if ep.EndpointState() == StateEstablished && ep.h.listenEP != nil {
   162  		ep.isConnectNotified = true
   163  		ep.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
   164  		if !deliverAccepted(ep) {
   165  			ep.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
   166  			cleanup()
   167  			return
   168  		}
   169  	}
   170  	ep.mu.Unlock()
   171  }
   172  
   173  // handleConnected is responsible for TCP processing for an endpoint in one of
   174  // the connected states(StateEstablished, StateFinWait1 etc.)
   175  func (p *processor) handleConnected(ep *endpoint) {
   176  	if !ep.TryLock() {
   177  		return
   178  	}
   179  
   180  	if !ep.EndpointState().connected() {
   181  		// If the endpoint has already transitioned out of a connected
   182  		// state then just return (only possible if it was closed or
   183  		// timed out by the time we got around to processing the wakeup.
   184  		ep.mu.Unlock()
   185  		return
   186  	}
   187  
   188  	// NOTE: We read this outside of e.mu lock which means that by the time
   189  	// we get to handleSegments the endpoint may not be in ESTABLISHED. But
   190  	// this should be fine as all normal shutdown states are handled by
   191  	// handleSegmentsLocked.
   192  	switch err := ep.handleSegmentsLocked(); {
   193  	case err != nil:
   194  		// Send any active resets if required.
   195  		ep.resetConnectionLocked(err)
   196  		fallthrough
   197  	case ep.EndpointState() == StateClose:
   198  		ep.mu.Unlock()
   199  		ep.drainClosingSegmentQueue()
   200  		ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   201  		return
   202  	case ep.EndpointState() == StateTimeWait:
   203  		p.startTimeWait(ep)
   204  	}
   205  	ep.mu.Unlock()
   206  }
   207  
   208  // startTimeWait starts a new goroutine to handle TIME-WAIT.
   209  //
   210  // +checklocks:ep.mu
   211  func (p *processor) startTimeWait(ep *endpoint) {
   212  	// Disable close timer as we are now entering real TIME_WAIT.
   213  	if ep.finWait2Timer != nil {
   214  		ep.finWait2Timer.Stop()
   215  	}
   216  	// Wake up any waiters before we start TIME-WAIT.
   217  	ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   218  	timeWaitDuration := ep.getTimeWaitDuration()
   219  	ep.timeWaitTimer = ep.stack.Clock().AfterFunc(timeWaitDuration, ep.timeWaitTimerExpired)
   220  }
   221  
   222  // handleTimeWait is responsible for TCP processing for an endpoint in TIME-WAIT
   223  // state.
   224  func (p *processor) handleTimeWait(ep *endpoint) {
   225  	if !ep.TryLock() {
   226  		return
   227  	}
   228  
   229  	if ep.EndpointState() != StateTimeWait {
   230  		// If the endpoint has already transitioned out of a TIME-WAIT
   231  		// state then just return (only possible if it was closed or
   232  		// timed out by the time we got around to processing the wakeup.
   233  		ep.mu.Unlock()
   234  		return
   235  	}
   236  
   237  	extendTimeWait, reuseTW := ep.handleTimeWaitSegments()
   238  	if reuseTW != nil {
   239  		ep.transitionToStateCloseLocked()
   240  		ep.mu.Unlock()
   241  		ep.drainClosingSegmentQueue()
   242  		ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   243  		reuseTW()
   244  		return
   245  	}
   246  	if extendTimeWait {
   247  		ep.timeWaitTimer.Reset(ep.getTimeWaitDuration())
   248  	}
   249  	ep.mu.Unlock()
   250  }
   251  
   252  // handleListen is responsible for TCP processing for an endpoint in LISTEN
   253  // state.
   254  func (p *processor) handleListen(ep *endpoint) {
   255  	if !ep.TryLock() {
   256  		return
   257  	}
   258  	defer ep.mu.Unlock()
   259  
   260  	if ep.EndpointState() != StateListen {
   261  		// If the endpoint has already transitioned out of a LISTEN
   262  		// state then just return (only possible if it was closed or
   263  		// shutdown).
   264  		return
   265  	}
   266  
   267  	for i := 0; i < maxSegmentsPerWake; i++ {
   268  		s := ep.segmentQueue.dequeue()
   269  		if s == nil {
   270  			break
   271  		}
   272  
   273  		// TODO(gvisor.dev/issue/4690): Better handle errors instead of
   274  		// silently dropping.
   275  		_ = ep.handleListenSegment(ep.listenCtx, s)
   276  		s.DecRef()
   277  	}
   278  }
   279  
   280  // start runs the main loop for a processor which is responsible for all TCP
   281  // processing for TCP endpoints.
   282  func (p *processor) start(wg *sync.WaitGroup) {
   283  	defer wg.Done()
   284  	defer p.sleeper.Done()
   285  
   286  	for {
   287  		switch w := p.sleeper.Fetch(true); {
   288  		case w == &p.closeWaker:
   289  			return
   290  		case w == &p.pauseWaker:
   291  			if !p.epQ.empty() {
   292  				p.newEndpointWaker.Assert()
   293  				p.pauseWaker.Assert()
   294  				continue
   295  			} else {
   296  				p.pauseChan <- struct{}{}
   297  				<-p.resumeChan
   298  			}
   299  		case w == &p.newEndpointWaker:
   300  			for {
   301  				ep := p.epQ.dequeue()
   302  				if ep == nil {
   303  					break
   304  				}
   305  				if ep.segmentQueue.empty() {
   306  					continue
   307  				}
   308  				switch state := ep.EndpointState(); {
   309  				case state.connecting():
   310  					p.handleConnecting(ep)
   311  				case state.connected() && state != StateTimeWait:
   312  					p.handleConnected(ep)
   313  				case state == StateTimeWait:
   314  					p.handleTimeWait(ep)
   315  				case state == StateListen:
   316  					p.handleListen(ep)
   317  				case state == StateError || state == StateClose:
   318  					// Try to redeliver any still queued
   319  					// packets to another endpoint or send a
   320  					// RST if it can't be delivered.
   321  					ep.mu.Lock()
   322  					if st := ep.EndpointState(); st == StateError || st == StateClose {
   323  						ep.drainClosingSegmentQueue()
   324  					}
   325  					ep.mu.Unlock()
   326  				default:
   327  					panic(fmt.Sprintf("unexpected tcp state in processor: %v", state))
   328  				}
   329  				// If there are more segments to process and the
   330  				// endpoint lock is not held by user then
   331  				// requeue this endpoint for processing.
   332  				if !ep.segmentQueue.empty() && !ep.isOwnedByUser() {
   333  					p.epQ.enqueue(ep)
   334  				}
   335  			}
   336  		}
   337  	}
   338  }
   339  
   340  // pause pauses the processor loop.
   341  func (p *processor) pause() chan struct{} {
   342  	p.pauseWaker.Assert()
   343  	return p.pauseChan
   344  }
   345  
   346  // resume resumes a previously paused loop.
   347  //
   348  // Precondition: Pause must have been called previously.
   349  func (p *processor) resume() {
   350  	p.resumeChan <- struct{}{}
   351  }
   352  
   353  // dispatcher manages a pool of TCP endpoint processors which are responsible
   354  // for the processing of inbound segments. This fixed pool of processor
   355  // goroutines do full tcp processing. The processor is selected based on the
   356  // hash of the endpoint id to ensure that delivery for the same endpoint happens
   357  // in-order.
   358  type dispatcher struct {
   359  	processors []processor
   360  	wg         sync.WaitGroup
   361  	hasher     jenkinsHasher
   362  	mu         sync.Mutex
   363  	// +checklocks:mu
   364  	paused bool
   365  	// +checklocks:mu
   366  	closed bool
   367  }
   368  
   369  // init initializes a dispatcher and starts the main loop for all the processors
   370  // owned by this dispatcher.
   371  func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
   372  	d.close()
   373  	d.wait()
   374  
   375  	d.mu.Lock()
   376  	defer d.mu.Unlock()
   377  	d.closed = false
   378  	d.processors = make([]processor, nProcessors)
   379  	d.hasher = jenkinsHasher{seed: rng.Uint32()}
   380  	for i := range d.processors {
   381  		p := &d.processors[i]
   382  		p.sleeper.AddWaker(&p.newEndpointWaker)
   383  		p.sleeper.AddWaker(&p.closeWaker)
   384  		p.sleeper.AddWaker(&p.pauseWaker)
   385  		p.pauseChan = make(chan struct{})
   386  		p.resumeChan = make(chan struct{})
   387  		d.wg.Add(1)
   388  		// NB: sleeper-waker registration must happen synchronously to avoid races
   389  		// with `close`.  It's possible to pull all this logic into `start`, but
   390  		// that results in a heap-allocated function literal.
   391  		go p.start(&d.wg)
   392  	}
   393  }
   394  
   395  // close closes a dispatcher and its processors.
   396  func (d *dispatcher) close() {
   397  	d.mu.Lock()
   398  	d.closed = true
   399  	d.mu.Unlock()
   400  	for i := range d.processors {
   401  		d.processors[i].close()
   402  	}
   403  }
   404  
   405  // wait waits for all processor goroutines to end.
   406  func (d *dispatcher) wait() {
   407  	d.wg.Wait()
   408  }
   409  
   410  // queuePacket queues an incoming packet to the matching tcp endpoint and
   411  // also queues the endpoint to a processor queue for processing.
   412  func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt stack.PacketBufferPtr) {
   413  	d.mu.Lock()
   414  	closed := d.closed
   415  	d.mu.Unlock()
   416  
   417  	if closed {
   418  		return
   419  	}
   420  
   421  	ep := stackEP.(*endpoint)
   422  
   423  	s, err := newIncomingSegment(id, clock, pkt)
   424  	if err != nil {
   425  		ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
   426  		ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   427  		return
   428  	}
   429  	defer s.DecRef()
   430  
   431  	if !s.csumValid {
   432  		ep.stack.Stats().TCP.ChecksumErrors.Increment()
   433  		ep.stats.ReceiveErrors.ChecksumErrors.Increment()
   434  		return
   435  	}
   436  
   437  	ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
   438  	ep.stats.SegmentsReceived.Increment()
   439  	if (s.flags & header.TCPFlagRst) != 0 {
   440  		ep.stack.Stats().TCP.ResetsReceived.Increment()
   441  	}
   442  
   443  	if !ep.enqueueSegment(s) {
   444  		return
   445  	}
   446  
   447  	// Only wakeup the processor if endpoint lock is not held by a user
   448  	// goroutine as endpoint.UnlockUser will wake up the processor if the
   449  	// segment queue is not empty.
   450  	if !ep.isOwnedByUser() {
   451  		d.selectProcessor(id).queueEndpoint(ep)
   452  	}
   453  }
   454  
   455  // selectProcessor uses a hash of the transport endpoint ID to queue the
   456  // endpoint to a specific processor. This is required to main TCP ordering as
   457  // queueing the same endpoint to multiple processors can *potentially* result in
   458  // out of order processing of incoming segments. It also ensures that a dispatcher
   459  // evenly loads the processor goroutines.
   460  func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
   461  	return &d.processors[d.hasher.hash(id)%uint32(len(d.processors))]
   462  }
   463  
   464  // pause pauses a dispatcher and all its processor goroutines.
   465  func (d *dispatcher) pause() {
   466  	d.mu.Lock()
   467  	d.paused = true
   468  	d.mu.Unlock()
   469  	for i := range d.processors {
   470  		<-d.processors[i].pause()
   471  	}
   472  }
   473  
   474  // resume resumes a previously paused dispatcher and its processor goroutines.
   475  // Calling resume on a dispatcher that was never paused is a no-op.
   476  func (d *dispatcher) resume() {
   477  	d.mu.Lock()
   478  
   479  	if !d.paused {
   480  		// If this was a restore run the stack is a new instance and
   481  		// it was never paused, so just return as there is nothing to
   482  		// resume.
   483  		d.mu.Unlock()
   484  		return
   485  	}
   486  	d.paused = false
   487  	d.mu.Unlock()
   488  	for i := range d.processors {
   489  		d.processors[i].resume()
   490  	}
   491  }
   492  
   493  // jenkinsHasher contains state needed to for a jenkins hash.
   494  type jenkinsHasher struct {
   495  	seed uint32
   496  }
   497  
   498  // hash hashes the provided TransportEndpointID using the jenkins hash
   499  // algorithm.
   500  func (j jenkinsHasher) hash(id stack.TransportEndpointID) uint32 {
   501  	var payload [4]byte
   502  	binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
   503  	binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
   504  
   505  	h := jenkins.Sum32(j.seed)
   506  	h.Write(payload[:])
   507  	h.Write(id.LocalAddress.AsSlice())
   508  	h.Write(id.RemoteAddress.AsSlice())
   509  	return h.Sum32()
   510  }