github.com/quic-go/quic-go@v0.44.0/internal/protocol/packet_number_test.go (about) 1 package protocol 2 3 import ( 4 "fmt" 5 6 . "github.com/onsi/ginkgo/v2" 7 . "github.com/onsi/gomega" 8 ) 9 10 // Tests taken and extended from chrome 11 var _ = Describe("packet number calculation", func() { 12 It("InvalidPacketNumber is smaller than all valid packet numbers", func() { 13 Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) 14 }) 15 16 It("works with the example from the draft", func() { 17 Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) 18 }) 19 20 It("works with the examples from the draft", func() { 21 Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) 22 Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) 23 }) 24 25 getEpoch := func(len PacketNumberLen) uint64 { 26 if len > 4 { 27 Fail("invalid packet number len") 28 } 29 return uint64(1) << (len * 8) 30 } 31 32 check := func(length PacketNumberLen, expected, last uint64) { 33 epoch := getEpoch(length) 34 epochMask := epoch - 1 35 wirePacketNumber := expected & epochMask 36 ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) 37 } 38 39 for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { 40 length := l 41 42 Context(fmt.Sprintf("with %d bytes", length), func() { 43 epoch := getEpoch(length) 44 epochMask := epoch - 1 45 46 It("works near epoch start", func() { 47 // A few quick manual sanity check 48 check(length, 1, 0) 49 check(length, epoch+1, epochMask) 50 check(length, epoch, epochMask) 51 52 // Cases where the last number was close to the start of the range. 53 for last := uint64(0); last < 10; last++ { 54 // Small numbers should not wrap (even if they're out of order). 55 for j := uint64(0); j < 10; j++ { 56 check(length, j, last) 57 } 58 59 // Large numbers should not wrap either (because we're near 0 already). 60 for j := uint64(0); j < 10; j++ { 61 check(length, epoch-1-j, last) 62 } 63 } 64 }) 65 66 It("works near epoch end", func() { 67 // Cases where the last number was close to the end of the range 68 for i := uint64(0); i < 10; i++ { 69 last := epoch - i 70 71 // Small numbers should wrap. 72 for j := uint64(0); j < 10; j++ { 73 check(length, epoch+j, last) 74 } 75 76 // Large numbers should not (even if they're out of order). 77 for j := uint64(0); j < 10; j++ { 78 check(length, epoch-1-j, last) 79 } 80 } 81 }) 82 83 // Next check where we're in a non-zero epoch to verify we handle 84 // reverse wrapping, too. 85 It("works near previous epoch", func() { 86 prevEpoch := 1 * epoch 87 curEpoch := 2 * epoch 88 // Cases where the last number was close to the start of the range 89 for i := uint64(0); i < 10; i++ { 90 last := curEpoch + i 91 // Small number should not wrap (even if they're out of order). 92 for j := uint64(0); j < 10; j++ { 93 check(length, curEpoch+j, last) 94 } 95 96 // But large numbers should reverse wrap. 97 for j := uint64(0); j < 10; j++ { 98 num := epoch - 1 - j 99 check(length, prevEpoch+num, last) 100 } 101 } 102 }) 103 104 It("works near next epoch", func() { 105 curEpoch := 2 * epoch 106 nextEpoch := 3 * epoch 107 // Cases where the last number was close to the end of the range 108 for i := uint64(0); i < 10; i++ { 109 last := nextEpoch - 1 - i 110 111 // Small numbers should wrap. 112 for j := uint64(0); j < 10; j++ { 113 check(length, nextEpoch+j, last) 114 } 115 116 // but large numbers should not (even if they're out of order). 117 for j := uint64(0); j < 10; j++ { 118 num := epoch - 1 - j 119 check(length, curEpoch+num, last) 120 } 121 } 122 }) 123 124 Context("shortening a packet number for the header", func() { 125 Context("shortening", func() { 126 It("sends out low packet numbers as 2 byte", func() { 127 length := GetPacketNumberLengthForHeader(4, 2) 128 Expect(length).To(Equal(PacketNumberLen2)) 129 }) 130 131 It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { 132 length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) 133 Expect(length).To(Equal(PacketNumberLen2)) 134 }) 135 136 It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { 137 length := GetPacketNumberLengthForHeader(40000, 2) 138 Expect(length).To(Equal(PacketNumberLen3)) 139 }) 140 141 It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { 142 length := GetPacketNumberLengthForHeader(40000000, 2) 143 Expect(length).To(Equal(PacketNumberLen4)) 144 }) 145 }) 146 147 Context("self-consistency", func() { 148 It("works for small packet numbers", func() { 149 for i := uint64(1); i < 10000; i++ { 150 packetNumber := PacketNumber(i) 151 leastUnacked := PacketNumber(1) 152 length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) 153 wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) 154 155 decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) 156 Expect(decodedPacketNumber).To(Equal(packetNumber)) 157 } 158 }) 159 160 It("works for small packet numbers and increasing ACKed packets", func() { 161 for i := uint64(1); i < 10000; i++ { 162 packetNumber := PacketNumber(i) 163 leastUnacked := PacketNumber(i / 2) 164 length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) 165 epochMask := getEpoch(length) - 1 166 wirePacketNumber := uint64(packetNumber) & epochMask 167 168 decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) 169 Expect(decodedPacketNumber).To(Equal(packetNumber)) 170 } 171 }) 172 173 It("also works for larger packet numbers", func() { 174 var increment uint64 175 for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { 176 packetNumber := PacketNumber(i) 177 leastUnacked := PacketNumber(1) 178 length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) 179 epochMask := getEpoch(length) - 1 180 wirePacketNumber := uint64(packetNumber) & epochMask 181 182 decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) 183 Expect(decodedPacketNumber).To(Equal(packetNumber)) 184 185 increment = getEpoch(length) / 8 186 } 187 }) 188 189 It("works for packet numbers larger than 2^48", func() { 190 for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { 191 packetNumber := PacketNumber(i) 192 leastUnacked := PacketNumber(i - 1000) 193 length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) 194 wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) 195 196 decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) 197 Expect(decodedPacketNumber).To(Equal(packetNumber)) 198 } 199 }) 200 }) 201 }) 202 }) 203 } 204 })