github.com/quic-go/quic-go@v0.44.0/internal/protocol/packet_number_test.go (about)

     1  package protocol
     2  
     3  import (
     4  	"fmt"
     5  
     6  	. "github.com/onsi/ginkgo/v2"
     7  	. "github.com/onsi/gomega"
     8  )
     9  
    10  // Tests taken and extended from chrome
    11  var _ = Describe("packet number calculation", func() {
    12  	It("InvalidPacketNumber is smaller than all valid packet numbers", func() {
    13  		Expect(InvalidPacketNumber).To(BeNumerically("<", 0))
    14  	})
    15  
    16  	It("works with the example from the draft", func() {
    17  		Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32)))
    18  	})
    19  
    20  	It("works with the examples from the draft", func() {
    21  		Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2))
    22  		Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3))
    23  	})
    24  
    25  	getEpoch := func(len PacketNumberLen) uint64 {
    26  		if len > 4 {
    27  			Fail("invalid packet number len")
    28  		}
    29  		return uint64(1) << (len * 8)
    30  	}
    31  
    32  	check := func(length PacketNumberLen, expected, last uint64) {
    33  		epoch := getEpoch(length)
    34  		epochMask := epoch - 1
    35  		wirePacketNumber := expected & epochMask
    36  		ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
    37  	}
    38  
    39  	for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} {
    40  		length := l
    41  
    42  		Context(fmt.Sprintf("with %d bytes", length), func() {
    43  			epoch := getEpoch(length)
    44  			epochMask := epoch - 1
    45  
    46  			It("works near epoch start", func() {
    47  				// A few quick manual sanity check
    48  				check(length, 1, 0)
    49  				check(length, epoch+1, epochMask)
    50  				check(length, epoch, epochMask)
    51  
    52  				// Cases where the last number was close to the start of the range.
    53  				for last := uint64(0); last < 10; last++ {
    54  					// Small numbers should not wrap (even if they're out of order).
    55  					for j := uint64(0); j < 10; j++ {
    56  						check(length, j, last)
    57  					}
    58  
    59  					// Large numbers should not wrap either (because we're near 0 already).
    60  					for j := uint64(0); j < 10; j++ {
    61  						check(length, epoch-1-j, last)
    62  					}
    63  				}
    64  			})
    65  
    66  			It("works near epoch end", func() {
    67  				// Cases where the last number was close to the end of the range
    68  				for i := uint64(0); i < 10; i++ {
    69  					last := epoch - i
    70  
    71  					// Small numbers should wrap.
    72  					for j := uint64(0); j < 10; j++ {
    73  						check(length, epoch+j, last)
    74  					}
    75  
    76  					// Large numbers should not (even if they're out of order).
    77  					for j := uint64(0); j < 10; j++ {
    78  						check(length, epoch-1-j, last)
    79  					}
    80  				}
    81  			})
    82  
    83  			// Next check where we're in a non-zero epoch to verify we handle
    84  			// reverse wrapping, too.
    85  			It("works near previous epoch", func() {
    86  				prevEpoch := 1 * epoch
    87  				curEpoch := 2 * epoch
    88  				// Cases where the last number was close to the start of the range
    89  				for i := uint64(0); i < 10; i++ {
    90  					last := curEpoch + i
    91  					// Small number should not wrap (even if they're out of order).
    92  					for j := uint64(0); j < 10; j++ {
    93  						check(length, curEpoch+j, last)
    94  					}
    95  
    96  					// But large numbers should reverse wrap.
    97  					for j := uint64(0); j < 10; j++ {
    98  						num := epoch - 1 - j
    99  						check(length, prevEpoch+num, last)
   100  					}
   101  				}
   102  			})
   103  
   104  			It("works near next epoch", func() {
   105  				curEpoch := 2 * epoch
   106  				nextEpoch := 3 * epoch
   107  				// Cases where the last number was close to the end of the range
   108  				for i := uint64(0); i < 10; i++ {
   109  					last := nextEpoch - 1 - i
   110  
   111  					// Small numbers should wrap.
   112  					for j := uint64(0); j < 10; j++ {
   113  						check(length, nextEpoch+j, last)
   114  					}
   115  
   116  					// but large numbers should not (even if they're out of order).
   117  					for j := uint64(0); j < 10; j++ {
   118  						num := epoch - 1 - j
   119  						check(length, curEpoch+num, last)
   120  					}
   121  				}
   122  			})
   123  
   124  			Context("shortening a packet number for the header", func() {
   125  				Context("shortening", func() {
   126  					It("sends out low packet numbers as 2 byte", func() {
   127  						length := GetPacketNumberLengthForHeader(4, 2)
   128  						Expect(length).To(Equal(PacketNumberLen2))
   129  					})
   130  
   131  					It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
   132  						length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1)
   133  						Expect(length).To(Equal(PacketNumberLen2))
   134  					})
   135  
   136  					It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() {
   137  						length := GetPacketNumberLengthForHeader(40000, 2)
   138  						Expect(length).To(Equal(PacketNumberLen3))
   139  					})
   140  
   141  					It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
   142  						length := GetPacketNumberLengthForHeader(40000000, 2)
   143  						Expect(length).To(Equal(PacketNumberLen4))
   144  					})
   145  				})
   146  
   147  				Context("self-consistency", func() {
   148  					It("works for small packet numbers", func() {
   149  						for i := uint64(1); i < 10000; i++ {
   150  							packetNumber := PacketNumber(i)
   151  							leastUnacked := PacketNumber(1)
   152  							length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
   153  							wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
   154  
   155  							decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
   156  							Expect(decodedPacketNumber).To(Equal(packetNumber))
   157  						}
   158  					})
   159  
   160  					It("works for small packet numbers and increasing ACKed packets", func() {
   161  						for i := uint64(1); i < 10000; i++ {
   162  							packetNumber := PacketNumber(i)
   163  							leastUnacked := PacketNumber(i / 2)
   164  							length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
   165  							epochMask := getEpoch(length) - 1
   166  							wirePacketNumber := uint64(packetNumber) & epochMask
   167  
   168  							decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
   169  							Expect(decodedPacketNumber).To(Equal(packetNumber))
   170  						}
   171  					})
   172  
   173  					It("also works for larger packet numbers", func() {
   174  						var increment uint64
   175  						for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment {
   176  							packetNumber := PacketNumber(i)
   177  							leastUnacked := PacketNumber(1)
   178  							length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
   179  							epochMask := getEpoch(length) - 1
   180  							wirePacketNumber := uint64(packetNumber) & epochMask
   181  
   182  							decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
   183  							Expect(decodedPacketNumber).To(Equal(packetNumber))
   184  
   185  							increment = getEpoch(length) / 8
   186  						}
   187  					})
   188  
   189  					It("works for packet numbers larger than 2^48", func() {
   190  						for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
   191  							packetNumber := PacketNumber(i)
   192  							leastUnacked := PacketNumber(i - 1000)
   193  							length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
   194  							wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
   195  
   196  							decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
   197  							Expect(decodedPacketNumber).To(Equal(packetNumber))
   198  						}
   199  					})
   200  				})
   201  			})
   202  		})
   203  	}
   204  })