github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/handshake_drop_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	mrand "math/rand"
     9  	"net"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/MerlinKodo/quic-go/quicvarint"
    15  
    16  	"github.com/MerlinKodo/quic-go"
    17  	quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy"
    18  	"github.com/MerlinKodo/quic-go/internal/wire"
    19  
    20  	. "github.com/onsi/ginkgo/v2"
    21  	. "github.com/onsi/gomega"
    22  	"github.com/onsi/gomega/gbytes"
    23  )
    24  
    25  var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
    26  
    27  type applicationProtocol struct {
    28  	name string
    29  	run  func()
    30  }
    31  
    32  var _ = Describe("Handshake drop tests", func() {
    33  	var (
    34  		proxy *quicproxy.QuicProxy
    35  		ln    *quic.Listener
    36  	)
    37  
    38  	data := GeneratePRData(5000)
    39  	const timeout = 2 * time.Minute
    40  
    41  	startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) {
    42  		conf := getQuicConfig(&quic.Config{
    43  			MaxIdleTimeout:           timeout,
    44  			HandshakeIdleTimeout:     timeout,
    45  			RequireAddressValidation: func(net.Addr) bool { return doRetry },
    46  		})
    47  		var tlsConf *tls.Config
    48  		if longCertChain {
    49  			tlsConf = getTLSConfigWithLongCertChain()
    50  		} else {
    51  			tlsConf = getTLSConfig()
    52  		}
    53  		var err error
    54  		ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
    55  		Expect(err).ToNot(HaveOccurred())
    56  		serverPort := ln.Addr().(*net.UDPAddr).Port
    57  		proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    58  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    59  			DropPacket: dropCallback,
    60  			DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
    61  				return 10 * time.Millisecond
    62  			},
    63  		})
    64  		Expect(err).ToNot(HaveOccurred())
    65  	}
    66  
    67  	clientSpeaksFirst := &applicationProtocol{
    68  		name: "client speaks first",
    69  		run: func() {
    70  			serverConnChan := make(chan quic.Connection)
    71  			go func() {
    72  				defer GinkgoRecover()
    73  				conn, err := ln.Accept(context.Background())
    74  				Expect(err).ToNot(HaveOccurred())
    75  				defer conn.CloseWithError(0, "")
    76  				str, err := conn.AcceptStream(context.Background())
    77  				Expect(err).ToNot(HaveOccurred())
    78  				b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
    79  				Expect(err).ToNot(HaveOccurred())
    80  				Expect(b).To(Equal(data))
    81  				serverConnChan <- conn
    82  			}()
    83  			conn, err := quic.DialAddr(
    84  				context.Background(),
    85  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
    86  				getTLSClientConfig(),
    87  				getQuicConfig(&quic.Config{
    88  					MaxIdleTimeout:       timeout,
    89  					HandshakeIdleTimeout: timeout,
    90  				}),
    91  			)
    92  			Expect(err).ToNot(HaveOccurred())
    93  			str, err := conn.OpenStream()
    94  			Expect(err).ToNot(HaveOccurred())
    95  			_, err = str.Write(data)
    96  			Expect(err).ToNot(HaveOccurred())
    97  			Expect(str.Close()).To(Succeed())
    98  
    99  			var serverConn quic.Connection
   100  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   101  			conn.CloseWithError(0, "")
   102  			serverConn.CloseWithError(0, "")
   103  		},
   104  	}
   105  
   106  	serverSpeaksFirst := &applicationProtocol{
   107  		name: "server speaks first",
   108  		run: func() {
   109  			serverConnChan := make(chan quic.Connection)
   110  			go func() {
   111  				defer GinkgoRecover()
   112  				conn, err := ln.Accept(context.Background())
   113  				Expect(err).ToNot(HaveOccurred())
   114  				str, err := conn.OpenStream()
   115  				Expect(err).ToNot(HaveOccurred())
   116  				_, err = str.Write(data)
   117  				Expect(err).ToNot(HaveOccurred())
   118  				Expect(str.Close()).To(Succeed())
   119  				serverConnChan <- conn
   120  			}()
   121  			conn, err := quic.DialAddr(
   122  				context.Background(),
   123  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   124  				getTLSClientConfig(),
   125  				getQuicConfig(&quic.Config{
   126  					MaxIdleTimeout:       timeout,
   127  					HandshakeIdleTimeout: timeout,
   128  				}),
   129  			)
   130  			Expect(err).ToNot(HaveOccurred())
   131  			str, err := conn.AcceptStream(context.Background())
   132  			Expect(err).ToNot(HaveOccurred())
   133  			b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
   134  			Expect(err).ToNot(HaveOccurred())
   135  			Expect(b).To(Equal(data))
   136  
   137  			var serverConn quic.Connection
   138  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   139  			conn.CloseWithError(0, "")
   140  			serverConn.CloseWithError(0, "")
   141  		},
   142  	}
   143  
   144  	nobodySpeaks := &applicationProtocol{
   145  		name: "nobody speaks",
   146  		run: func() {
   147  			serverConnChan := make(chan quic.Connection)
   148  			go func() {
   149  				defer GinkgoRecover()
   150  				conn, err := ln.Accept(context.Background())
   151  				Expect(err).ToNot(HaveOccurred())
   152  				serverConnChan <- conn
   153  			}()
   154  			conn, err := quic.DialAddr(
   155  				context.Background(),
   156  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   157  				getTLSClientConfig(),
   158  				getQuicConfig(&quic.Config{
   159  					MaxIdleTimeout:       timeout,
   160  					HandshakeIdleTimeout: timeout,
   161  				}),
   162  			)
   163  			Expect(err).ToNot(HaveOccurred())
   164  			var serverConn quic.Connection
   165  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   166  			// both server and client accepted a connection. Close now.
   167  			conn.CloseWithError(0, "")
   168  			serverConn.CloseWithError(0, "")
   169  		},
   170  	}
   171  
   172  	AfterEach(func() {
   173  		Expect(ln.Close()).To(Succeed())
   174  		Expect(proxy.Close()).To(Succeed())
   175  	})
   176  
   177  	for _, d := range directions {
   178  		direction := d
   179  
   180  		for _, dr := range []bool{true, false} {
   181  			doRetry := dr
   182  			desc := "when using Retry"
   183  			if !dr {
   184  				desc = "when not using Retry"
   185  			}
   186  
   187  			Context(desc, func() {
   188  				for _, lcc := range []bool{false, true} {
   189  					longCertChain := lcc
   190  
   191  					Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
   192  						for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
   193  							app := a
   194  
   195  							Context(app.name, func() {
   196  								It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
   197  									var incoming, outgoing int32
   198  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   199  										var p int32
   200  										//nolint:exhaustive
   201  										switch d {
   202  										case quicproxy.DirectionIncoming:
   203  											p = atomic.AddInt32(&incoming, 1)
   204  										case quicproxy.DirectionOutgoing:
   205  											p = atomic.AddInt32(&outgoing, 1)
   206  										}
   207  										return p == 1 && d.Is(direction)
   208  									}, doRetry, longCertChain)
   209  									app.run()
   210  								})
   211  
   212  								It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
   213  									var incoming, outgoing int32
   214  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   215  										var p int32
   216  										//nolint:exhaustive
   217  										switch d {
   218  										case quicproxy.DirectionIncoming:
   219  											p = atomic.AddInt32(&incoming, 1)
   220  										case quicproxy.DirectionOutgoing:
   221  											p = atomic.AddInt32(&outgoing, 1)
   222  										}
   223  										return p == 2 && d.Is(direction)
   224  									}, doRetry, longCertChain)
   225  									app.run()
   226  								})
   227  
   228  								It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
   229  									const maxSequentiallyDropped = 10
   230  									var mx sync.Mutex
   231  									var incoming, outgoing int
   232  
   233  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   234  										drop := mrand.Int63n(int64(3)) == 0
   235  
   236  										mx.Lock()
   237  										defer mx.Unlock()
   238  										// never drop more than 10 consecutive packets
   239  										if d.Is(quicproxy.DirectionIncoming) {
   240  											if drop {
   241  												incoming++
   242  												if incoming > maxSequentiallyDropped {
   243  													drop = false
   244  												}
   245  											}
   246  											if !drop {
   247  												incoming = 0
   248  											}
   249  										}
   250  										if d.Is(quicproxy.DirectionOutgoing) {
   251  											if drop {
   252  												outgoing++
   253  												if outgoing > maxSequentiallyDropped {
   254  													drop = false
   255  												}
   256  											}
   257  											if !drop {
   258  												outgoing = 0
   259  											}
   260  										}
   261  										return drop
   262  									}, doRetry, longCertChain)
   263  									app.run()
   264  								})
   265  							})
   266  						}
   267  					})
   268  				}
   269  			})
   270  		}
   271  
   272  		It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
   273  			origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
   274  			defer func() {
   275  				wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
   276  			}()
   277  			b := make([]byte, 2500) // the ClientHello will now span across 3 packets
   278  			mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
   279  			wire.AdditionalTransportParametersClient = map[uint64][]byte{
   280  				// Avoid random collisions with the greased transport parameters.
   281  				uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
   282  			}
   283  
   284  			startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   285  				if d == quicproxy.DirectionOutgoing {
   286  					return false
   287  				}
   288  				return mrand.Intn(3) == 0
   289  			}, false, false)
   290  			clientSpeaksFirst.run()
   291  		})
   292  	}
   293  })