github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/tuic/congestion/cubic_sender.go (about)

     1  package congestion
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	"github.com/sagernet/quic-go/congestion"
     8  	"github.com/sagernet/quic-go/logging"
     9  )
    10  
    11  const (
    12  	maxBurstPackets            = 3
    13  	renoBeta                   = 0.7 // Reno backoff factor.
    14  	minCongestionWindowPackets = 2
    15  	initialCongestionWindow    = 32
    16  )
    17  
    18  const (
    19  	InvalidPacketNumber        congestion.PacketNumber = -1
    20  	MaxCongestionWindowPackets                         = 20000
    21  	MaxByteCount                                       = congestion.ByteCount(1<<62 - 1)
    22  )
    23  
    24  type cubicSender struct {
    25  	hybridSlowStart HybridSlowStart
    26  	rttStats        congestion.RTTStatsProvider
    27  	cubic           *Cubic
    28  	pacer           *pacer
    29  	clock           Clock
    30  
    31  	reno bool
    32  
    33  	// Track the largest packet that has been sent.
    34  	largestSentPacketNumber congestion.PacketNumber
    35  
    36  	// Track the largest packet that has been acked.
    37  	largestAckedPacketNumber congestion.PacketNumber
    38  
    39  	// Track the largest packet number outstanding when a CWND cutback occurs.
    40  	largestSentAtLastCutback congestion.PacketNumber
    41  
    42  	// Whether the last loss event caused us to exit slowstart.
    43  	// Used for stats collection of slowstartPacketsLost
    44  	lastCutbackExitedSlowstart bool
    45  
    46  	// Congestion window in bytes.
    47  	congestionWindow congestion.ByteCount
    48  
    49  	// Slow start congestion window in bytes, aka ssthresh.
    50  	slowStartThreshold congestion.ByteCount
    51  
    52  	// ACK counter for the Reno implementation.
    53  	numAckedPackets uint64
    54  
    55  	initialCongestionWindow    congestion.ByteCount
    56  	initialMaxCongestionWindow congestion.ByteCount
    57  
    58  	maxDatagramSize congestion.ByteCount
    59  
    60  	lastState logging.CongestionState
    61  	tracer    logging.ConnectionTracer
    62  }
    63  
    64  var _ congestion.CongestionControl = &cubicSender{}
    65  
    66  // NewCubicSender makes a new cubic sender
    67  func NewCubicSender(
    68  	clock Clock,
    69  	initialMaxDatagramSize congestion.ByteCount,
    70  	reno bool,
    71  	tracer logging.ConnectionTracer,
    72  ) *cubicSender {
    73  	return newCubicSender(
    74  		clock,
    75  		reno,
    76  		initialMaxDatagramSize,
    77  		initialCongestionWindow*initialMaxDatagramSize,
    78  		MaxCongestionWindowPackets*initialMaxDatagramSize,
    79  		tracer,
    80  	)
    81  }
    82  
    83  func newCubicSender(
    84  	clock Clock,
    85  	reno bool,
    86  	initialMaxDatagramSize,
    87  	initialCongestionWindow,
    88  	initialMaxCongestionWindow congestion.ByteCount,
    89  	tracer logging.ConnectionTracer,
    90  ) *cubicSender {
    91  	c := &cubicSender{
    92  		largestSentPacketNumber:    InvalidPacketNumber,
    93  		largestAckedPacketNumber:   InvalidPacketNumber,
    94  		largestSentAtLastCutback:   InvalidPacketNumber,
    95  		initialCongestionWindow:    initialCongestionWindow,
    96  		initialMaxCongestionWindow: initialMaxCongestionWindow,
    97  		congestionWindow:           initialCongestionWindow,
    98  		slowStartThreshold:         MaxByteCount,
    99  		cubic:                      NewCubic(clock),
   100  		clock:                      clock,
   101  		reno:                       reno,
   102  		tracer:                     tracer,
   103  		maxDatagramSize:            initialMaxDatagramSize,
   104  	}
   105  	c.pacer = newPacer(c.BandwidthEstimate)
   106  	if c.tracer != nil {
   107  		c.lastState = logging.CongestionStateSlowStart
   108  		c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
   109  	}
   110  	return c
   111  }
   112  
   113  func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
   114  	c.rttStats = provider
   115  }
   116  
   117  // TimeUntilSend returns when the next packet should be sent.
   118  func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
   119  	return c.pacer.TimeUntilSend()
   120  }
   121  
   122  func (c *cubicSender) HasPacingBudget(now time.Time) bool {
   123  	return c.pacer.Budget(now) >= c.maxDatagramSize
   124  }
   125  
   126  func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
   127  	return c.maxDatagramSize * MaxCongestionWindowPackets
   128  }
   129  
   130  func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
   131  	return c.maxDatagramSize * minCongestionWindowPackets
   132  }
   133  
   134  func (c *cubicSender) OnPacketSent(
   135  	sentTime time.Time,
   136  	_ congestion.ByteCount,
   137  	packetNumber congestion.PacketNumber,
   138  	bytes congestion.ByteCount,
   139  	isRetransmittable bool,
   140  ) {
   141  	c.pacer.SentPacket(sentTime, bytes)
   142  	if !isRetransmittable {
   143  		return
   144  	}
   145  	c.largestSentPacketNumber = packetNumber
   146  	c.hybridSlowStart.OnPacketSent(packetNumber)
   147  }
   148  
   149  func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
   150  	return bytesInFlight < c.GetCongestionWindow()
   151  }
   152  
   153  func (c *cubicSender) InRecovery() bool {
   154  	return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
   155  }
   156  
   157  func (c *cubicSender) InSlowStart() bool {
   158  	return c.GetCongestionWindow() < c.slowStartThreshold
   159  }
   160  
   161  func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
   162  	return c.congestionWindow
   163  }
   164  
   165  func (c *cubicSender) MaybeExitSlowStart() {
   166  	if c.InSlowStart() &&
   167  		c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
   168  		// exit slow start
   169  		c.slowStartThreshold = c.congestionWindow
   170  		c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
   171  	}
   172  }
   173  
   174  func (c *cubicSender) OnPacketAcked(
   175  	ackedPacketNumber congestion.PacketNumber,
   176  	ackedBytes congestion.ByteCount,
   177  	priorInFlight congestion.ByteCount,
   178  	eventTime time.Time,
   179  ) {
   180  	c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
   181  	if c.InRecovery() {
   182  		return
   183  	}
   184  	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
   185  	if c.InSlowStart() {
   186  		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
   187  	}
   188  }
   189  
   190  func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
   191  	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
   192  	// already sent should be treated as a single loss event, since it's expected.
   193  	if packetNumber <= c.largestSentAtLastCutback {
   194  		return
   195  	}
   196  	c.lastCutbackExitedSlowstart = c.InSlowStart()
   197  	c.maybeTraceStateChange(logging.CongestionStateRecovery)
   198  
   199  	if c.reno {
   200  		c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
   201  	} else {
   202  		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
   203  	}
   204  	if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
   205  		c.congestionWindow = minCwnd
   206  	}
   207  	c.slowStartThreshold = c.congestionWindow
   208  	c.largestSentAtLastCutback = c.largestSentPacketNumber
   209  	// reset packet count from congestion avoidance mode. We start
   210  	// counting again when we're out of recovery.
   211  	c.numAckedPackets = 0
   212  }
   213  
   214  // Called when we receive an ack. Normal TCP tracks how many packets one ack
   215  // represents, but quic has a separate ack for each packet.
   216  func (c *cubicSender) maybeIncreaseCwnd(
   217  	_ congestion.PacketNumber,
   218  	ackedBytes congestion.ByteCount,
   219  	priorInFlight congestion.ByteCount,
   220  	eventTime time.Time,
   221  ) {
   222  	// Do not increase the congestion window unless the sender is close to using
   223  	// the current window.
   224  	if !c.isCwndLimited(priorInFlight) {
   225  		c.cubic.OnApplicationLimited()
   226  		c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
   227  		return
   228  	}
   229  	if c.congestionWindow >= c.maxCongestionWindow() {
   230  		return
   231  	}
   232  	if c.InSlowStart() {
   233  		// TCP slow start, exponential growth, increase by one for each ACK.
   234  		c.congestionWindow += c.maxDatagramSize
   235  		c.maybeTraceStateChange(logging.CongestionStateSlowStart)
   236  		return
   237  	}
   238  	// Congestion avoidance
   239  	c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
   240  	if c.reno {
   241  		// Classic Reno congestion avoidance.
   242  		c.numAckedPackets++
   243  		if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
   244  			c.congestionWindow += c.maxDatagramSize
   245  			c.numAckedPackets = 0
   246  		}
   247  	} else {
   248  		c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
   249  	}
   250  }
   251  
   252  func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
   253  	congestionWindow := c.GetCongestionWindow()
   254  	if bytesInFlight >= congestionWindow {
   255  		return true
   256  	}
   257  	availableBytes := congestionWindow - bytesInFlight
   258  	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
   259  	return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
   260  }
   261  
   262  // BandwidthEstimate returns the current bandwidth estimate
   263  func (c *cubicSender) BandwidthEstimate() Bandwidth {
   264  	if c.rttStats == nil {
   265  		return infBandwidth
   266  	}
   267  	srtt := c.rttStats.SmoothedRTT()
   268  	if srtt == 0 {
   269  		// If we haven't measured an rtt, the bandwidth estimate is unknown.
   270  		return infBandwidth
   271  	}
   272  	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
   273  }
   274  
   275  // OnRetransmissionTimeout is called on an retransmission timeout
   276  func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
   277  	c.largestSentAtLastCutback = InvalidPacketNumber
   278  	if !packetsRetransmitted {
   279  		return
   280  	}
   281  	c.hybridSlowStart.Restart()
   282  	c.cubic.Reset()
   283  	c.slowStartThreshold = c.congestionWindow / 2
   284  	c.congestionWindow = c.minCongestionWindow()
   285  }
   286  
   287  // OnConnectionMigration is called when the connection is migrated (?)
   288  func (c *cubicSender) OnConnectionMigration() {
   289  	c.hybridSlowStart.Restart()
   290  	c.largestSentPacketNumber = InvalidPacketNumber
   291  	c.largestAckedPacketNumber = InvalidPacketNumber
   292  	c.largestSentAtLastCutback = InvalidPacketNumber
   293  	c.lastCutbackExitedSlowstart = false
   294  	c.cubic.Reset()
   295  	c.numAckedPackets = 0
   296  	c.congestionWindow = c.initialCongestionWindow
   297  	c.slowStartThreshold = c.initialMaxCongestionWindow
   298  }
   299  
   300  func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
   301  	if c.tracer == nil || new == c.lastState {
   302  		return
   303  	}
   304  	c.tracer.UpdatedCongestionState(new)
   305  	c.lastState = new
   306  }
   307  
   308  func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
   309  	if s < c.maxDatagramSize {
   310  		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
   311  	}
   312  	cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
   313  	c.maxDatagramSize = s
   314  	if cwndIsMinCwnd {
   315  		c.congestionWindow = c.minCongestionWindow()
   316  	}
   317  	c.pacer.SetMaxDatagramSize(s)
   318  }