github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/self/handshake_drop_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	mrand "math/rand"
     9  	"net"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/apernet/quic-go"
    15  	quicproxy "github.com/apernet/quic-go/integrationtests/tools/proxy"
    16  	"github.com/apernet/quic-go/internal/wire"
    17  	"github.com/apernet/quic-go/quicvarint"
    18  
    19  	. "github.com/onsi/ginkgo/v2"
    20  	. "github.com/onsi/gomega"
    21  	"github.com/onsi/gomega/gbytes"
    22  )
    23  
    24  var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
    25  
    26  type applicationProtocol struct {
    27  	name string
    28  	run  func(ln *quic.Listener, port int)
    29  }
    30  
    31  var _ = Describe("Handshake drop tests", func() {
    32  	data := GeneratePRData(5000)
    33  	const timeout = 2 * time.Minute
    34  
    35  	startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
    36  		conf := getQuicConfig(&quic.Config{
    37  			MaxIdleTimeout:       timeout,
    38  			HandshakeIdleTimeout: timeout,
    39  		})
    40  		var tlsConf *tls.Config
    41  		if longCertChain {
    42  			tlsConf = getTLSConfigWithLongCertChain()
    43  		} else {
    44  			tlsConf = getTLSConfig()
    45  		}
    46  		laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
    47  		Expect(err).ToNot(HaveOccurred())
    48  		conn, err := net.ListenUDP("udp", laddr)
    49  		Expect(err).ToNot(HaveOccurred())
    50  		tr := &quic.Transport{Conn: conn}
    51  		if doRetry {
    52  			tr.VerifySourceAddress = func(net.Addr) bool { return true }
    53  		}
    54  		ln, err = tr.Listen(tlsConf, conf)
    55  		Expect(err).ToNot(HaveOccurred())
    56  		serverPort := ln.Addr().(*net.UDPAddr).Port
    57  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    58  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    59  			DropPacket: dropCallback,
    60  			DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
    61  				return 10 * time.Millisecond
    62  			},
    63  		})
    64  		Expect(err).ToNot(HaveOccurred())
    65  
    66  		return ln, proxy.LocalPort(), func() {
    67  			ln.Close()
    68  			tr.Close()
    69  			conn.Close()
    70  			proxy.Close()
    71  		}
    72  	}
    73  
    74  	clientSpeaksFirst := &applicationProtocol{
    75  		name: "client speaks first",
    76  		run: func(ln *quic.Listener, port int) {
    77  			serverConnChan := make(chan quic.Connection)
    78  			go func() {
    79  				defer GinkgoRecover()
    80  				conn, err := ln.Accept(context.Background())
    81  				Expect(err).ToNot(HaveOccurred())
    82  				defer conn.CloseWithError(0, "")
    83  				str, err := conn.AcceptStream(context.Background())
    84  				Expect(err).ToNot(HaveOccurred())
    85  				b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
    86  				Expect(err).ToNot(HaveOccurred())
    87  				Expect(b).To(Equal(data))
    88  				serverConnChan <- conn
    89  			}()
    90  			conn, err := quic.DialAddr(
    91  				context.Background(),
    92  				fmt.Sprintf("localhost:%d", port),
    93  				getTLSClientConfig(),
    94  				getQuicConfig(&quic.Config{
    95  					MaxIdleTimeout:       timeout,
    96  					HandshakeIdleTimeout: timeout,
    97  				}),
    98  			)
    99  			Expect(err).ToNot(HaveOccurred())
   100  			str, err := conn.OpenStream()
   101  			Expect(err).ToNot(HaveOccurred())
   102  			_, err = str.Write(data)
   103  			Expect(err).ToNot(HaveOccurred())
   104  			Expect(str.Close()).To(Succeed())
   105  
   106  			var serverConn quic.Connection
   107  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   108  			conn.CloseWithError(0, "")
   109  			serverConn.CloseWithError(0, "")
   110  		},
   111  	}
   112  
   113  	serverSpeaksFirst := &applicationProtocol{
   114  		name: "server speaks first",
   115  		run: func(ln *quic.Listener, port int) {
   116  			serverConnChan := make(chan quic.Connection)
   117  			go func() {
   118  				defer GinkgoRecover()
   119  				conn, err := ln.Accept(context.Background())
   120  				Expect(err).ToNot(HaveOccurred())
   121  				str, err := conn.OpenStream()
   122  				Expect(err).ToNot(HaveOccurred())
   123  				_, err = str.Write(data)
   124  				Expect(err).ToNot(HaveOccurred())
   125  				Expect(str.Close()).To(Succeed())
   126  				serverConnChan <- conn
   127  			}()
   128  			conn, err := quic.DialAddr(
   129  				context.Background(),
   130  				fmt.Sprintf("localhost:%d", port),
   131  				getTLSClientConfig(),
   132  				getQuicConfig(&quic.Config{
   133  					MaxIdleTimeout:       timeout,
   134  					HandshakeIdleTimeout: timeout,
   135  				}),
   136  			)
   137  			Expect(err).ToNot(HaveOccurred())
   138  			str, err := conn.AcceptStream(context.Background())
   139  			Expect(err).ToNot(HaveOccurred())
   140  			b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
   141  			Expect(err).ToNot(HaveOccurred())
   142  			Expect(b).To(Equal(data))
   143  
   144  			var serverConn quic.Connection
   145  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   146  			conn.CloseWithError(0, "")
   147  			serverConn.CloseWithError(0, "")
   148  		},
   149  	}
   150  
   151  	nobodySpeaks := &applicationProtocol{
   152  		name: "nobody speaks",
   153  		run: func(ln *quic.Listener, port int) {
   154  			serverConnChan := make(chan quic.Connection)
   155  			go func() {
   156  				defer GinkgoRecover()
   157  				conn, err := ln.Accept(context.Background())
   158  				Expect(err).ToNot(HaveOccurred())
   159  				serverConnChan <- conn
   160  			}()
   161  			conn, err := quic.DialAddr(
   162  				context.Background(),
   163  				fmt.Sprintf("localhost:%d", port),
   164  				getTLSClientConfig(),
   165  				getQuicConfig(&quic.Config{
   166  					MaxIdleTimeout:       timeout,
   167  					HandshakeIdleTimeout: timeout,
   168  				}),
   169  			)
   170  			Expect(err).ToNot(HaveOccurred())
   171  			var serverConn quic.Connection
   172  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   173  			// both server and client accepted a connection. Close now.
   174  			conn.CloseWithError(0, "")
   175  			serverConn.CloseWithError(0, "")
   176  		},
   177  	}
   178  
   179  	for _, d := range directions {
   180  		direction := d
   181  
   182  		for _, dr := range []bool{true, false} {
   183  			doRetry := dr
   184  			desc := "when using Retry"
   185  			if !dr {
   186  				desc = "when not using Retry"
   187  			}
   188  
   189  			Context(desc, func() {
   190  				for _, lcc := range []bool{false, true} {
   191  					longCertChain := lcc
   192  
   193  					Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
   194  						for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
   195  							app := a
   196  
   197  							Context(app.name, func() {
   198  								It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
   199  									var incoming, outgoing atomic.Int32
   200  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   201  										var p int32
   202  										//nolint:exhaustive
   203  										switch d {
   204  										case quicproxy.DirectionIncoming:
   205  											p = incoming.Add(1)
   206  										case quicproxy.DirectionOutgoing:
   207  											p = outgoing.Add(1)
   208  										}
   209  										return p == 1 && d.Is(direction)
   210  									}, doRetry, longCertChain)
   211  									defer closeFn()
   212  									app.run(ln, proxyPort)
   213  								})
   214  
   215  								It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
   216  									var incoming, outgoing atomic.Int32
   217  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   218  										var p int32
   219  										//nolint:exhaustive
   220  										switch d {
   221  										case quicproxy.DirectionIncoming:
   222  											p = incoming.Add(1)
   223  										case quicproxy.DirectionOutgoing:
   224  											p = outgoing.Add(1)
   225  										}
   226  										return p == 2 && d.Is(direction)
   227  									}, doRetry, longCertChain)
   228  									defer closeFn()
   229  									app.run(ln, proxyPort)
   230  								})
   231  
   232  								It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
   233  									const maxSequentiallyDropped = 10
   234  									var mx sync.Mutex
   235  									var incoming, outgoing int
   236  
   237  									ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   238  										drop := mrand.Int63n(int64(3)) == 0
   239  
   240  										mx.Lock()
   241  										defer mx.Unlock()
   242  										// never drop more than 10 consecutive packets
   243  										if d.Is(quicproxy.DirectionIncoming) {
   244  											if drop {
   245  												incoming++
   246  												if incoming > maxSequentiallyDropped {
   247  													drop = false
   248  												}
   249  											}
   250  											if !drop {
   251  												incoming = 0
   252  											}
   253  										}
   254  										if d.Is(quicproxy.DirectionOutgoing) {
   255  											if drop {
   256  												outgoing++
   257  												if outgoing > maxSequentiallyDropped {
   258  													drop = false
   259  												}
   260  											}
   261  											if !drop {
   262  												outgoing = 0
   263  											}
   264  										}
   265  										return drop
   266  									}, doRetry, longCertChain)
   267  									defer closeFn()
   268  									app.run(ln, proxyPort)
   269  								})
   270  							})
   271  						}
   272  					})
   273  				}
   274  			})
   275  		}
   276  
   277  		It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
   278  			origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
   279  			defer func() {
   280  				wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
   281  			}()
   282  			b := make([]byte, 2500) // the ClientHello will now span across 3 packets
   283  			mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
   284  			wire.AdditionalTransportParametersClient = map[uint64][]byte{
   285  				// Avoid random collisions with the greased transport parameters.
   286  				uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
   287  			}
   288  
   289  			ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   290  				if d == quicproxy.DirectionOutgoing {
   291  					return false
   292  				}
   293  				return mrand.Intn(3) == 0
   294  			}, false, false)
   295  			defer closeFn()
   296  			clientSpeaksFirst.run(ln, proxyPort)
   297  		})
   298  	}
   299  })