github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/handshake_drop_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 mrand "math/rand" 9 "net" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/MerlinKodo/quic-go/quicvarint" 15 16 "github.com/MerlinKodo/quic-go" 17 quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy" 18 "github.com/MerlinKodo/quic-go/internal/wire" 19 20 . "github.com/onsi/ginkgo/v2" 21 . "github.com/onsi/gomega" 22 "github.com/onsi/gomega/gbytes" 23 ) 24 25 var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} 26 27 type applicationProtocol struct { 28 name string 29 run func() 30 } 31 32 var _ = Describe("Handshake drop tests", func() { 33 var ( 34 proxy *quicproxy.QuicProxy 35 ln *quic.Listener 36 ) 37 38 data := GeneratePRData(5000) 39 const timeout = 2 * time.Minute 40 41 startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) { 42 conf := getQuicConfig(&quic.Config{ 43 MaxIdleTimeout: timeout, 44 HandshakeIdleTimeout: timeout, 45 RequireAddressValidation: func(net.Addr) bool { return doRetry }, 46 }) 47 var tlsConf *tls.Config 48 if longCertChain { 49 tlsConf = getTLSConfigWithLongCertChain() 50 } else { 51 tlsConf = getTLSConfig() 52 } 53 var err error 54 ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) 55 Expect(err).ToNot(HaveOccurred()) 56 serverPort := ln.Addr().(*net.UDPAddr).Port 57 proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 58 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 59 DropPacket: dropCallback, 60 DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { 61 return 10 * time.Millisecond 62 }, 63 }) 64 Expect(err).ToNot(HaveOccurred()) 65 } 66 67 clientSpeaksFirst := &applicationProtocol{ 68 name: "client speaks first", 69 run: func() { 70 serverConnChan := make(chan quic.Connection) 71 go func() { 72 defer GinkgoRecover() 73 conn, err := ln.Accept(context.Background()) 74 Expect(err).ToNot(HaveOccurred()) 75 defer conn.CloseWithError(0, "") 76 str, err := conn.AcceptStream(context.Background()) 77 Expect(err).ToNot(HaveOccurred()) 78 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 79 Expect(err).ToNot(HaveOccurred()) 80 Expect(b).To(Equal(data)) 81 serverConnChan <- conn 82 }() 83 conn, err := quic.DialAddr( 84 context.Background(), 85 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 86 getTLSClientConfig(), 87 getQuicConfig(&quic.Config{ 88 MaxIdleTimeout: timeout, 89 HandshakeIdleTimeout: timeout, 90 }), 91 ) 92 Expect(err).ToNot(HaveOccurred()) 93 str, err := conn.OpenStream() 94 Expect(err).ToNot(HaveOccurred()) 95 _, err = str.Write(data) 96 Expect(err).ToNot(HaveOccurred()) 97 Expect(str.Close()).To(Succeed()) 98 99 var serverConn quic.Connection 100 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 101 conn.CloseWithError(0, "") 102 serverConn.CloseWithError(0, "") 103 }, 104 } 105 106 serverSpeaksFirst := &applicationProtocol{ 107 name: "server speaks first", 108 run: func() { 109 serverConnChan := make(chan quic.Connection) 110 go func() { 111 defer GinkgoRecover() 112 conn, err := ln.Accept(context.Background()) 113 Expect(err).ToNot(HaveOccurred()) 114 str, err := conn.OpenStream() 115 Expect(err).ToNot(HaveOccurred()) 116 _, err = str.Write(data) 117 Expect(err).ToNot(HaveOccurred()) 118 Expect(str.Close()).To(Succeed()) 119 serverConnChan <- conn 120 }() 121 conn, err := quic.DialAddr( 122 context.Background(), 123 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 124 getTLSClientConfig(), 125 getQuicConfig(&quic.Config{ 126 MaxIdleTimeout: timeout, 127 HandshakeIdleTimeout: timeout, 128 }), 129 ) 130 Expect(err).ToNot(HaveOccurred()) 131 str, err := conn.AcceptStream(context.Background()) 132 Expect(err).ToNot(HaveOccurred()) 133 b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) 134 Expect(err).ToNot(HaveOccurred()) 135 Expect(b).To(Equal(data)) 136 137 var serverConn quic.Connection 138 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 139 conn.CloseWithError(0, "") 140 serverConn.CloseWithError(0, "") 141 }, 142 } 143 144 nobodySpeaks := &applicationProtocol{ 145 name: "nobody speaks", 146 run: func() { 147 serverConnChan := make(chan quic.Connection) 148 go func() { 149 defer GinkgoRecover() 150 conn, err := ln.Accept(context.Background()) 151 Expect(err).ToNot(HaveOccurred()) 152 serverConnChan <- conn 153 }() 154 conn, err := quic.DialAddr( 155 context.Background(), 156 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 157 getTLSClientConfig(), 158 getQuicConfig(&quic.Config{ 159 MaxIdleTimeout: timeout, 160 HandshakeIdleTimeout: timeout, 161 }), 162 ) 163 Expect(err).ToNot(HaveOccurred()) 164 var serverConn quic.Connection 165 Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) 166 // both server and client accepted a connection. Close now. 167 conn.CloseWithError(0, "") 168 serverConn.CloseWithError(0, "") 169 }, 170 } 171 172 AfterEach(func() { 173 Expect(ln.Close()).To(Succeed()) 174 Expect(proxy.Close()).To(Succeed()) 175 }) 176 177 for _, d := range directions { 178 direction := d 179 180 for _, dr := range []bool{true, false} { 181 doRetry := dr 182 desc := "when using Retry" 183 if !dr { 184 desc = "when not using Retry" 185 } 186 187 Context(desc, func() { 188 for _, lcc := range []bool{false, true} { 189 longCertChain := lcc 190 191 Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { 192 for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { 193 app := a 194 195 Context(app.name, func() { 196 It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { 197 var incoming, outgoing int32 198 startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 199 var p int32 200 //nolint:exhaustive 201 switch d { 202 case quicproxy.DirectionIncoming: 203 p = atomic.AddInt32(&incoming, 1) 204 case quicproxy.DirectionOutgoing: 205 p = atomic.AddInt32(&outgoing, 1) 206 } 207 return p == 1 && d.Is(direction) 208 }, doRetry, longCertChain) 209 app.run() 210 }) 211 212 It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { 213 var incoming, outgoing int32 214 startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 215 var p int32 216 //nolint:exhaustive 217 switch d { 218 case quicproxy.DirectionIncoming: 219 p = atomic.AddInt32(&incoming, 1) 220 case quicproxy.DirectionOutgoing: 221 p = atomic.AddInt32(&outgoing, 1) 222 } 223 return p == 2 && d.Is(direction) 224 }, doRetry, longCertChain) 225 app.run() 226 }) 227 228 It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { 229 const maxSequentiallyDropped = 10 230 var mx sync.Mutex 231 var incoming, outgoing int 232 233 startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 234 drop := mrand.Int63n(int64(3)) == 0 235 236 mx.Lock() 237 defer mx.Unlock() 238 // never drop more than 10 consecutive packets 239 if d.Is(quicproxy.DirectionIncoming) { 240 if drop { 241 incoming++ 242 if incoming > maxSequentiallyDropped { 243 drop = false 244 } 245 } 246 if !drop { 247 incoming = 0 248 } 249 } 250 if d.Is(quicproxy.DirectionOutgoing) { 251 if drop { 252 outgoing++ 253 if outgoing > maxSequentiallyDropped { 254 drop = false 255 } 256 } 257 if !drop { 258 outgoing = 0 259 } 260 } 261 return drop 262 }, doRetry, longCertChain) 263 app.run() 264 }) 265 }) 266 } 267 }) 268 } 269 }) 270 } 271 272 It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() { 273 origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient 274 defer func() { 275 wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient 276 }() 277 b := make([]byte, 2500) // the ClientHello will now span across 3 packets 278 mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b) 279 wire.AdditionalTransportParametersClient = map[uint64][]byte{ 280 // Avoid random collisions with the greased transport parameters. 281 uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, 282 } 283 284 startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { 285 if d == quicproxy.DirectionOutgoing { 286 return false 287 } 288 return mrand.Intn(3) == 0 289 }, false, false) 290 clientSpeaksFirst.run() 291 }) 292 } 293 })