github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/timeout_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	mrand "math/rand"
     8  	"net"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/MerlinKodo/quic-go"
    13  	quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy"
    14  	"github.com/MerlinKodo/quic-go/internal/utils"
    15  	"github.com/MerlinKodo/quic-go/logging"
    16  
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  type faultyConn struct {
    22  	net.PacketConn
    23  
    24  	MaxPackets int32
    25  	counter    int32
    26  }
    27  
    28  func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) {
    29  	n, addr, err := c.PacketConn.ReadFrom(p)
    30  	counter := atomic.AddInt32(&c.counter, 1)
    31  	if counter <= c.MaxPackets {
    32  		return n, addr, err
    33  	}
    34  	return 0, nil, io.ErrClosedPipe
    35  }
    36  
    37  func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
    38  	counter := atomic.AddInt32(&c.counter, 1)
    39  	if counter <= c.MaxPackets {
    40  		return c.PacketConn.WriteTo(p, addr)
    41  	}
    42  	return 0, io.ErrClosedPipe
    43  }
    44  
    45  var _ = Describe("Timeout tests", func() {
    46  	checkTimeoutError := func(err error) {
    47  		ExpectWithOffset(1, err).To(MatchError(&quic.IdleTimeoutError{}))
    48  		nerr, ok := err.(net.Error)
    49  		ExpectWithOffset(1, ok).To(BeTrue())
    50  		ExpectWithOffset(1, nerr.Timeout()).To(BeTrue())
    51  	}
    52  
    53  	It("returns net.Error timeout errors when dialing", func() {
    54  		errChan := make(chan error)
    55  		go func() {
    56  			_, err := quic.DialAddr(
    57  				context.Background(),
    58  				"localhost:12345",
    59  				getTLSClientConfig(),
    60  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
    61  			)
    62  			errChan <- err
    63  		}()
    64  		var err error
    65  		Eventually(errChan).Should(Receive(&err))
    66  		checkTimeoutError(err)
    67  	})
    68  
    69  	It("returns the context error when the context expires", func() {
    70  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    71  		defer cancel()
    72  		errChan := make(chan error)
    73  		go func() {
    74  			_, err := quic.DialAddr(
    75  				ctx,
    76  				"localhost:12345",
    77  				getTLSClientConfig(),
    78  				getQuicConfig(nil),
    79  			)
    80  			errChan <- err
    81  		}()
    82  		var err error
    83  		Eventually(errChan).Should(Receive(&err))
    84  		// This is not a net.Error timeout error
    85  		Expect(err).To(MatchError(context.DeadlineExceeded))
    86  	})
    87  
    88  	It("returns the context error when the context expires with 0RTT enabled", func() {
    89  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    90  		defer cancel()
    91  		errChan := make(chan error)
    92  		go func() {
    93  			_, err := quic.DialAddrEarly(
    94  				ctx,
    95  				"localhost:12345",
    96  				getTLSClientConfig(),
    97  				getQuicConfig(nil),
    98  			)
    99  			errChan <- err
   100  		}()
   101  		var err error
   102  		Eventually(errChan).Should(Receive(&err))
   103  		// This is not a net.Error timeout error
   104  		Expect(err).To(MatchError(context.DeadlineExceeded))
   105  	})
   106  
   107  	It("returns net.Error timeout errors when an idle timeout occurs", func() {
   108  		const idleTimeout = 500 * time.Millisecond
   109  
   110  		server, err := quic.ListenAddr(
   111  			"localhost:0",
   112  			getTLSConfig(),
   113  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   114  		)
   115  		Expect(err).ToNot(HaveOccurred())
   116  		defer server.Close()
   117  
   118  		go func() {
   119  			defer GinkgoRecover()
   120  			conn, err := server.Accept(context.Background())
   121  			Expect(err).ToNot(HaveOccurred())
   122  			str, err := conn.OpenStream()
   123  			Expect(err).ToNot(HaveOccurred())
   124  			_, err = str.Write([]byte("foobar"))
   125  			Expect(err).ToNot(HaveOccurred())
   126  		}()
   127  
   128  		var drop atomic.Bool
   129  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   130  			RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   131  			DropPacket: func(quicproxy.Direction, []byte) bool {
   132  				return drop.Load()
   133  			},
   134  		})
   135  		Expect(err).ToNot(HaveOccurred())
   136  		defer proxy.Close()
   137  
   138  		conn, err := quic.DialAddr(
   139  			context.Background(),
   140  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   141  			getTLSClientConfig(),
   142  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
   143  		)
   144  		Expect(err).ToNot(HaveOccurred())
   145  		strIn, err := conn.AcceptStream(context.Background())
   146  		Expect(err).ToNot(HaveOccurred())
   147  		strOut, err := conn.OpenStream()
   148  		Expect(err).ToNot(HaveOccurred())
   149  		_, err = strIn.Read(make([]byte, 6))
   150  		Expect(err).ToNot(HaveOccurred())
   151  
   152  		drop.Store(true)
   153  		time.Sleep(2 * idleTimeout)
   154  		_, err = strIn.Write([]byte("test"))
   155  		checkTimeoutError(err)
   156  		_, err = strIn.Read([]byte{0})
   157  		checkTimeoutError(err)
   158  		_, err = strOut.Write([]byte("test"))
   159  		checkTimeoutError(err)
   160  		_, err = strOut.Read([]byte{0})
   161  		checkTimeoutError(err)
   162  		_, err = conn.OpenStream()
   163  		checkTimeoutError(err)
   164  		_, err = conn.OpenUniStream()
   165  		checkTimeoutError(err)
   166  		_, err = conn.AcceptStream(context.Background())
   167  		checkTimeoutError(err)
   168  		_, err = conn.AcceptUniStream(context.Background())
   169  		checkTimeoutError(err)
   170  	})
   171  
   172  	Context("timing out at the right time", func() {
   173  		var idleTimeout time.Duration
   174  
   175  		BeforeEach(func() {
   176  			idleTimeout = scaleDuration(500 * time.Millisecond)
   177  		})
   178  
   179  		It("times out after inactivity", func() {
   180  			server, err := quic.ListenAddr(
   181  				"localhost:0",
   182  				getTLSConfig(),
   183  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   184  			)
   185  			Expect(err).ToNot(HaveOccurred())
   186  			defer server.Close()
   187  
   188  			serverConnClosed := make(chan struct{})
   189  			go func() {
   190  				defer GinkgoRecover()
   191  				conn, err := server.Accept(context.Background())
   192  				Expect(err).ToNot(HaveOccurred())
   193  				conn.AcceptStream(context.Background()) // blocks until the connection is closed
   194  				close(serverConnClosed)
   195  			}()
   196  
   197  			counter, tr := newPacketTracer()
   198  			conn, err := quic.DialAddr(
   199  				context.Background(),
   200  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   201  				getTLSClientConfig(),
   202  				getQuicConfig(&quic.Config{
   203  					MaxIdleTimeout:          idleTimeout,
   204  					Tracer:                  newTracer(tr),
   205  					DisablePathMTUDiscovery: true,
   206  				}),
   207  			)
   208  			Expect(err).ToNot(HaveOccurred())
   209  			done := make(chan struct{})
   210  			go func() {
   211  				defer GinkgoRecover()
   212  				_, err := conn.AcceptStream(context.Background())
   213  				checkTimeoutError(err)
   214  				close(done)
   215  			}()
   216  			Eventually(done, 2*idleTimeout).Should(BeClosed())
   217  			var lastAckElicitingPacketSentAt time.Time
   218  			for _, p := range counter.getSentShortHeaderPackets() {
   219  				var hasAckElicitingFrame bool
   220  				for _, f := range p.frames {
   221  					if _, ok := f.(*logging.AckFrame); ok {
   222  						continue
   223  					}
   224  					hasAckElicitingFrame = true
   225  					break
   226  				}
   227  				if hasAckElicitingFrame {
   228  					lastAckElicitingPacketSentAt = p.time
   229  				}
   230  			}
   231  			rcvdPackets := counter.getRcvdShortHeaderPackets()
   232  			lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
   233  			// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
   234  			// This is ok since we're dealing with a lossless connection here,
   235  			// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
   236  			Expect(time.Since(utils.MaxTime(lastAckElicitingPacketSentAt, lastPacketRcvdAt))).To(And(
   237  				BeNumerically(">=", idleTimeout),
   238  				BeNumerically("<", idleTimeout*6/5),
   239  			))
   240  			Consistently(serverConnClosed).ShouldNot(BeClosed())
   241  
   242  			// make the go routine return
   243  			Expect(server.Close()).To(Succeed())
   244  			Eventually(serverConnClosed).Should(BeClosed())
   245  		})
   246  
   247  		It("times out after sending a packet", func() {
   248  			server, err := quic.ListenAddr(
   249  				"localhost:0",
   250  				getTLSConfig(),
   251  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   252  			)
   253  			Expect(err).ToNot(HaveOccurred())
   254  			defer server.Close()
   255  
   256  			var drop atomic.Bool
   257  			proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   258  				RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   259  				DropPacket: func(dir quicproxy.Direction, _ []byte) bool {
   260  					if dir == quicproxy.DirectionOutgoing {
   261  						return drop.Load()
   262  					}
   263  					return false
   264  				},
   265  			})
   266  			Expect(err).ToNot(HaveOccurred())
   267  			defer proxy.Close()
   268  
   269  			serverConnClosed := make(chan struct{})
   270  			go func() {
   271  				defer GinkgoRecover()
   272  				conn, err := server.Accept(context.Background())
   273  				Expect(err).ToNot(HaveOccurred())
   274  				<-conn.Context().Done() // block until the connection is closed
   275  				close(serverConnClosed)
   276  			}()
   277  
   278  			conn, err := quic.DialAddr(
   279  				context.Background(),
   280  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   281  				getTLSClientConfig(),
   282  				getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
   283  			)
   284  			Expect(err).ToNot(HaveOccurred())
   285  
   286  			// wait half the idle timeout, then send a packet
   287  			time.Sleep(idleTimeout / 2)
   288  			drop.Store(true)
   289  			str, err := conn.OpenUniStream()
   290  			Expect(err).ToNot(HaveOccurred())
   291  			_, err = str.Write([]byte("foobar"))
   292  			Expect(err).ToNot(HaveOccurred())
   293  
   294  			// now make sure that the idle timeout is based on this packet
   295  			startTime := time.Now()
   296  			done := make(chan struct{})
   297  			go func() {
   298  				defer GinkgoRecover()
   299  				_, err := conn.AcceptStream(context.Background())
   300  				checkTimeoutError(err)
   301  				close(done)
   302  			}()
   303  			Eventually(done, 2*idleTimeout).Should(BeClosed())
   304  			dur := time.Since(startTime)
   305  			Expect(dur).To(And(
   306  				BeNumerically(">=", idleTimeout),
   307  				BeNumerically("<", idleTimeout*12/10),
   308  			))
   309  			Consistently(serverConnClosed).ShouldNot(BeClosed())
   310  
   311  			// make the go routine return
   312  			Expect(server.Close()).To(Succeed())
   313  			Eventually(serverConnClosed).Should(BeClosed())
   314  		})
   315  	})
   316  
   317  	It("does not time out if keepalive is set", func() {
   318  		const idleTimeout = 500 * time.Millisecond
   319  
   320  		server, err := quic.ListenAddr(
   321  			"localhost:0",
   322  			getTLSConfig(),
   323  			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   324  		)
   325  		Expect(err).ToNot(HaveOccurred())
   326  		defer server.Close()
   327  
   328  		serverConnClosed := make(chan struct{})
   329  		go func() {
   330  			defer GinkgoRecover()
   331  			conn, err := server.Accept(context.Background())
   332  			Expect(err).ToNot(HaveOccurred())
   333  			conn.AcceptStream(context.Background()) // blocks until the connection is closed
   334  			close(serverConnClosed)
   335  		}()
   336  
   337  		var drop atomic.Bool
   338  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   339  			RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   340  			DropPacket: func(quicproxy.Direction, []byte) bool {
   341  				return drop.Load()
   342  			},
   343  		})
   344  		Expect(err).ToNot(HaveOccurred())
   345  		defer proxy.Close()
   346  
   347  		conn, err := quic.DialAddr(
   348  			context.Background(),
   349  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   350  			getTLSClientConfig(),
   351  			getQuicConfig(&quic.Config{
   352  				MaxIdleTimeout:          idleTimeout,
   353  				KeepAlivePeriod:         idleTimeout / 2,
   354  				DisablePathMTUDiscovery: true,
   355  			}),
   356  		)
   357  		Expect(err).ToNot(HaveOccurred())
   358  
   359  		// wait longer than the idle timeout
   360  		time.Sleep(3 * idleTimeout)
   361  		str, err := conn.OpenUniStream()
   362  		Expect(err).ToNot(HaveOccurred())
   363  		_, err = str.Write([]byte("foobar"))
   364  		Expect(err).ToNot(HaveOccurred())
   365  		Consistently(serverConnClosed).ShouldNot(BeClosed())
   366  
   367  		// idle timeout will still kick in if pings are dropped
   368  		drop.Store(true)
   369  		time.Sleep(2 * idleTimeout)
   370  		_, err = str.Write([]byte("foobar"))
   371  		checkTimeoutError(err)
   372  
   373  		Expect(server.Close()).To(Succeed())
   374  		Eventually(serverConnClosed).Should(BeClosed())
   375  	})
   376  
   377  	Context("faulty packet conns", func() {
   378  		const handshakeTimeout = time.Second / 2
   379  
   380  		runServer := func(ln *quic.Listener) error {
   381  			conn, err := ln.Accept(context.Background())
   382  			if err != nil {
   383  				return err
   384  			}
   385  			str, err := conn.OpenUniStream()
   386  			if err != nil {
   387  				return err
   388  			}
   389  			defer str.Close()
   390  			_, err = str.Write(PRData)
   391  			return err
   392  		}
   393  
   394  		runClient := func(conn quic.Connection) error {
   395  			str, err := conn.AcceptUniStream(context.Background())
   396  			if err != nil {
   397  				return err
   398  			}
   399  			data, err := io.ReadAll(str)
   400  			if err != nil {
   401  				return err
   402  			}
   403  			Expect(data).To(Equal(PRData))
   404  			return conn.CloseWithError(0, "done")
   405  		}
   406  
   407  		It("deals with an erroring packet conn, on the server side", func() {
   408  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   409  			Expect(err).ToNot(HaveOccurred())
   410  			conn, err := net.ListenUDP("udp", addr)
   411  			Expect(err).ToNot(HaveOccurred())
   412  			maxPackets := mrand.Int31n(25)
   413  			fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
   414  			ln, err := quic.Listen(
   415  				&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
   416  				getTLSConfig(),
   417  				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   418  			)
   419  			Expect(err).ToNot(HaveOccurred())
   420  
   421  			serverErrChan := make(chan error, 1)
   422  			go func() {
   423  				defer GinkgoRecover()
   424  				serverErrChan <- runServer(ln)
   425  			}()
   426  
   427  			clientErrChan := make(chan error, 1)
   428  			go func() {
   429  				defer GinkgoRecover()
   430  				conn, err := quic.DialAddr(
   431  					context.Background(),
   432  					fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   433  					getTLSClientConfig(),
   434  					getQuicConfig(&quic.Config{
   435  						HandshakeIdleTimeout:    handshakeTimeout,
   436  						MaxIdleTimeout:          handshakeTimeout,
   437  						DisablePathMTUDiscovery: true,
   438  					}),
   439  				)
   440  				if err != nil {
   441  					clientErrChan <- err
   442  					return
   443  				}
   444  				clientErrChan <- runClient(conn)
   445  			}()
   446  
   447  			var clientErr error
   448  			Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
   449  			Expect(clientErr).To(HaveOccurred())
   450  			nErr, ok := clientErr.(net.Error)
   451  			Expect(ok).To(BeTrue())
   452  			Expect(nErr.Timeout()).To(BeTrue())
   453  
   454  			select {
   455  			case serverErr := <-serverErrChan:
   456  				Expect(serverErr).To(HaveOccurred())
   457  				Expect(serverErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
   458  				defer ln.Close()
   459  			default:
   460  				Expect(ln.Close()).To(Succeed())
   461  				Eventually(serverErrChan).Should(Receive())
   462  			}
   463  		})
   464  
   465  		It("deals with an erroring packet conn, on the client side", func() {
   466  			ln, err := quic.ListenAddr(
   467  				"localhost:0",
   468  				getTLSConfig(),
   469  				getQuicConfig(&quic.Config{
   470  					HandshakeIdleTimeout:    handshakeTimeout,
   471  					MaxIdleTimeout:          handshakeTimeout,
   472  					KeepAlivePeriod:         handshakeTimeout / 2,
   473  					DisablePathMTUDiscovery: true,
   474  				}),
   475  			)
   476  			Expect(err).ToNot(HaveOccurred())
   477  			defer ln.Close()
   478  
   479  			serverErrChan := make(chan error, 1)
   480  			go func() {
   481  				defer GinkgoRecover()
   482  				serverErrChan <- runServer(ln)
   483  			}()
   484  
   485  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   486  			Expect(err).ToNot(HaveOccurred())
   487  			conn, err := net.ListenUDP("udp", addr)
   488  			Expect(err).ToNot(HaveOccurred())
   489  			maxPackets := mrand.Int31n(25)
   490  			fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
   491  			clientErrChan := make(chan error, 1)
   492  			go func() {
   493  				defer GinkgoRecover()
   494  				conn, err := quic.Dial(
   495  					context.Background(),
   496  					&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
   497  					ln.Addr(),
   498  					getTLSClientConfig(),
   499  					getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
   500  				)
   501  				if err != nil {
   502  					clientErrChan <- err
   503  					return
   504  				}
   505  				clientErrChan <- runClient(conn)
   506  			}()
   507  
   508  			var clientErr error
   509  			Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
   510  			Expect(clientErr).To(HaveOccurred())
   511  			Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
   512  			Eventually(areHandshakesRunning, 5*handshakeTimeout).Should(BeFalse())
   513  			select {
   514  			case serverErr := <-serverErrChan: // The handshake completed on the server side.
   515  				Expect(serverErr).To(HaveOccurred())
   516  				nErr, ok := serverErr.(net.Error)
   517  				Expect(ok).To(BeTrue())
   518  				Expect(nErr.Timeout()).To(BeTrue())
   519  			default: // The handshake didn't complete
   520  				Expect(ln.Close()).To(Succeed())
   521  				Eventually(serverErrChan).Should(Receive())
   522  			}
   523  		})
   524  	})
   525  })