github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/self/handshake_drop_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	mrand "math/rand"
     9  	"net"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/tumi8/quic-go/quicvarint"
    15  	"github.com/tumi8/quic-go"
    16  	quicproxy "github.com/tumi8/quic-go/integrationtests/tools/proxy"
    17  	"github.com/tumi8/quic-go/noninternal/wire"
    18  
    19  	. "github.com/onsi/ginkgo/v2"
    20  	. "github.com/onsi/gomega"
    21  	"github.com/onsi/gomega/gbytes"
    22  )
    23  
    24  var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
    25  
    26  type applicationProtocol struct {
    27  	name string
    28  	run  func()
    29  }
    30  
    31  var _ = Describe("Handshake drop tests", func() {
    32  	var (
    33  		proxy *quicproxy.QuicProxy
    34  		ln    *quic.Listener
    35  	)
    36  
    37  	data := GeneratePRData(5000)
    38  	const timeout = 2 * time.Minute
    39  
    40  	startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) {
    41  		conf := getQuicConfig(&quic.Config{
    42  			MaxIdleTimeout:           timeout,
    43  			HandshakeIdleTimeout:     timeout,
    44  			RequireAddressValidation: func(net.Addr) bool { return doRetry },
    45  		})
    46  		var tlsConf *tls.Config
    47  		if longCertChain {
    48  			tlsConf = getTLSConfigWithLongCertChain()
    49  		} else {
    50  			tlsConf = getTLSConfig()
    51  		}
    52  		var err error
    53  		ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
    54  		Expect(err).ToNot(HaveOccurred())
    55  		serverPort := ln.Addr().(*net.UDPAddr).Port
    56  		proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    57  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    58  			DropPacket: dropCallback,
    59  			DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
    60  				return 10 * time.Millisecond
    61  			},
    62  		})
    63  		Expect(err).ToNot(HaveOccurred())
    64  	}
    65  
    66  	clientSpeaksFirst := &applicationProtocol{
    67  		name: "client speaks first",
    68  		run: func() {
    69  			serverConnChan := make(chan quic.Connection)
    70  			go func() {
    71  				defer GinkgoRecover()
    72  				conn, err := ln.Accept(context.Background())
    73  				Expect(err).ToNot(HaveOccurred())
    74  				defer conn.CloseWithError(0, "")
    75  				str, err := conn.AcceptStream(context.Background())
    76  				Expect(err).ToNot(HaveOccurred())
    77  				b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
    78  				Expect(err).ToNot(HaveOccurred())
    79  				Expect(b).To(Equal(data))
    80  				serverConnChan <- conn
    81  			}()
    82  			conn, err := quic.DialAddr(
    83  				context.Background(),
    84  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
    85  				getTLSClientConfig(),
    86  				getQuicConfig(&quic.Config{
    87  					MaxIdleTimeout:       timeout,
    88  					HandshakeIdleTimeout: timeout,
    89  				}),
    90  			)
    91  			Expect(err).ToNot(HaveOccurred())
    92  			str, err := conn.OpenStream()
    93  			Expect(err).ToNot(HaveOccurred())
    94  			_, err = str.Write(data)
    95  			Expect(err).ToNot(HaveOccurred())
    96  			Expect(str.Close()).To(Succeed())
    97  
    98  			var serverConn quic.Connection
    99  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   100  			conn.CloseWithError(0, "")
   101  			serverConn.CloseWithError(0, "")
   102  		},
   103  	}
   104  
   105  	serverSpeaksFirst := &applicationProtocol{
   106  		name: "server speaks first",
   107  		run: func() {
   108  			serverConnChan := make(chan quic.Connection)
   109  			go func() {
   110  				defer GinkgoRecover()
   111  				conn, err := ln.Accept(context.Background())
   112  				Expect(err).ToNot(HaveOccurred())
   113  				str, err := conn.OpenStream()
   114  				Expect(err).ToNot(HaveOccurred())
   115  				_, err = str.Write(data)
   116  				Expect(err).ToNot(HaveOccurred())
   117  				Expect(str.Close()).To(Succeed())
   118  				serverConnChan <- conn
   119  			}()
   120  			conn, err := quic.DialAddr(
   121  				context.Background(),
   122  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   123  				getTLSClientConfig(),
   124  				getQuicConfig(&quic.Config{
   125  					MaxIdleTimeout:       timeout,
   126  					HandshakeIdleTimeout: timeout,
   127  				}),
   128  			)
   129  			Expect(err).ToNot(HaveOccurred())
   130  			str, err := conn.AcceptStream(context.Background())
   131  			Expect(err).ToNot(HaveOccurred())
   132  			b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
   133  			Expect(err).ToNot(HaveOccurred())
   134  			Expect(b).To(Equal(data))
   135  
   136  			var serverConn quic.Connection
   137  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   138  			conn.CloseWithError(0, "")
   139  			serverConn.CloseWithError(0, "")
   140  		},
   141  	}
   142  
   143  	nobodySpeaks := &applicationProtocol{
   144  		name: "nobody speaks",
   145  		run: func() {
   146  			serverConnChan := make(chan quic.Connection)
   147  			go func() {
   148  				defer GinkgoRecover()
   149  				conn, err := ln.Accept(context.Background())
   150  				Expect(err).ToNot(HaveOccurred())
   151  				serverConnChan <- conn
   152  			}()
   153  			conn, err := quic.DialAddr(
   154  				context.Background(),
   155  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   156  				getTLSClientConfig(),
   157  				getQuicConfig(&quic.Config{
   158  					MaxIdleTimeout:       timeout,
   159  					HandshakeIdleTimeout: timeout,
   160  				}),
   161  			)
   162  			Expect(err).ToNot(HaveOccurred())
   163  			var serverConn quic.Connection
   164  			Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
   165  			// both server and client accepted a connection. Close now.
   166  			conn.CloseWithError(0, "")
   167  			serverConn.CloseWithError(0, "")
   168  		},
   169  	}
   170  
   171  	AfterEach(func() {
   172  		Expect(ln.Close()).To(Succeed())
   173  		Expect(proxy.Close()).To(Succeed())
   174  	})
   175  
   176  	for _, d := range directions {
   177  		direction := d
   178  
   179  		for _, dr := range []bool{true, false} {
   180  			doRetry := dr
   181  			desc := "when using Retry"
   182  			if !dr {
   183  				desc = "when not using Retry"
   184  			}
   185  
   186  			Context(desc, func() {
   187  				for _, lcc := range []bool{false, true} {
   188  					longCertChain := lcc
   189  
   190  					Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
   191  						for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
   192  							app := a
   193  
   194  							Context(app.name, func() {
   195  								It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
   196  									var incoming, outgoing int32
   197  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   198  										var p int32
   199  										//nolint:exhaustive
   200  										switch d {
   201  										case quicproxy.DirectionIncoming:
   202  											p = atomic.AddInt32(&incoming, 1)
   203  										case quicproxy.DirectionOutgoing:
   204  											p = atomic.AddInt32(&outgoing, 1)
   205  										}
   206  										return p == 1 && d.Is(direction)
   207  									}, doRetry, longCertChain)
   208  									app.run()
   209  								})
   210  
   211  								It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
   212  									var incoming, outgoing int32
   213  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   214  										var p int32
   215  										//nolint:exhaustive
   216  										switch d {
   217  										case quicproxy.DirectionIncoming:
   218  											p = atomic.AddInt32(&incoming, 1)
   219  										case quicproxy.DirectionOutgoing:
   220  											p = atomic.AddInt32(&outgoing, 1)
   221  										}
   222  										return p == 2 && d.Is(direction)
   223  									}, doRetry, longCertChain)
   224  									app.run()
   225  								})
   226  
   227  								It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
   228  									const maxSequentiallyDropped = 10
   229  									var mx sync.Mutex
   230  									var incoming, outgoing int
   231  
   232  									startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   233  										drop := mrand.Int63n(int64(3)) == 0
   234  
   235  										mx.Lock()
   236  										defer mx.Unlock()
   237  										// never drop more than 10 consecutive packets
   238  										if d.Is(quicproxy.DirectionIncoming) {
   239  											if drop {
   240  												incoming++
   241  												if incoming > maxSequentiallyDropped {
   242  													drop = false
   243  												}
   244  											}
   245  											if !drop {
   246  												incoming = 0
   247  											}
   248  										}
   249  										if d.Is(quicproxy.DirectionOutgoing) {
   250  											if drop {
   251  												outgoing++
   252  												if outgoing > maxSequentiallyDropped {
   253  													drop = false
   254  												}
   255  											}
   256  											if !drop {
   257  												outgoing = 0
   258  											}
   259  										}
   260  										return drop
   261  									}, doRetry, longCertChain)
   262  									app.run()
   263  								})
   264  							})
   265  						}
   266  					})
   267  				}
   268  			})
   269  		}
   270  
   271  		It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
   272  			origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
   273  			defer func() {
   274  				wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
   275  			}()
   276  			b := make([]byte, 2500) // the ClientHello will now span across 3 packets
   277  			mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
   278  			wire.AdditionalTransportParametersClient = map[uint64][]byte{
   279  				// Avoid random collisions with the greased transport parameters.
   280  				uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
   281  			}
   282  
   283  			startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
   284  				if d == quicproxy.DirectionOutgoing {
   285  					return false
   286  				}
   287  				return mrand.Intn(3) == 0
   288  			}, false, false)
   289  			clientSpeaksFirst.run()
   290  		})
   291  	}
   292  })