github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/tcp/cubic.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"math"
    19  	"time"
    20  
    21  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    22  )
    23  
    24  // cubicState stores the variables related to TCP CUBIC congestion
    25  // control algorithm state.
    26  //
    27  // See: https://tools.ietf.org/html/rfc8312.
    28  // +stateify savable
    29  type cubicState struct {
    30  	stack.TCPCubicState
    31  
    32  	// numCongestionEvents tracks the number of congestion events since last
    33  	// RTO.
    34  	numCongestionEvents int
    35  
    36  	s *sender
    37  }
    38  
    39  // newCubicCC returns a partially initialized cubic state with the constants
    40  // beta and c set and t set to current time.
    41  func newCubicCC(s *sender) *cubicState {
    42  	return &cubicState{
    43  		TCPCubicState: stack.TCPCubicState{
    44  			T:    s.ep.stack.Clock().NowMonotonic(),
    45  			Beta: 0.7,
    46  			C:    0.4,
    47  		},
    48  		s: s,
    49  	}
    50  }
    51  
    52  // enterCongestionAvoidance is used to initialize cubic in cases where we exit
    53  // SlowStart without a real congestion event taking place. This can happen when
    54  // a connection goes back to slow start due to a retransmit and we exceed the
    55  // previously lowered ssThresh without experiencing packet loss.
    56  //
    57  // Refer: https://tools.ietf.org/html/rfc8312#section-4.8
    58  func (c *cubicState) enterCongestionAvoidance() {
    59  	// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
    60  	// https://tools.ietf.org/html/rfc8312#section-4.8
    61  	if c.numCongestionEvents == 0 {
    62  		c.K = 0
    63  		c.T = c.s.ep.stack.Clock().NowMonotonic()
    64  		c.WLastMax = c.WMax
    65  		c.WMax = float64(c.s.SndCwnd)
    66  	}
    67  }
    68  
    69  // updateSlowStart will update the congestion window as per the slow-start
    70  // algorithm used by NewReno. If after adjusting the congestion window we cross
    71  // the ssThresh then it will return the number of packets that must be consumed
    72  // in congestion avoidance mode.
    73  func (c *cubicState) updateSlowStart(packetsAcked int) int {
    74  	// Don't let the congestion window cross into the congestion
    75  	// avoidance range.
    76  	newcwnd := c.s.SndCwnd + packetsAcked
    77  	enterCA := false
    78  	if newcwnd >= c.s.Ssthresh {
    79  		newcwnd = c.s.Ssthresh
    80  		c.s.SndCAAckCount = 0
    81  		enterCA = true
    82  	}
    83  
    84  	packetsAcked -= newcwnd - c.s.SndCwnd
    85  	c.s.SndCwnd = newcwnd
    86  	if enterCA {
    87  		c.enterCongestionAvoidance()
    88  	}
    89  	return packetsAcked
    90  }
    91  
    92  // Update updates cubic's internal state variables. It must be called on every
    93  // ACK received.
    94  // Refer: https://tools.ietf.org/html/rfc8312#section-4
    95  func (c *cubicState) Update(packetsAcked int) {
    96  	if c.s.SndCwnd < c.s.Ssthresh {
    97  		packetsAcked = c.updateSlowStart(packetsAcked)
    98  		if packetsAcked == 0 {
    99  			return
   100  		}
   101  	} else {
   102  		c.s.rtt.Lock()
   103  		srtt := c.s.rtt.TCPRTTState.SRTT
   104  		c.s.rtt.Unlock()
   105  		c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt)
   106  	}
   107  }
   108  
   109  // cubicCwnd computes the CUBIC congestion window after t seconds from last
   110  // congestion event.
   111  func (c *cubicState) cubicCwnd(t float64) float64 {
   112  	return c.C*math.Pow(t, 3.0) + c.WMax
   113  }
   114  
   115  // getCwnd returns the current congestion window as computed by CUBIC.
   116  // Refer: https://tools.ietf.org/html/rfc8312#section-4
   117  func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
   118  	elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T)
   119  	elapsedSeconds := elapsed.Seconds()
   120  
   121  	// Compute the window as per Cubic after 'elapsed' time
   122  	// since last congestion event.
   123  	c.WC = c.cubicCwnd(elapsedSeconds - c.K)
   124  
   125  	// Compute the TCP friendly estimate of the congestion window.
   126  	c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds())
   127  
   128  	// Make sure in the TCP friendly region CUBIC performs at least
   129  	// as well as Reno.
   130  	if c.WC < c.WEst && float64(sndCwnd) < c.WEst {
   131  		// TCP Friendly region of cubic.
   132  		return int(c.WEst)
   133  	}
   134  
   135  	// In Concave/Convex region of CUBIC, calculate what CUBIC window
   136  	// will be after 1 RTT and use that to grow congestion window
   137  	// for every ack.
   138  	tEst := (elapsed + srtt).Seconds()
   139  	wtRtt := c.cubicCwnd(tEst - c.K)
   140  	// As per 4.3 for each received ACK cwnd must be incremented
   141  	// by (w_cubic(t+RTT) - cwnd/cwnd.
   142  	cwnd := float64(sndCwnd)
   143  	for i := 0; i < packetsAcked; i++ {
   144  		// Concave/Convex regions of cubic have the same formulas.
   145  		// See: https://tools.ietf.org/html/rfc8312#section-4.3
   146  		cwnd += (wtRtt - cwnd) / cwnd
   147  	}
   148  	return int(cwnd)
   149  }
   150  
   151  // HandleLossDetected implements congestionControl.HandleLossDetected.
   152  func (c *cubicState) HandleLossDetected() {
   153  	// See: https://tools.ietf.org/html/rfc8312#section-4.5
   154  	c.numCongestionEvents++
   155  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   156  	c.WLastMax = c.WMax
   157  	c.WMax = float64(c.s.SndCwnd)
   158  
   159  	c.fastConvergence()
   160  	c.reduceSlowStartThreshold()
   161  }
   162  
   163  // HandleRTOExpired implements congestionContrl.HandleRTOExpired.
   164  func (c *cubicState) HandleRTOExpired() {
   165  	// See: https://tools.ietf.org/html/rfc8312#section-4.6
   166  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   167  	c.numCongestionEvents = 0
   168  	c.WLastMax = c.WMax
   169  	c.WMax = float64(c.s.SndCwnd)
   170  
   171  	c.fastConvergence()
   172  
   173  	// We lost a packet, so reduce ssthresh.
   174  	c.reduceSlowStartThreshold()
   175  
   176  	// Reduce the congestion window to 1, i.e., enter slow-start. Per
   177  	// RFC 5681, page 7, we must use 1 regardless of the value of the
   178  	// initial congestion window.
   179  	c.s.SndCwnd = 1
   180  }
   181  
   182  // fastConvergence implements the logic for Fast Convergence algorithm as
   183  // described in https://tools.ietf.org/html/rfc8312#section-4.6.
   184  func (c *cubicState) fastConvergence() {
   185  	if c.WMax < c.WLastMax {
   186  		c.WLastMax = c.WMax
   187  		c.WMax = c.WMax * (1.0 + c.Beta) / 2.0
   188  	} else {
   189  		c.WLastMax = c.WMax
   190  	}
   191  	// Recompute k as wMax may have changed.
   192  	c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C)
   193  }
   194  
   195  // PostRecovery implements congestionControl.PostRecovery.
   196  func (c *cubicState) PostRecovery() {
   197  	c.T = c.s.ep.stack.Clock().NowMonotonic()
   198  }
   199  
   200  // reduceSlowStartThreshold returns new SsThresh as described in
   201  // https://tools.ietf.org/html/rfc8312#section-4.7.
   202  func (c *cubicState) reduceSlowStartThreshold() {
   203  	c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0))
   204  }