github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/integrationtests/self/handshake_drop_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 mrand "math/rand" 9 "net" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/metacubex/quic-go" 15 quicproxy "github.com/metacubex/quic-go/integrationtests/tools/proxy" 16 "github.com/metacubex/quic-go/internal/wire" 17 "github.com/metacubex/quic-go/quicvarint" 18 19 . "github.com/onsi/ginkgo/v2" 20 . "github.com/onsi/gomega" 21 "github.com/onsi/gomega/gbytes" 22 ) 23 24 var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} 25 26 type applicationProtocol struct { 27 name string 28 run func(ln *quic.Listener, port int) 29 } 30 31 var _ = Describe("Handshake drop tests", func() { 32 data := GeneratePRData(5000) 33 const timeout = 2 * time.Minute 34 35 startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) { 36 conf := getQuicConfig(&quic.Config{ 37 MaxIdleTimeout: timeout, 38 HandshakeIdleTimeout: timeout, 39 }) 40 var tlsConf *tls.Config 41 if longCertChain { 42 tlsConf = getTLSConfigWithLongCertChain() 43 } else { 44 tlsConf = getTLSConfig() 45 } 46 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 47 Expect(err).ToNot(HaveOccurred()) 48 conn, err := net.ListenUDP("udp", laddr) 49 Expect(err).ToNot(HaveOccurred()) 50 tr := &quic.Transport{Conn: conn} 51 if doRetry { 52 tr.VerifySourceAddress = func(net.Addr) bool { return true } 53 } 54 ln, err = tr.Listen(tlsConf, conf) 55 Expect(err).ToNot(HaveOccurred()) 56 serverPort := ln.Addr().(*net.UDPAddr).Port 57 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 58 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 59 DropPacket: dropCallback, 60 DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { 61 return 10 * time.Millisecond 62 }, 63 }) 64 Expect(err).ToNot(HaveOccurred()) 65 66 return ln, proxy.LocalPort(), func() { 67 ln.Close() 68 tr.Close() 69 conn.Close() 70 proxy.Close() 71 } 72 } 73 74 clientSpeaksFirst := &applicationProtocol{ 75 name: "client speaks first", 76 run: func(ln *quic.Listener, port int) { 77 serverConnChan := make(chan quic.Connection) 78 go func() { 79 defer GinkgoRecover() 80 conn, err := ln.Accept(context.Background()) 81 Expect(err).ToNot(HaveOccurred()) 82 defer conn.CloseWithError(0, "") 83 str, err := conn.AcceptStream(context.Background()) 84 Expect(err).ToNot(HaveOccurred()) 85 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 86 Expect(err).ToNot(HaveOccurred()) 87 Expect(b).To(Equal(data)) 88 serverConnChan <- conn 89 }() 90 conn, err := quic.DialAddr( 91 context.Background(), 92 fmt.Sprintf("localhost:%d", port), 93 getTLSClientConfig(), 94 getQuicConfig(&quic.Config{ 95 MaxIdleTimeout: timeout, 96 HandshakeIdleTimeout: timeout, 97 }), 98 ) 99 Expect(err).ToNot(HaveOccurred()) 100 str, err := conn.OpenStream() 101 Expect(err).ToNot(HaveOccurred()) 102 _, err = str.Write(data) 103 Expect(err).ToNot(HaveOccurred()) 104 Expect(str.Close()).To(Succeed()) 105 106 var serverConn quic.Connection 107 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 108 conn.CloseWithError(0, "") 109 serverConn.CloseWithError(0, "") 110 }, 111 } 112 113 serverSpeaksFirst := &applicationProtocol{ 114 name: "server speaks first", 115 run: func(ln *quic.Listener, port int) { 116 serverConnChan := make(chan quic.Connection) 117 go func() { 118 defer GinkgoRecover() 119 conn, err := ln.Accept(context.Background()) 120 Expect(err).ToNot(HaveOccurred()) 121 str, err := conn.OpenStream() 122 Expect(err).ToNot(HaveOccurred()) 123 _, err = str.Write(data) 124 Expect(err).ToNot(HaveOccurred()) 125 Expect(str.Close()).To(Succeed()) 126 serverConnChan <- conn 127 }() 128 conn, err := quic.DialAddr( 129 context.Background(), 130 fmt.Sprintf("localhost:%d", port), 131 getTLSClientConfig(), 132 getQuicConfig(&quic.Config{ 133 MaxIdleTimeout: timeout, 134 HandshakeIdleTimeout: timeout, 135 }), 136 ) 137 Expect(err).ToNot(HaveOccurred()) 138 str, err := conn.AcceptStream(context.Background()) 139 Expect(err).ToNot(HaveOccurred()) 140 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 141 Expect(err).ToNot(HaveOccurred()) 142 Expect(b).To(Equal(data)) 143 144 var serverConn quic.Connection 145 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 146 conn.CloseWithError(0, "") 147 serverConn.CloseWithError(0, "") 148 }, 149 } 150 151 nobodySpeaks := &applicationProtocol{ 152 name: "nobody speaks", 153 run: func(ln *quic.Listener, port int) { 154 serverConnChan := make(chan quic.Connection) 155 go func() { 156 defer GinkgoRecover() 157 conn, err := ln.Accept(context.Background()) 158 Expect(err).ToNot(HaveOccurred()) 159 serverConnChan <- conn 160 }() 161 conn, err := quic.DialAddr( 162 context.Background(), 163 fmt.Sprintf("localhost:%d", port), 164 getTLSClientConfig(), 165 getQuicConfig(&quic.Config{ 166 MaxIdleTimeout: timeout, 167 HandshakeIdleTimeout: timeout, 168 }), 169 ) 170 Expect(err).ToNot(HaveOccurred()) 171 var serverConn quic.Connection 172 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 173 // both server and client accepted a connection. Close now. 174 conn.CloseWithError(0, "") 175 serverConn.CloseWithError(0, "") 176 }, 177 } 178 179 for _, d := range directions { 180 direction := d 181 182 for _, dr := range []bool{true, false} { 183 doRetry := dr 184 desc := "when using Retry" 185 if !dr { 186 desc = "when not using Retry" 187 } 188 189 Context(desc, func() { 190 for _, lcc := range []bool{false, true} { 191 longCertChain := lcc 192 193 Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { 194 for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { 195 app := a 196 197 Context(app.name, func() { 198 It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { 199 var incoming, outgoing atomic.Int32 200 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 201 var p int32 202 switch d { 203 case quicproxy.DirectionIncoming: 204 p = incoming.Add(1) 205 case quicproxy.DirectionOutgoing: 206 p = outgoing.Add(1) 207 } 208 return p == 1 && d.Is(direction) 209 }, doRetry, longCertChain) 210 defer closeFn() 211 app.run(ln, proxyPort) 212 }) 213 214 It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { 215 var incoming, outgoing atomic.Int32 216 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 217 var p int32 218 switch d { 219 case quicproxy.DirectionIncoming: 220 p = incoming.Add(1) 221 case quicproxy.DirectionOutgoing: 222 p = outgoing.Add(1) 223 } 224 return p == 2 && d.Is(direction) 225 }, doRetry, longCertChain) 226 defer closeFn() 227 app.run(ln, proxyPort) 228 }) 229 230 It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { 231 const maxSequentiallyDropped = 10 232 var mx sync.Mutex 233 var incoming, outgoing int 234 235 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 236 drop := mrand.Int63n(int64(3)) == 0 237 238 mx.Lock() 239 defer mx.Unlock() 240 // never drop more than 10 consecutive packets 241 if d.Is(quicproxy.DirectionIncoming) { 242 if drop { 243 incoming++ 244 if incoming > maxSequentiallyDropped { 245 drop = false 246 } 247 } 248 if !drop { 249 incoming = 0 250 } 251 } 252 if d.Is(quicproxy.DirectionOutgoing) { 253 if drop { 254 outgoing++ 255 if outgoing > maxSequentiallyDropped { 256 drop = false 257 } 258 } 259 if !drop { 260 outgoing = 0 261 } 262 } 263 return drop 264 }, doRetry, longCertChain) 265 defer closeFn() 266 app.run(ln, proxyPort) 267 }) 268 }) 269 } 270 }) 271 } 272 }) 273 } 274 275 It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() { 276 origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient 277 defer func() { 278 wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient 279 }() 280 b := make([]byte, 2500) // the ClientHello will now span across 3 packets 281 mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b) 282 wire.AdditionalTransportParametersClient = map[uint64][]byte{ 283 // Avoid random collisions with the greased transport parameters. 284 uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, 285 } 286 287 ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 288 if d == quicproxy.DirectionOutgoing { 289 return false 290 } 291 return mrand.Intn(3) == 0 292 }, false, false) 293 defer closeFn() 294 clientSpeaksFirst.run(ln, proxyPort) 295 }) 296 } 297 })