github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/handshake_drop_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	mrand "math/rand"
     9  	"net"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/quicvarint"
    15  
    16  	"github.com/danielpfeifer02/quic-go-prio-packs"
    17  	quicproxy "github.com/danielpfeifer02/quic-go-prio-packs/integrationtests/tools/proxy"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    19  
    20  	. "github.com/onsi/ginkgo/v2"
    21  	. "github.com/onsi/gomega"
    22  	"github.com/onsi/gomega/gbytes"
    23  )
    24  
    25  var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
    26  
    27  type applicationProtocol struct {
    28  	name string
    29  	run  func(ln *quic.Listener, port int)
    30  }
    31  
    32  var _ = Describe("Handshake drop tests", func() {
    33  	data := GeneratePRData(5000)
    34  	const timeout = 2 * time.Minute
    35  
    36  	startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
    37  		conf := getQuicConfig(&quic.Config{
    38  			MaxIdleTimeout:       timeout,
    39  			HandshakeIdleTimeout: timeout,
    40  		})
    41  		var tlsConf *tls.Config
    42  		if longCertChain {
    43  			tlsConf = getTLSConfigWithLongCertChain()
    44  		} else {
    45  			tlsConf = getTLSConfig()
    46  		}
    47  		laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
    48  		Expect(err).ToNot(HaveOccurred())
    49  		conn, err := net.ListenUDP("udp", laddr)
    50  		Expect(err).ToNot(HaveOccurred())
    51  		tr := &quic.Transport{Conn: conn}
    52  		if doRetry {
    53  			tr.MaxUnvalidatedHandshakes = -1
    54  		}
    55  		ln, err = tr.Listen(tlsConf, conf)
    56  		Expect(err).ToNot(HaveOccurred())
    57  		serverPort := ln.Addr().(*net.UDPAddr).Port
    58  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    59  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    60  			DropPacket: dropCallback,
    61  			DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
    62  				return 10 * time.Millisecond
    63  			},
    64  		})
    65  		Expect(err).ToNot(HaveOccurred())
    66  
    67  		return ln, proxy.LocalPort(), func() {
    68  			ln.Close()
    69  			tr.Close()
    70  			conn.Close()
    71  			proxy.Close()
    72  		}
    73  	}
    74  
    75  	clientSpeaksFirst := &applicationProtocol{
    76  		name: "client speaks first",
    77  		run: func(ln *quic.Listener, port int) {
    78  			serverConnChan := make(chan quic.Connection)
    79  			go func() {
    80  				defer GinkgoRecover()
    81  				conn, err := ln.Accept(context.Background())
    82  				Expect(err).ToNot(HaveOccurred())
    83  				defer conn.CloseWithError(0, "")
    84  				str, err := conn.AcceptStream(context.Background())
    85  				Expect(err).ToNot(HaveOccurred())
    86  				b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
    87  				Expect(err).ToNot(HaveOccurred())
    88  				Expect(b).To(Equal(data))
    89  				serverConnChan <- conn
    90  			}()
    91  			conn, err := quic.DialAddr(
    92  				context.Background(),
    93  				fmt.Sprintf("localhost:%d", port),
    94  				getTLSClientConfig(),
    95  				getQuicConfig(&quic.Config{
    96  					MaxIdleTimeout:       timeout,
    97  					HandshakeIdleTimeout: timeout,
    98  				}),
    99  			)
   100  			Expect(err).ToNot(HaveOccurred())
   101  			str, err := conn.OpenStream()
   102  			Expect(err).ToNot(HaveOccurred())
   103  			_, err = str.Write(data)
   104  			Expect(err).ToNot(HaveOccurred())
   105  			Expect(str.Close()).To(Succeed())
   106  
   107  			var serverConn quic.Connection
   108  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   109  			conn.CloseWithError(0, "")
   110  			serverConn.CloseWithError(0, "")
   111  		},
   112  	}
   113  
   114  	serverSpeaksFirst := &applicationProtocol{
   115  		name: "server speaks first",
   116  		run: func(ln *quic.Listener, port int) {
   117  			serverConnChan := make(chan quic.Connection)
   118  			go func() {
   119  				defer GinkgoRecover()
   120  				conn, err := ln.Accept(context.Background())
   121  				Expect(err).ToNot(HaveOccurred())
   122  				str, err := conn.OpenStream()
   123  				Expect(err).ToNot(HaveOccurred())
   124  				_, err = str.Write(data)
   125  				Expect(err).ToNot(HaveOccurred())
   126  				Expect(str.Close()).To(Succeed())
   127  				serverConnChan <- conn
   128  			}()
   129  			conn, err := quic.DialAddr(
   130  				context.Background(),
   131  				fmt.Sprintf("localhost:%d", port),
   132  				getTLSClientConfig(),
   133  				getQuicConfig(&quic.Config{
   134  					MaxIdleTimeout:       timeout,
   135  					HandshakeIdleTimeout: timeout,
   136  				}),
   137  			)
   138  			Expect(err).ToNot(HaveOccurred())
   139  			str, err := conn.AcceptStream(context.Background())
   140  			Expect(err).ToNot(HaveOccurred())
   141  			b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
   142  			Expect(err).ToNot(HaveOccurred())
   143  			Expect(b).To(Equal(data))
   144  
   145  			var serverConn quic.Connection
   146  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   147  			conn.CloseWithError(0, "")
   148  			serverConn.CloseWithError(0, "")
   149  		},
   150  	}
   151  
   152  	nobodySpeaks := &applicationProtocol{
   153  		name: "nobody speaks",
   154  		run: func(ln *quic.Listener, port int) {
   155  			serverConnChan := make(chan quic.Connection)
   156  			go func() {
   157  				defer GinkgoRecover()
   158  				conn, err := ln.Accept(context.Background())
   159  				Expect(err).ToNot(HaveOccurred())
   160  				serverConnChan <- conn
   161  			}()
   162  			conn, err := quic.DialAddr(
   163  				context.Background(),
   164  				fmt.Sprintf("localhost:%d", port),
   165  				getTLSClientConfig(),
   166  				getQuicConfig(&quic.Config{
   167  					MaxIdleTimeout:       timeout,
   168  					HandshakeIdleTimeout: timeout,
   169  				}),
   170  			)
   171  			Expect(err).ToNot(HaveOccurred())
   172  			var serverConn quic.Connection
   173  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   174  			// both server and client accepted a connection. Close now.
   175  			conn.CloseWithError(0, "")
   176  			serverConn.CloseWithError(0, "")
   177  		},
   178  	}
   179  
   180  	for _, d := range directions {
   181  		direction := d
   182  
   183  		for _, dr := range []bool{true, false} {
   184  			doRetry := dr
   185  			desc := "when using Retry"
   186  			if !dr {
   187  				desc = "when not using Retry"
   188  			}
   189  
   190  			Context(desc, func() {
   191  				for _, lcc := range []bool{false, true} {
   192  					longCertChain := lcc
   193  
   194  					Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
   195  						for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
   196  							app := a
   197  
   198  							Context(app.name, func() {
   199  								It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
   200  									var incoming, outgoing atomic.Int32
   201  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   202  										var p int32
   203  										//nolint:exhaustive
   204  										switch d {
   205  										case quicproxy.DirectionIncoming:
   206  											p = incoming.Add(1)
   207  										case quicproxy.DirectionOutgoing:
   208  											p = outgoing.Add(1)
   209  										}
   210  										return p == 1 && d.Is(direction)
   211  									}, doRetry, longCertChain)
   212  									defer closeFn()
   213  									app.run(ln, proxyPort)
   214  								})
   215  
   216  								It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
   217  									var incoming, outgoing atomic.Int32
   218  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   219  										var p int32
   220  										//nolint:exhaustive
   221  										switch d {
   222  										case quicproxy.DirectionIncoming:
   223  											p = incoming.Add(1)
   224  										case quicproxy.DirectionOutgoing:
   225  											p = outgoing.Add(1)
   226  										}
   227  										return p == 2 && d.Is(direction)
   228  									}, doRetry, longCertChain)
   229  									defer closeFn()
   230  									app.run(ln, proxyPort)
   231  								})
   232  
   233  								It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
   234  									const maxSequentiallyDropped = 10
   235  									var mx sync.Mutex
   236  									var incoming, outgoing int
   237  
   238  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   239  										drop := mrand.Int63n(int64(3)) == 0
   240  
   241  										mx.Lock()
   242  										defer mx.Unlock()
   243  										// never drop more than 10 consecutive packets
   244  										if d.Is(quicproxy.DirectionIncoming) {
   245  											if drop {
   246  												incoming++
   247  												if incoming > maxSequentiallyDropped {
   248  													drop = false
   249  												}
   250  											}
   251  											if !drop {
   252  												incoming = 0
   253  											}
   254  										}
   255  										if d.Is(quicproxy.DirectionOutgoing) {
   256  											if drop {
   257  												outgoing++
   258  												if outgoing > maxSequentiallyDropped {
   259  													drop = false
   260  												}
   261  											}
   262  											if !drop {
   263  												outgoing = 0
   264  											}
   265  										}
   266  										return drop
   267  									}, doRetry, longCertChain)
   268  									defer closeFn()
   269  									app.run(ln, proxyPort)
   270  								})
   271  							})
   272  						}
   273  					})
   274  				}
   275  			})
   276  		}
   277  
   278  		It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
   279  			origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
   280  			defer func() {
   281  				wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
   282  			}()
   283  			b := make([]byte, 2500) // the ClientHello will now span across 3 packets
   284  			mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
   285  			wire.AdditionalTransportParametersClient = map[uint64][]byte{
   286  				// Avoid random collisions with the greased transport parameters.
   287  				uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
   288  			}
   289  
   290  			ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   291  				if d == quicproxy.DirectionOutgoing {
   292  					return false
   293  				}
   294  				return mrand.Intn(3) == 0
   295  			}, false, false)
   296  			defer closeFn()
   297  			clientSpeaksFirst.run(ln, proxyPort)
   298  		})
   299  	}
   300  })