gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/tcp/cubic.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"math"
    19  	"time"
    20  
    21  	"gvisor.dev/gvisor/pkg/tcpip"
    22  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    23  )
    24  
    25  // effectivelyInfinity is an initialization value used for round-trip times
    26  // that are then set using min.  It is equal to approximately 100 years: large
    27  // enough that it will always be greater than a real TCP round-trip time, and
    28  // small enough that it fits in time.Duration.
    29  const effectivelyInfinity = time.Duration(math.MaxInt64)
    30  
    31  const (
    32  	// RTT = round-trip time.
    33  
    34  	// The delay increase sensitivity is determined by minRTTThresh and
    35  	// maxRTTThresh. Smaller values of minRTTThresh may cause spurious exits
    36  	// from slow start. Larger values of maxRTTThresh may result in slow start
    37  	// not exiting until loss is encountered for connections on large RTT paths.
    38  	minRTTThresh = 4 * time.Millisecond
    39  	maxRTTThresh = 16 * time.Millisecond
    40  
    41  	// minRTTDivisor is a fraction of RTT to compute the delay threshold. A
    42  	// smaller value would mean a larger threshold and thus less sensitivity to
    43  	// delay increase, and vice versa.
    44  	minRTTDivisor = 8
    45  
    46  	// nRTTSample is the minimum number of RTT samples in the round before
    47  	// considering whether to exit the round due to increased RTT.
    48  	nRTTSample = 8
    49  
    50  	// ackDelta is the maximum time between ACKs for them to be considered part
    51  	// of the same ACK Train during HyStart
    52  	ackDelta = 2 * time.Millisecond
    53  )
    54  
    55  // cubicState stores the variables related to TCP CUBIC congestion
    56  // control algorithm state.
    57  //
    58  // See: https://tools.ietf.org/html/rfc8312.
    59  // +stateify savable
    60  type cubicState struct {
    61  	stack.TCPCubicState
    62  
    63  	// numCongestionEvents tracks the number of congestion events since last
    64  	// RTO.
    65  	numCongestionEvents int
    66  
    67  	s *sender
    68  }
    69  
    70  // newCubicCC returns a partially initialized cubic state with the constants
    71  // beta and c set and t set to current time.
    72  func newCubicCC(s *sender) *cubicState {
    73  	now := s.ep.stack.Clock().NowMonotonic()
    74  	return &cubicState{
    75  		TCPCubicState: stack.TCPCubicState{
    76  			T:    now,
    77  			Beta: 0.7,
    78  			C:    0.4,
    79  			// By this point, the sender has initialized it's initial sequence
    80  			// number.
    81  			EndSeq:     s.SndNxt,
    82  			LastRTT:    effectivelyInfinity,
    83  			CurrRTT:    effectivelyInfinity,
    84  			LastAck:    now,
    85  			RoundStart: now,
    86  		},
    87  		s: s,
    88  	}
    89  }
    90  
    91  // enterCongestionAvoidance is used to initialize cubic in cases where we exit
    92  // SlowStart without a real congestion event taking place. This can happen when
    93  // a connection goes back to slow start due to a retransmit and we exceed the
    94  // previously lowered ssThresh without experiencing packet loss.
    95  //
    96  // Refer: https://tools.ietf.org/html/rfc8312#section-4.8
    97  func (c *cubicState) enterCongestionAvoidance() {
    98  	// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
    99  	// https://tools.ietf.org/html/rfc8312#section-4.8
   100  	if c.numCongestionEvents == 0 {
   101  		c.K = 0
   102  		c.T = c.s.ep.stack.Clock().NowMonotonic()
   103  		c.WLastMax = c.WMax
   104  		c.WMax = float64(c.s.SndCwnd)
   105  	}
   106  }
   107  
   108  // updateHyStart tracks packet round-trip time (rtt) to find a safe threshold
   109  // to exit slow start without triggering packet loss.  It updates the SSThresh
   110  // when it does.
   111  //
   112  // Implementation of HyStart follows the algorithm from the Linux kernel, rather
   113  // than RFC 9406 (https://www.rfc-editor.org/rfc/rfc9406.html). Briefly, the
   114  // Linux kernel algorithm is based directly on the original HyStart paper
   115  // (https://doi.org/10.1016/j.comnet.2011.01.014), and differs from the RFC in
   116  // that two detection algorithms run in parallel ('ACK train' and 'Delay
   117  // increase').  The RFC version includes only the latter algorithm and adds an
   118  // intermediate phase called Conservative Slow Start, which is not implemented
   119  // here.
   120  func (c *cubicState) updateHyStart(rtt time.Duration) {
   121  	if rtt < 0 {
   122  		// negative indicates unknown
   123  		return
   124  	}
   125  	now := c.s.ep.stack.Clock().NowMonotonic()
   126  	if c.EndSeq.LessThan(c.s.SndUna) {
   127  		c.beginHyStartRound(now)
   128  	}
   129  	// ACK train
   130  	if now.Sub(c.LastAck) < ackDelta && // ensures acks are part of the same "train"
   131  		c.LastRTT < effectivelyInfinity {
   132  		c.LastAck = now
   133  		if thresh := c.LastRTT / 2; now.Sub(c.RoundStart) > thresh {
   134  			c.s.Ssthresh = c.s.SndCwnd
   135  		}
   136  	}
   137  
   138  	// Delay increase
   139  	c.CurrRTT = min(c.CurrRTT, rtt)
   140  	c.SampleCount++
   141  
   142  	if c.SampleCount >= nRTTSample && c.LastRTT < effectivelyInfinity {
   143  		// i.e. LastRTT/minRTTDivisor, but clamped to minRTTThresh & maxRTTThresh
   144  		thresh := max(
   145  			minRTTThresh,
   146  			min(maxRTTThresh, c.LastRTT/minRTTDivisor),
   147  		)
   148  		if c.CurrRTT >= (c.LastRTT + thresh) {
   149  			// Triggered HyStart safe exit threshold
   150  			c.s.Ssthresh = c.s.SndCwnd
   151  		}
   152  	}
   153  }
   154  
   155  func (c *cubicState) beginHyStartRound(now tcpip.MonotonicTime) {
   156  	c.EndSeq = c.s.SndNxt
   157  	c.SampleCount = 0
   158  	c.LastRTT = c.CurrRTT
   159  	c.CurrRTT = effectivelyInfinity
   160  	c.LastAck = now
   161  	c.RoundStart = now
   162  }
   163  
   164  // updateSlowStart will update the congestion window as per the slow-start
   165  // algorithm used by NewReno. If after adjusting the congestion window we cross
   166  // the ssThresh then it will return the number of packets that must be consumed
   167  // in congestion avoidance mode.
   168  func (c *cubicState) updateSlowStart(packetsAcked int) int {
   169  	// Don't let the congestion window cross into the congestion
   170  	// avoidance range.
   171  	newcwnd := c.s.SndCwnd + packetsAcked
   172  	enterCA := false
   173  	if newcwnd >= c.s.Ssthresh {
   174  		newcwnd = c.s.Ssthresh
   175  		c.s.SndCAAckCount = 0
   176  		enterCA = true
   177  	}
   178  
   179  	packetsAcked -= newcwnd - c.s.SndCwnd
   180  	c.s.SndCwnd = newcwnd
   181  	if enterCA {
   182  		c.enterCongestionAvoidance()
   183  	}
   184  	return packetsAcked
   185  }
   186  
   187  // Update updates cubic's internal state variables. It must be called on every
   188  // ACK received.
   189  // Refer: https://tools.ietf.org/html/rfc8312#section-4
   190  func (c *cubicState) Update(packetsAcked int, rtt time.Duration) {
   191  	if c.s.Ssthresh == InitialSsthresh && c.s.SndCwnd < c.s.Ssthresh {
   192  		c.updateHyStart(rtt)
   193  	}
   194  	if c.s.SndCwnd < c.s.Ssthresh {
   195  		packetsAcked = c.updateSlowStart(packetsAcked)
   196  		if packetsAcked == 0 {
   197  			return
   198  		}
   199  	} else {
   200  		c.s.rtt.Lock()
   201  		srtt := c.s.rtt.TCPRTTState.SRTT
   202  		c.s.rtt.Unlock()
   203  		c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt)
   204  	}
   205  }
   206  
   207  // cubicCwnd computes the CUBIC congestion window after t seconds from last
   208  // congestion event.
   209  func (c *cubicState) cubicCwnd(t float64) float64 {
   210  	return c.C*math.Pow(t, 3.0) + c.WMax
   211  }
   212  
   213  // getCwnd returns the current congestion window as computed by CUBIC.
   214  // Refer: https://tools.ietf.org/html/rfc8312#section-4
   215  func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
   216  	elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T)
   217  	elapsedSeconds := elapsed.Seconds()
   218  
   219  	// Compute the window as per Cubic after 'elapsed' time
   220  	// since last congestion event.
   221  	c.WC = c.cubicCwnd(elapsedSeconds - c.K)
   222  
   223  	// Compute the TCP friendly estimate of the congestion window.
   224  	c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds())
   225  
   226  	// Make sure in the TCP friendly region CUBIC performs at least
   227  	// as well as Reno.
   228  	if c.WC < c.WEst && float64(sndCwnd) < c.WEst {
   229  		// TCP Friendly region of cubic.
   230  		return int(c.WEst)
   231  	}
   232  
   233  	// In Concave/Convex region of CUBIC, calculate what CUBIC window
   234  	// will be after 1 RTT and use that to grow congestion window
   235  	// for every ack.
   236  	tEst := (elapsed + srtt).Seconds()
   237  	wtRtt := c.cubicCwnd(tEst - c.K)
   238  	// As per 4.3 for each received ACK cwnd must be incremented
   239  	// by (w_cubic(t+RTT) - cwnd/cwnd.
   240  	cwnd := float64(sndCwnd)
   241  	for i := 0; i < packetsAcked; i++ {
   242  		// Concave/Convex regions of cubic have the same formulas.
   243  		// See: https://tools.ietf.org/html/rfc8312#section-4.3
   244  		cwnd += (wtRtt - cwnd) / cwnd
   245  	}
   246  	return int(cwnd)
   247  }
   248  
   249  // HandleLossDetected implements congestionControl.HandleLossDetected.
   250  func (c *cubicState) HandleLossDetected() {
   251  	// See: https://tools.ietf.org/html/rfc8312#section-4.5
   252  	c.numCongestionEvents++
   253  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   254  	c.WLastMax = c.WMax
   255  	c.WMax = float64(c.s.SndCwnd)
   256  
   257  	c.fastConvergence()
   258  	c.reduceSlowStartThreshold()
   259  }
   260  
   261  // HandleRTOExpired implements congestionContrl.HandleRTOExpired.
   262  func (c *cubicState) HandleRTOExpired() {
   263  	// See: https://tools.ietf.org/html/rfc8312#section-4.6
   264  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   265  	c.numCongestionEvents = 0
   266  	c.WLastMax = c.WMax
   267  	c.WMax = float64(c.s.SndCwnd)
   268  
   269  	c.fastConvergence()
   270  
   271  	// We lost a packet, so reduce ssthresh.
   272  	c.reduceSlowStartThreshold()
   273  
   274  	// Reduce the congestion window to 1, i.e., enter slow-start. Per
   275  	// RFC 5681, page 7, we must use 1 regardless of the value of the
   276  	// initial congestion window.
   277  	c.s.SndCwnd = 1
   278  }
   279  
   280  // fastConvergence implements the logic for Fast Convergence algorithm as
   281  // described in https://tools.ietf.org/html/rfc8312#section-4.6.
   282  func (c *cubicState) fastConvergence() {
   283  	if c.WMax < c.WLastMax {
   284  		c.WLastMax = c.WMax
   285  		c.WMax = c.WMax * (1.0 + c.Beta) / 2.0
   286  	} else {
   287  		c.WLastMax = c.WMax
   288  	}
   289  	// Recompute k as wMax may have changed.
   290  	c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C)
   291  }
   292  
   293  // PostRecovery implements congestionControl.PostRecovery.
   294  func (c *cubicState) PostRecovery() {
   295  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   296  }
   297  
   298  // reduceSlowStartThreshold returns new SsThresh as described in
   299  // https://tools.ietf.org/html/rfc8312#section-4.7.
   300  func (c *cubicState) reduceSlowStartThreshold() {
   301  	c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0))
   302  }