github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/timeout_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	mrand "math/rand"
     8  	"net"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/danielpfeifer02/quic-go-prio-packs"
    13  	quicproxy "github.com/danielpfeifer02/quic-go-prio-packs/integrationtests/tools/proxy"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    16  
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  type faultyConn struct {
    22  	net.PacketConn
    23  
    24  	MaxPackets int32
    25  	counter    atomic.Int32
    26  }
    27  
    28  func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) {
    29  	n, addr, err := c.PacketConn.ReadFrom(p)
    30  	counter := c.counter.Add(1)
    31  	if counter <= c.MaxPackets {
    32  		return n, addr, err
    33  	}
    34  	return 0, nil, io.ErrClosedPipe
    35  }
    36  
    37  func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
    38  	counter := c.counter.Add(1)
    39  	if counter <= c.MaxPackets {
    40  		return c.PacketConn.WriteTo(p, addr)
    41  	}
    42  	return 0, io.ErrClosedPipe
    43  }
    44  
    45  var _ = Describe("Timeout tests", func() {
    46  	checkTimeoutError := func(err error) {
    47  		ExpectWithOffset(1, err).To(MatchError(&quic.IdleTimeoutError{}))
    48  		nerr, ok := err.(net.Error)
    49  		ExpectWithOffset(1, ok).To(BeTrue())
    50  		ExpectWithOffset(1, nerr.Timeout()).To(BeTrue())
    51  	}
    52  
    53  	It("returns net.Error timeout errors when dialing", func() {
    54  		errChan := make(chan error)
    55  		go func() {
    56  			_, err := quic.DialAddr(
    57  				context.Background(),
    58  				"localhost:12345",
    59  				getTLSClientConfig(),
    60  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
    61  			)
    62  			errChan <- err
    63  		}()
    64  		var err error
    65  		Eventually(errChan).Should(Receive(&err))
    66  		checkTimeoutError(err)
    67  	})
    68  
    69  	It("returns the context error when the context expires", func() {
    70  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    71  		defer cancel()
    72  		errChan := make(chan error)
    73  		go func() {
    74  			_, err := quic.DialAddr(
    75  				ctx,
    76  				"localhost:12345",
    77  				getTLSClientConfig(),
    78  				getQuicConfig(nil),
    79  			)
    80  			errChan <- err
    81  		}()
    82  		var err error
    83  		Eventually(errChan).Should(Receive(&err))
    84  		// This is not a net.Error timeout error
    85  		Expect(err).To(MatchError(context.DeadlineExceeded))
    86  	})
    87  
    88  	It("returns the context error when the context expires with 0RTT enabled", func() {
    89  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    90  		defer cancel()
    91  		errChan := make(chan error)
    92  		go func() {
    93  			_, err := quic.DialAddrEarly(
    94  				ctx,
    95  				"localhost:12345",
    96  				getTLSClientConfig(),
    97  				getQuicConfig(nil),
    98  			)
    99  			errChan <- err
   100  		}()
   101  		var err error
   102  		Eventually(errChan).Should(Receive(&err))
   103  		// This is not a net.Error timeout error
   104  		Expect(err).To(MatchError(context.DeadlineExceeded))
   105  	})
   106  
   107  	It("returns net.Error timeout errors when an idle timeout occurs", func() {
   108  		const idleTimeout = 500 * time.Millisecond
   109  
   110  		server, err := quic.ListenAddr(
   111  			"localhost:0",
   112  			getTLSConfig(),
   113  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   114  		)
   115  		Expect(err).ToNot(HaveOccurred())
   116  		defer server.Close()
   117  
   118  		go func() {
   119  			defer GinkgoRecover()
   120  			conn, err := server.Accept(context.Background())
   121  			Expect(err).ToNot(HaveOccurred())
   122  			str, err := conn.OpenStream()
   123  			Expect(err).ToNot(HaveOccurred())
   124  			_, err = str.Write([]byte("foobar"))
   125  			Expect(err).ToNot(HaveOccurred())
   126  		}()
   127  
   128  		var drop atomic.Bool
   129  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   130  			RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   131  			DropPacket: func(quicproxy.Direction, []byte) bool {
   132  				return drop.Load()
   133  			},
   134  		})
   135  		Expect(err).ToNot(HaveOccurred())
   136  		defer proxy.Close()
   137  
   138  		conn, err := quic.DialAddr(
   139  			context.Background(),
   140  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   141  			getTLSClientConfig(),
   142  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
   143  		)
   144  		Expect(err).ToNot(HaveOccurred())
   145  		strIn, err := conn.AcceptStream(context.Background())
   146  		Expect(err).ToNot(HaveOccurred())
   147  		strOut, err := conn.OpenStream()
   148  		Expect(err).ToNot(HaveOccurred())
   149  		_, err = strIn.Read(make([]byte, 6))
   150  		Expect(err).ToNot(HaveOccurred())
   151  
   152  		drop.Store(true)
   153  		time.Sleep(2 * idleTimeout)
   154  		_, err = strIn.Write([]byte("test"))
   155  		checkTimeoutError(err)
   156  		_, err = strIn.Read([]byte{0})
   157  		checkTimeoutError(err)
   158  		_, err = strOut.Write([]byte("test"))
   159  		checkTimeoutError(err)
   160  		_, err = strOut.Read([]byte{0})
   161  		checkTimeoutError(err)
   162  		_, err = conn.OpenStream()
   163  		checkTimeoutError(err)
   164  		_, err = conn.OpenUniStream()
   165  		checkTimeoutError(err)
   166  		_, err = conn.AcceptStream(context.Background())
   167  		checkTimeoutError(err)
   168  		_, err = conn.AcceptUniStream(context.Background())
   169  		checkTimeoutError(err)
   170  	})
   171  
   172  	Context("timing out at the right time", func() {
   173  		var idleTimeout time.Duration
   174  
   175  		BeforeEach(func() {
   176  			idleTimeout = scaleDuration(500 * time.Millisecond)
   177  		})
   178  
   179  		It("times out after inactivity", func() {
   180  			server, err := quic.ListenAddr(
   181  				"localhost:0",
   182  				getTLSConfig(),
   183  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   184  			)
   185  			Expect(err).ToNot(HaveOccurred())
   186  			defer server.Close()
   187  
   188  			serverConnChan := make(chan quic.Connection, 1)
   189  			serverConnClosed := make(chan struct{})
   190  			go func() {
   191  				defer GinkgoRecover()
   192  				conn, err := server.Accept(context.Background())
   193  				Expect(err).ToNot(HaveOccurred())
   194  				serverConnChan <- conn
   195  				conn.AcceptStream(context.Background()) // blocks until the connection is closed
   196  				close(serverConnClosed)
   197  			}()
   198  
   199  			counter, tr := newPacketTracer()
   200  			conn, err := quic.DialAddr(
   201  				context.Background(),
   202  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   203  				getTLSClientConfig(),
   204  				getQuicConfig(&quic.Config{
   205  					MaxIdleTimeout:          idleTimeout,
   206  					Tracer:                  newTracer(tr),
   207  					DisablePathMTUDiscovery: true,
   208  				}),
   209  			)
   210  			Expect(err).ToNot(HaveOccurred())
   211  			done := make(chan struct{})
   212  			go func() {
   213  				defer GinkgoRecover()
   214  				_, err := conn.AcceptStream(context.Background())
   215  				checkTimeoutError(err)
   216  				close(done)
   217  			}()
   218  			Eventually(done, 2*idleTimeout).Should(BeClosed())
   219  			var lastAckElicitingPacketSentAt time.Time
   220  			for _, p := range counter.getSentShortHeaderPackets() {
   221  				var hasAckElicitingFrame bool
   222  				for _, f := range p.frames {
   223  					if _, ok := f.(*logging.AckFrame); ok {
   224  						continue
   225  					}
   226  					hasAckElicitingFrame = true
   227  					break
   228  				}
   229  				if hasAckElicitingFrame {
   230  					lastAckElicitingPacketSentAt = p.time
   231  				}
   232  			}
   233  			rcvdPackets := counter.getRcvdShortHeaderPackets()
   234  			lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
   235  			// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
   236  			// This is ok since we're dealing with a lossless connection here,
   237  			// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
   238  			Expect(time.Since(utils.MaxTime(lastAckElicitingPacketSentAt, lastPacketRcvdAt))).To(And(
   239  				BeNumerically(">=", idleTimeout),
   240  				BeNumerically("<", idleTimeout*6/5),
   241  			))
   242  			Consistently(serverConnClosed).ShouldNot(BeClosed())
   243  
   244  			// make the go routine return
   245  			(<-serverConnChan).CloseWithError(0, "")
   246  			Eventually(serverConnClosed).Should(BeClosed())
   247  		})
   248  
   249  		It("times out after sending a packet", func() {
   250  			server, err := quic.ListenAddr(
   251  				"localhost:0",
   252  				getTLSConfig(),
   253  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   254  			)
   255  			Expect(err).ToNot(HaveOccurred())
   256  			defer server.Close()
   257  
   258  			var drop atomic.Bool
   259  			proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   260  				RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   261  				DropPacket: func(dir quicproxy.Direction, _ []byte) bool {
   262  					if dir == quicproxy.DirectionOutgoing {
   263  						return drop.Load()
   264  					}
   265  					return false
   266  				},
   267  			})
   268  			Expect(err).ToNot(HaveOccurred())
   269  			defer proxy.Close()
   270  
   271  			serverConnChan := make(chan quic.Connection, 1)
   272  			serverConnClosed := make(chan struct{})
   273  			go func() {
   274  				defer GinkgoRecover()
   275  				conn, err := server.Accept(context.Background())
   276  				Expect(err).ToNot(HaveOccurred())
   277  				serverConnChan <- conn
   278  				<-conn.Context().Done() // block until the connection is closed
   279  				close(serverConnClosed)
   280  			}()
   281  
   282  			conn, err := quic.DialAddr(
   283  				context.Background(),
   284  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   285  				getTLSClientConfig(),
   286  				getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
   287  			)
   288  			Expect(err).ToNot(HaveOccurred())
   289  
   290  			// wait half the idle timeout, then send a packet
   291  			time.Sleep(idleTimeout / 2)
   292  			drop.Store(true)
   293  			str, err := conn.OpenUniStream()
   294  			Expect(err).ToNot(HaveOccurred())
   295  			_, err = str.Write([]byte("foobar"))
   296  			Expect(err).ToNot(HaveOccurred())
   297  
   298  			// now make sure that the idle timeout is based on this packet
   299  			startTime := time.Now()
   300  			done := make(chan struct{})
   301  			go func() {
   302  				defer GinkgoRecover()
   303  				_, err := conn.AcceptStream(context.Background())
   304  				checkTimeoutError(err)
   305  				close(done)
   306  			}()
   307  			Eventually(done, 2*idleTimeout).Should(BeClosed())
   308  			dur := time.Since(startTime)
   309  			Expect(dur).To(And(
   310  				BeNumerically(">=", idleTimeout),
   311  				BeNumerically("<", idleTimeout*12/10),
   312  			))
   313  			Consistently(serverConnClosed).ShouldNot(BeClosed())
   314  
   315  			// make the go routine return
   316  			(<-serverConnChan).CloseWithError(0, "")
   317  			Eventually(serverConnClosed).Should(BeClosed())
   318  		})
   319  	})
   320  
   321  	It("does not time out if keepalive is set", func() {
   322  		const idleTimeout = 500 * time.Millisecond
   323  
   324  		server, err := quic.ListenAddr(
   325  			"localhost:0",
   326  			getTLSConfig(),
   327  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   328  		)
   329  		Expect(err).ToNot(HaveOccurred())
   330  		defer server.Close()
   331  
   332  		serverConnChan := make(chan quic.Connection, 1)
   333  		serverConnClosed := make(chan struct{})
   334  		go func() {
   335  			defer GinkgoRecover()
   336  			conn, err := server.Accept(context.Background())
   337  			Expect(err).ToNot(HaveOccurred())
   338  			serverConnChan <- conn
   339  			conn.AcceptStream(context.Background()) // blocks until the connection is closed
   340  			close(serverConnClosed)
   341  		}()
   342  
   343  		var drop atomic.Bool
   344  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   345  			RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   346  			DropPacket: func(quicproxy.Direction, []byte) bool {
   347  				return drop.Load()
   348  			},
   349  		})
   350  		Expect(err).ToNot(HaveOccurred())
   351  		defer proxy.Close()
   352  
   353  		conn, err := quic.DialAddr(
   354  			context.Background(),
   355  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   356  			getTLSClientConfig(),
   357  			getQuicConfig(&quic.Config{
   358  				MaxIdleTimeout:          idleTimeout,
   359  				KeepAlivePeriod:         idleTimeout / 2,
   360  				DisablePathMTUDiscovery: true,
   361  			}),
   362  		)
   363  		Expect(err).ToNot(HaveOccurred())
   364  
   365  		// wait longer than the idle timeout
   366  		time.Sleep(3 * idleTimeout)
   367  		str, err := conn.OpenUniStream()
   368  		Expect(err).ToNot(HaveOccurred())
   369  		_, err = str.Write([]byte("foobar"))
   370  		Expect(err).ToNot(HaveOccurred())
   371  		Consistently(serverConnClosed).ShouldNot(BeClosed())
   372  
   373  		// idle timeout will still kick in if pings are dropped
   374  		drop.Store(true)
   375  		time.Sleep(2 * idleTimeout)
   376  		_, err = str.Write([]byte("foobar"))
   377  		checkTimeoutError(err)
   378  
   379  		(<-serverConnChan).CloseWithError(0, "")
   380  		Eventually(serverConnClosed).Should(BeClosed())
   381  	})
   382  
   383  	Context("faulty packet conns", func() {
   384  		const handshakeTimeout = time.Second / 2
   385  
   386  		runServer := func(ln *quic.Listener) error {
   387  			conn, err := ln.Accept(context.Background())
   388  			if err != nil {
   389  				return err
   390  			}
   391  			str, err := conn.OpenUniStream()
   392  			if err != nil {
   393  				return err
   394  			}
   395  			defer str.Close()
   396  			_, err = str.Write(PRData)
   397  			return err
   398  		}
   399  
   400  		runClient := func(conn quic.Connection) error {
   401  			str, err := conn.AcceptUniStream(context.Background())
   402  			if err != nil {
   403  				return err
   404  			}
   405  			data, err := io.ReadAll(str)
   406  			if err != nil {
   407  				return err
   408  			}
   409  			Expect(data).To(Equal(PRData))
   410  			return conn.CloseWithError(0, "done")
   411  		}
   412  
   413  		It("deals with an erroring packet conn, on the server side", func() {
   414  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   415  			Expect(err).ToNot(HaveOccurred())
   416  			conn, err := net.ListenUDP("udp", addr)
   417  			Expect(err).ToNot(HaveOccurred())
   418  			maxPackets := mrand.Int31n(25)
   419  			fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
   420  			ln, err := quic.Listen(
   421  				&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
   422  				getTLSConfig(),
   423  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   424  			)
   425  			Expect(err).ToNot(HaveOccurred())
   426  
   427  			serverErrChan := make(chan error, 1)
   428  			go func() {
   429  				defer GinkgoRecover()
   430  				serverErrChan <- runServer(ln)
   431  			}()
   432  
   433  			clientErrChan := make(chan error, 1)
   434  			go func() {
   435  				defer GinkgoRecover()
   436  				conn, err := quic.DialAddr(
   437  					context.Background(),
   438  					fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   439  					getTLSClientConfig(),
   440  					getQuicConfig(&quic.Config{
   441  						HandshakeIdleTimeout:    handshakeTimeout,
   442  						MaxIdleTimeout:          handshakeTimeout,
   443  						DisablePathMTUDiscovery: true,
   444  					}),
   445  				)
   446  				if err != nil {
   447  					clientErrChan <- err
   448  					return
   449  				}
   450  				clientErrChan <- runClient(conn)
   451  			}()
   452  
   453  			var clientErr error
   454  			Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
   455  			Expect(clientErr).To(HaveOccurred())
   456  			nErr, ok := clientErr.(net.Error)
   457  			Expect(ok).To(BeTrue())
   458  			Expect(nErr.Timeout()).To(BeTrue())
   459  
   460  			select {
   461  			case serverErr := <-serverErrChan:
   462  				Expect(serverErr).To(HaveOccurred())
   463  				Expect(serverErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
   464  				defer ln.Close()
   465  			default:
   466  				Expect(ln.Close()).To(Succeed())
   467  				Eventually(serverErrChan).Should(Receive())
   468  			}
   469  		})
   470  
   471  		It("deals with an erroring packet conn, on the client side", func() {
   472  			ln, err := quic.ListenAddr(
   473  				"localhost:0",
   474  				getTLSConfig(),
   475  				getQuicConfig(&quic.Config{
   476  					HandshakeIdleTimeout:    handshakeTimeout,
   477  					MaxIdleTimeout:          handshakeTimeout,
   478  					KeepAlivePeriod:         handshakeTimeout / 2,
   479  					DisablePathMTUDiscovery: true,
   480  				}),
   481  			)
   482  			Expect(err).ToNot(HaveOccurred())
   483  			defer ln.Close()
   484  
   485  			serverErrChan := make(chan error, 1)
   486  			go func() {
   487  				defer GinkgoRecover()
   488  				serverErrChan <- runServer(ln)
   489  			}()
   490  
   491  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   492  			Expect(err).ToNot(HaveOccurred())
   493  			conn, err := net.ListenUDP("udp", addr)
   494  			Expect(err).ToNot(HaveOccurred())
   495  			maxPackets := mrand.Int31n(25)
   496  			fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
   497  			clientErrChan := make(chan error, 1)
   498  			go func() {
   499  				defer GinkgoRecover()
   500  				conn, err := quic.Dial(
   501  					context.Background(),
   502  					&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
   503  					ln.Addr(),
   504  					getTLSClientConfig(),
   505  					getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   506  				)
   507  				if err != nil {
   508  					clientErrChan <- err
   509  					return
   510  				}
   511  				clientErrChan <- runClient(conn)
   512  			}()
   513  
   514  			var clientErr error
   515  			Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
   516  			Expect(clientErr).To(HaveOccurred())
   517  			Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
   518  			Eventually(areHandshakesRunning, 5*handshakeTimeout).Should(BeFalse())
   519  			select {
   520  			case serverErr := <-serverErrChan: // The handshake completed on the server side.
   521  				Expect(serverErr).To(HaveOccurred())
   522  				nErr, ok := serverErr.(net.Error)
   523  				Expect(ok).To(BeTrue())
   524  				Expect(nErr.Timeout()).To(BeTrue())
   525  			default: // The handshake didn't complete
   526  				Expect(ln.Close()).To(Succeed())
   527  				Eventually(serverErrChan).Should(Receive())
   528  			}
   529  		})
   530  	})
   531  })