github.com/kelleygo/clashcore@v1.0.2/transport/tuic/congestion/cubic_sender.go (about)

     1  package congestion
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	"github.com/metacubex/quic-go/congestion"
     8  )
     9  
    10  const (
    11  	maxBurstPackets            = 3
    12  	renoBeta                   = 0.7 // Reno backoff factor.
    13  	minCongestionWindowPackets = 2
    14  	initialCongestionWindow    = 32
    15  )
    16  
    17  const InvalidPacketNumber congestion.PacketNumber = -1
    18  const MaxCongestionWindowPackets = 20000
    19  const MaxByteCount = congestion.ByteCount(1<<62 - 1)
    20  
    21  type cubicSender struct {
    22  	hybridSlowStart HybridSlowStart
    23  	rttStats        congestion.RTTStatsProvider
    24  	cubic           *Cubic
    25  	pacer           *pacer
    26  	clock           Clock
    27  
    28  	reno bool
    29  
    30  	// Track the largest packet that has been sent.
    31  	largestSentPacketNumber congestion.PacketNumber
    32  
    33  	// Track the largest packet that has been acked.
    34  	largestAckedPacketNumber congestion.PacketNumber
    35  
    36  	// Track the largest packet number outstanding when a CWND cutback occurs.
    37  	largestSentAtLastCutback congestion.PacketNumber
    38  
    39  	// Whether the last loss event caused us to exit slowstart.
    40  	// Used for stats collection of slowstartPacketsLost
    41  	lastCutbackExitedSlowstart bool
    42  
    43  	// Congestion window in bytes.
    44  	congestionWindow congestion.ByteCount
    45  
    46  	// Slow start congestion window in bytes, aka ssthresh.
    47  	slowStartThreshold congestion.ByteCount
    48  
    49  	// ACK counter for the Reno implementation.
    50  	numAckedPackets uint64
    51  
    52  	initialCongestionWindow    congestion.ByteCount
    53  	initialMaxCongestionWindow congestion.ByteCount
    54  
    55  	maxDatagramSize congestion.ByteCount
    56  }
    57  
    58  var (
    59  	_ congestion.CongestionControl = &cubicSender{}
    60  )
    61  
    62  // NewCubicSender makes a new cubic sender
    63  func NewCubicSender(
    64  	clock Clock,
    65  	initialMaxDatagramSize congestion.ByteCount,
    66  	reno bool,
    67  ) *cubicSender {
    68  	return newCubicSender(
    69  		clock,
    70  		reno,
    71  		initialMaxDatagramSize,
    72  		initialCongestionWindow*initialMaxDatagramSize,
    73  		MaxCongestionWindowPackets*initialMaxDatagramSize,
    74  	)
    75  }
    76  
    77  func newCubicSender(
    78  	clock Clock,
    79  	reno bool,
    80  	initialMaxDatagramSize,
    81  	initialCongestionWindow,
    82  	initialMaxCongestionWindow congestion.ByteCount,
    83  ) *cubicSender {
    84  	c := &cubicSender{
    85  		largestSentPacketNumber:    InvalidPacketNumber,
    86  		largestAckedPacketNumber:   InvalidPacketNumber,
    87  		largestSentAtLastCutback:   InvalidPacketNumber,
    88  		initialCongestionWindow:    initialCongestionWindow,
    89  		initialMaxCongestionWindow: initialMaxCongestionWindow,
    90  		congestionWindow:           initialCongestionWindow,
    91  		slowStartThreshold:         MaxByteCount,
    92  		cubic:                      NewCubic(clock),
    93  		clock:                      clock,
    94  		reno:                       reno,
    95  		maxDatagramSize:            initialMaxDatagramSize,
    96  	}
    97  	c.pacer = newPacer(c.BandwidthEstimate)
    98  	return c
    99  }
   100  
   101  func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
   102  	c.rttStats = provider
   103  }
   104  
   105  // TimeUntilSend returns when the next packet should be sent.
   106  func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
   107  	return c.pacer.TimeUntilSend()
   108  }
   109  
   110  func (c *cubicSender) HasPacingBudget(now time.Time) bool {
   111  	return c.pacer.Budget(now) >= c.maxDatagramSize
   112  }
   113  
   114  func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
   115  	return c.maxDatagramSize * MaxCongestionWindowPackets
   116  }
   117  
   118  func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
   119  	return c.maxDatagramSize * minCongestionWindowPackets
   120  }
   121  
   122  func (c *cubicSender) OnPacketSent(
   123  	sentTime time.Time,
   124  	_ congestion.ByteCount,
   125  	packetNumber congestion.PacketNumber,
   126  	bytes congestion.ByteCount,
   127  	isRetransmittable bool,
   128  ) {
   129  	c.pacer.SentPacket(sentTime, bytes)
   130  	if !isRetransmittable {
   131  		return
   132  	}
   133  	c.largestSentPacketNumber = packetNumber
   134  	c.hybridSlowStart.OnPacketSent(packetNumber)
   135  }
   136  
   137  func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
   138  	return bytesInFlight < c.GetCongestionWindow()
   139  }
   140  
   141  func (c *cubicSender) InRecovery() bool {
   142  	return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
   143  }
   144  
   145  func (c *cubicSender) InSlowStart() bool {
   146  	return c.GetCongestionWindow() < c.slowStartThreshold
   147  }
   148  
   149  func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
   150  	return c.congestionWindow
   151  }
   152  
   153  func (c *cubicSender) MaybeExitSlowStart() {
   154  	if c.InSlowStart() &&
   155  		c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
   156  		// exit slow start
   157  		c.slowStartThreshold = c.congestionWindow
   158  	}
   159  }
   160  
   161  func (c *cubicSender) OnPacketAcked(
   162  	ackedPacketNumber congestion.PacketNumber,
   163  	ackedBytes congestion.ByteCount,
   164  	priorInFlight congestion.ByteCount,
   165  	eventTime time.Time,
   166  ) {
   167  	c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
   168  	if c.InRecovery() {
   169  		return
   170  	}
   171  	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
   172  	if c.InSlowStart() {
   173  		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
   174  	}
   175  }
   176  
   177  func (c *cubicSender) OnCongestionEvent(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
   178  	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
   179  	// already sent should be treated as a single loss event, since it's expected.
   180  	if packetNumber <= c.largestSentAtLastCutback {
   181  		return
   182  	}
   183  	c.lastCutbackExitedSlowstart = c.InSlowStart()
   184  
   185  	if c.reno {
   186  		c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
   187  	} else {
   188  		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
   189  	}
   190  	if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
   191  		c.congestionWindow = minCwnd
   192  	}
   193  	c.slowStartThreshold = c.congestionWindow
   194  	c.largestSentAtLastCutback = c.largestSentPacketNumber
   195  	// reset packet count from congestion avoidance mode. We start
   196  	// counting again when we're out of recovery.
   197  	c.numAckedPackets = 0
   198  }
   199  
   200  func (b *cubicSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
   201  	// Stub
   202  }
   203  
   204  // Called when we receive an ack. Normal TCP tracks how many packets one ack
   205  // represents, but quic has a separate ack for each packet.
   206  func (c *cubicSender) maybeIncreaseCwnd(
   207  	_ congestion.PacketNumber,
   208  	ackedBytes congestion.ByteCount,
   209  	priorInFlight congestion.ByteCount,
   210  	eventTime time.Time,
   211  ) {
   212  	// Do not increase the congestion window unless the sender is close to using
   213  	// the current window.
   214  	if !c.isCwndLimited(priorInFlight) {
   215  		c.cubic.OnApplicationLimited()
   216  		return
   217  	}
   218  	if c.congestionWindow >= c.maxCongestionWindow() {
   219  		return
   220  	}
   221  	if c.InSlowStart() {
   222  		// TCP slow start, exponential growth, increase by one for each ACK.
   223  		c.congestionWindow += c.maxDatagramSize
   224  		return
   225  	}
   226  	// Congestion avoidance
   227  	if c.reno {
   228  		// Classic Reno congestion avoidance.
   229  		c.numAckedPackets++
   230  		if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
   231  			c.congestionWindow += c.maxDatagramSize
   232  			c.numAckedPackets = 0
   233  		}
   234  	} else {
   235  		c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
   236  	}
   237  }
   238  
   239  func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
   240  	congestionWindow := c.GetCongestionWindow()
   241  	if bytesInFlight >= congestionWindow {
   242  		return true
   243  	}
   244  	availableBytes := congestionWindow - bytesInFlight
   245  	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
   246  	return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
   247  }
   248  
   249  // BandwidthEstimate returns the current bandwidth estimate
   250  func (c *cubicSender) BandwidthEstimate() Bandwidth {
   251  	if c.rttStats == nil {
   252  		return infBandwidth
   253  	}
   254  	srtt := c.rttStats.SmoothedRTT()
   255  	if srtt == 0 {
   256  		// If we haven't measured an rtt, the bandwidth estimate is unknown.
   257  		return infBandwidth
   258  	}
   259  	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
   260  }
   261  
   262  // OnRetransmissionTimeout is called on an retransmission timeout
   263  func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
   264  	c.largestSentAtLastCutback = InvalidPacketNumber
   265  	if !packetsRetransmitted {
   266  		return
   267  	}
   268  	c.hybridSlowStart.Restart()
   269  	c.cubic.Reset()
   270  	c.slowStartThreshold = c.congestionWindow / 2
   271  	c.congestionWindow = c.minCongestionWindow()
   272  }
   273  
   274  // OnConnectionMigration is called when the connection is migrated (?)
   275  func (c *cubicSender) OnConnectionMigration() {
   276  	c.hybridSlowStart.Restart()
   277  	c.largestSentPacketNumber = InvalidPacketNumber
   278  	c.largestAckedPacketNumber = InvalidPacketNumber
   279  	c.largestSentAtLastCutback = InvalidPacketNumber
   280  	c.lastCutbackExitedSlowstart = false
   281  	c.cubic.Reset()
   282  	c.numAckedPackets = 0
   283  	c.congestionWindow = c.initialCongestionWindow
   284  	c.slowStartThreshold = c.initialMaxCongestionWindow
   285  }
   286  
   287  func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
   288  	if s < c.maxDatagramSize {
   289  		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
   290  	}
   291  	cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
   292  	c.maxDatagramSize = s
   293  	if cwndIsMinCwnd {
   294  		c.congestionWindow = c.minCongestionWindow()
   295  	}
   296  	c.pacer.SetMaxDatagramSize(s)
   297  }