github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/self/handshake_drop_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 mrand "math/rand" 9 "net" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/apernet/quic-go" 15 quicproxy "github.com/apernet/quic-go/integrationtests/tools/proxy" 16 "github.com/apernet/quic-go/internal/wire" 17 "github.com/apernet/quic-go/quicvarint" 18 19 . "github.com/onsi/ginkgo/v2" 20 . "github.com/onsi/gomega" 21 "github.com/onsi/gomega/gbytes" 22 ) 23 24 var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} 25 26 type applicationProtocol struct { 27 name string 28 run func(ln *quic.Listener, port int) 29 } 30 31 var _ = Describe("Handshake drop tests", func() { 32 data := GeneratePRData(5000) 33 const timeout = 2 * time.Minute 34 35 startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) { 36 conf := getQuicConfig(&quic.Config{ 37 MaxIdleTimeout: timeout, 38 HandshakeIdleTimeout: timeout, 39 }) 40 var tlsConf *tls.Config 41 if longCertChain { 42 tlsConf = getTLSConfigWithLongCertChain() 43 } else { 44 tlsConf = getTLSConfig() 45 } 46 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 47 Expect(err).ToNot(HaveOccurred()) 48 conn, err := net.ListenUDP("udp", laddr) 49 Expect(err).ToNot(HaveOccurred()) 50 tr := &quic.Transport{Conn: conn} 51 if doRetry { 52 tr.VerifySourceAddress = func(net.Addr) bool { return true } 53 } 54 ln, err = tr.Listen(tlsConf, conf) 55 Expect(err).ToNot(HaveOccurred()) 56 serverPort := ln.Addr().(*net.UDPAddr).Port 57 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 58 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 59 DropPacket: dropCallback, 60 DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { 61 return 10 * time.Millisecond 62 }, 63 }) 64 Expect(err).ToNot(HaveOccurred()) 65 66 return ln, proxy.LocalPort(), func() { 67 ln.Close() 68 tr.Close() 69 conn.Close() 70 proxy.Close() 71 } 72 } 73 74 clientSpeaksFirst := &applicationProtocol{ 75 name: "client speaks first", 76 run: func(ln *quic.Listener, port int) { 77 serverConnChan := make(chan quic.Connection) 78 go func() { 79 defer GinkgoRecover() 80 conn, err := ln.Accept(context.Background()) 81 Expect(err).ToNot(HaveOccurred()) 82 defer conn.CloseWithError(0, "") 83 str, err := conn.AcceptStream(context.Background()) 84 Expect(err).ToNot(HaveOccurred()) 85 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 86 Expect(err).ToNot(HaveOccurred()) 87 Expect(b).To(Equal(data)) 88 serverConnChan <- conn 89 }() 90 conn, err := quic.DialAddr( 91 context.Background(), 92 fmt.Sprintf("localhost:%d", port), 93 getTLSClientConfig(), 94 getQuicConfig(&quic.Config{ 95 MaxIdleTimeout: timeout, 96 HandshakeIdleTimeout: timeout, 97 }), 98 ) 99 Expect(err).ToNot(HaveOccurred()) 100 str, err := conn.OpenStream() 101 Expect(err).ToNot(HaveOccurred()) 102 _, err = str.Write(data) 103 Expect(err).ToNot(HaveOccurred()) 104 Expect(str.Close()).To(Succeed()) 105 106 var serverConn quic.Connection 107 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 108 conn.CloseWithError(0, "") 109 serverConn.CloseWithError(0, "") 110 }, 111 } 112 113 serverSpeaksFirst := &applicationProtocol{ 114 name: "server speaks first", 115 run: func(ln *quic.Listener, port int) { 116 serverConnChan := make(chan quic.Connection) 117 go func() { 118 defer GinkgoRecover() 119 conn, err := ln.Accept(context.Background()) 120 Expect(err).ToNot(HaveOccurred()) 121 str, err := conn.OpenStream() 122 Expect(err).ToNot(HaveOccurred()) 123 _, err = str.Write(data) 124 Expect(err).ToNot(HaveOccurred()) 125 Expect(str.Close()).To(Succeed()) 126 serverConnChan <- conn 127 }() 128 conn, err := quic.DialAddr( 129 context.Background(), 130 fmt.Sprintf("localhost:%d", port), 131 getTLSClientConfig(), 132 getQuicConfig(&quic.Config{ 133 MaxIdleTimeout: timeout, 134 HandshakeIdleTimeout: timeout, 135 }), 136 ) 137 Expect(err).ToNot(HaveOccurred()) 138 str, err := conn.AcceptStream(context.Background()) 139 Expect(err).ToNot(HaveOccurred()) 140 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 141 Expect(err).ToNot(HaveOccurred()) 142 Expect(b).To(Equal(data)) 143 144 var serverConn quic.Connection 145 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 146 conn.CloseWithError(0, "") 147 serverConn.CloseWithError(0, "") 148 }, 149 } 150 151 nobodySpeaks := &applicationProtocol{ 152 name: "nobody speaks", 153 run: func(ln *quic.Listener, port int) { 154 serverConnChan := make(chan quic.Connection) 155 go func() { 156 defer GinkgoRecover() 157 conn, err := ln.Accept(context.Background()) 158 Expect(err).ToNot(HaveOccurred()) 159 serverConnChan <- conn 160 }() 161 conn, err := quic.DialAddr( 162 context.Background(), 163 fmt.Sprintf("localhost:%d", port), 164 getTLSClientConfig(), 165 getQuicConfig(&quic.Config{ 166 MaxIdleTimeout: timeout, 167 HandshakeIdleTimeout: timeout, 168 }), 169 ) 170 Expect(err).ToNot(HaveOccurred()) 171 var serverConn quic.Connection 172 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 173 // both server and client accepted a connection. Close now. 174 conn.CloseWithError(0, "") 175 serverConn.CloseWithError(0, "") 176 }, 177 } 178 179 for _, d := range directions { 180 direction := d 181 182 for _, dr := range []bool{true, false} { 183 doRetry := dr 184 desc := "when using Retry" 185 if !dr { 186 desc = "when not using Retry" 187 } 188 189 Context(desc, func() { 190 for _, lcc := range []bool{false, true} { 191 longCertChain := lcc 192 193 Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { 194 for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { 195 app := a 196 197 Context(app.name, func() { 198 It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { 199 var incoming, outgoing atomic.Int32 200 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 201 var p int32 202 //nolint:exhaustive 203 switch d { 204 case quicproxy.DirectionIncoming: 205 p = incoming.Add(1) 206 case quicproxy.DirectionOutgoing: 207 p = outgoing.Add(1) 208 } 209 return p == 1 && d.Is(direction) 210 }, doRetry, longCertChain) 211 defer closeFn() 212 app.run(ln, proxyPort) 213 }) 214 215 It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { 216 var incoming, outgoing atomic.Int32 217 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 218 var p int32 219 //nolint:exhaustive 220 switch d { 221 case quicproxy.DirectionIncoming: 222 p = incoming.Add(1) 223 case quicproxy.DirectionOutgoing: 224 p = outgoing.Add(1) 225 } 226 return p == 2 && d.Is(direction) 227 }, doRetry, longCertChain) 228 defer closeFn() 229 app.run(ln, proxyPort) 230 }) 231 232 It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { 233 const maxSequentiallyDropped = 10 234 var mx sync.Mutex 235 var incoming, outgoing int 236 237 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 238 drop := mrand.Int63n(int64(3)) == 0 239 240 mx.Lock() 241 defer mx.Unlock() 242 // never drop more than 10 consecutive packets 243 if d.Is(quicproxy.DirectionIncoming) { 244 if drop { 245 incoming++ 246 if incoming > maxSequentiallyDropped { 247 drop = false 248 } 249 } 250 if !drop { 251 incoming = 0 252 } 253 } 254 if d.Is(quicproxy.DirectionOutgoing) { 255 if drop { 256 outgoing++ 257 if outgoing > maxSequentiallyDropped { 258 drop = false 259 } 260 } 261 if !drop { 262 outgoing = 0 263 } 264 } 265 return drop 266 }, doRetry, longCertChain) 267 defer closeFn() 268 app.run(ln, proxyPort) 269 }) 270 }) 271 } 272 }) 273 } 274 }) 275 } 276 277 It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() { 278 origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient 279 defer func() { 280 wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient 281 }() 282 b := make([]byte, 2500) // the ClientHello will now span across 3 packets 283 mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b) 284 wire.AdditionalTransportParametersClient = map[uint64][]byte{ 285 // Avoid random collisions with the greased transport parameters. 286 uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, 287 } 288 289 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 290 if d == quicproxy.DirectionOutgoing { 291 return false 292 } 293 return mrand.Intn(3) == 0 294 }, false, false) 295 defer closeFn() 296 clientSpeaksFirst.run(ln, proxyPort) 297 }) 298 } 299 })