github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/tools/proxy/proxy_test.go (about)

     1  package quicproxy
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net"
     7  	"runtime"
     8  	"runtime/pprof"
     9  	"strconv"
    10  	"strings"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/apernet/quic-go/internal/protocol"
    15  	"github.com/apernet/quic-go/internal/wire"
    16  
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  type packetData []byte
    22  
    23  func isProxyRunning() bool {
    24  	var b bytes.Buffer
    25  	pprof.Lookup("goroutine").WriteTo(&b, 1)
    26  	return strings.Contains(b.String(), "proxy.(*QuicProxy).runIncomingConnection") ||
    27  		strings.Contains(b.String(), "proxy.(*QuicProxy).runOutgoingConnection")
    28  }
    29  
    30  var _ = Describe("QUIC Proxy", func() {
    31  	makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
    32  		hdr := wire.ExtendedHeader{
    33  			Header: wire.Header{
    34  				Type:             protocol.PacketTypeInitial,
    35  				Version:          protocol.Version1,
    36  				Length:           4 + protocol.ByteCount(len(payload)),
    37  				DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
    38  				SrcConnectionID:  protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
    39  			},
    40  			PacketNumber:    p,
    41  			PacketNumberLen: protocol.PacketNumberLen4,
    42  		}
    43  		b, err := hdr.Append(nil, protocol.Version1)
    44  		Expect(err).ToNot(HaveOccurred())
    45  		b = append(b, payload...)
    46  		return b
    47  	}
    48  
    49  	readPacketNumber := func(b []byte) protocol.PacketNumber {
    50  		hdr, data, _, err := wire.ParsePacket(b)
    51  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    52  		Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
    53  		extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
    54  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    55  		return extHdr.PacketNumber
    56  	}
    57  
    58  	AfterEach(func() {
    59  		Eventually(isProxyRunning).Should(BeFalse())
    60  	})
    61  
    62  	Context("Proxy setup and teardown", func() {
    63  		It("sets up the UDPProxy", func() {
    64  			proxy, err := NewQuicProxy("localhost:0", nil)
    65  			Expect(err).ToNot(HaveOccurred())
    66  			Expect(proxy.clientDict).To(HaveLen(0))
    67  
    68  			// check that the proxy port is in use
    69  			addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
    70  			Expect(err).ToNot(HaveOccurred())
    71  			_, err = net.ListenUDP("udp", addr)
    72  			if runtime.GOOS == "windows" {
    73  				Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: Only one usage of each socket address (protocol/network address/port) is normally permitted.", proxy.LocalPort())))
    74  			} else {
    75  				Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort())))
    76  			}
    77  			Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test
    78  		})
    79  
    80  		It("stops the UDPProxy", func() {
    81  			isProxyRunning := func() bool {
    82  				var b bytes.Buffer
    83  				pprof.Lookup("goroutine").WriteTo(&b, 1)
    84  				return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
    85  			}
    86  
    87  			proxy, err := NewQuicProxy("localhost:0", nil)
    88  			Expect(err).ToNot(HaveOccurred())
    89  			port := proxy.LocalPort()
    90  			Eventually(isProxyRunning).Should(BeTrue())
    91  			err = proxy.Close()
    92  			Expect(err).ToNot(HaveOccurred())
    93  
    94  			// check that the proxy port is not in use anymore
    95  			addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
    96  			Expect(err).ToNot(HaveOccurred())
    97  			// sometimes it takes a while for the OS to free the port
    98  			Eventually(func() error {
    99  				ln, err := net.ListenUDP("udp", addr)
   100  				if err != nil {
   101  					return err
   102  				}
   103  				ln.Close()
   104  				return nil
   105  			}).ShouldNot(HaveOccurred())
   106  			Eventually(isProxyRunning).Should(BeFalse())
   107  		})
   108  
   109  		It("stops listening for proxied connections", func() {
   110  			serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   111  			Expect(err).ToNot(HaveOccurred())
   112  			serverConn, err := net.ListenUDP("udp", serverAddr)
   113  			Expect(err).ToNot(HaveOccurred())
   114  			defer serverConn.Close()
   115  
   116  			proxy, err := NewQuicProxy("localhost:0", &Opts{RemoteAddr: serverConn.LocalAddr().String()})
   117  			Expect(err).ToNot(HaveOccurred())
   118  			Expect(isProxyRunning()).To(BeFalse())
   119  
   120  			// check that the proxy port is not in use anymore
   121  			conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
   122  			Expect(err).ToNot(HaveOccurred())
   123  			_, err = conn.Write(makePacket(1, []byte("foobar")))
   124  			Expect(err).ToNot(HaveOccurred())
   125  			Eventually(isProxyRunning).Should(BeTrue())
   126  			Expect(proxy.Close()).To(Succeed())
   127  			Eventually(isProxyRunning).Should(BeFalse())
   128  		})
   129  
   130  		It("has the correct LocalAddr and LocalPort", func() {
   131  			proxy, err := NewQuicProxy("localhost:0", nil)
   132  			Expect(err).ToNot(HaveOccurred())
   133  
   134  			Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort())))
   135  			Expect(proxy.LocalPort()).ToNot(BeZero())
   136  
   137  			Expect(proxy.Close()).To(Succeed())
   138  		})
   139  	})
   140  
   141  	Context("Proxy tests", func() {
   142  		var (
   143  			serverConn            *net.UDPConn
   144  			serverNumPacketsSent  atomic.Int32
   145  			serverReceivedPackets chan packetData
   146  			clientConn            *net.UDPConn
   147  			proxy                 *QuicProxy
   148  			stoppedReading        chan struct{}
   149  		)
   150  
   151  		startProxy := func(opts *Opts) {
   152  			var err error
   153  			proxy, err = NewQuicProxy("localhost:0", opts)
   154  			Expect(err).ToNot(HaveOccurred())
   155  			clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
   156  			Expect(err).ToNot(HaveOccurred())
   157  		}
   158  
   159  		BeforeEach(func() {
   160  			stoppedReading = make(chan struct{})
   161  			serverReceivedPackets = make(chan packetData, 100)
   162  			serverNumPacketsSent.Store(0)
   163  
   164  			// set up a dump UDP server
   165  			// in production this would be a QUIC server
   166  			raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
   167  			Expect(err).ToNot(HaveOccurred())
   168  			serverConn, err = net.ListenUDP("udp", raddr)
   169  			Expect(err).ToNot(HaveOccurred())
   170  
   171  			go func() {
   172  				defer GinkgoRecover()
   173  				defer close(stoppedReading)
   174  				for {
   175  					buf := make([]byte, protocol.MaxPacketBufferSize)
   176  					// the ReadFromUDP will error as soon as the UDP conn is closed
   177  					n, addr, err2 := serverConn.ReadFromUDP(buf)
   178  					if err2 != nil {
   179  						return
   180  					}
   181  					data := buf[0:n]
   182  					serverReceivedPackets <- packetData(data)
   183  					// echo the packet
   184  					serverNumPacketsSent.Add(1)
   185  					serverConn.WriteToUDP(data, addr)
   186  				}
   187  			}()
   188  		})
   189  
   190  		AfterEach(func() {
   191  			Expect(proxy.Close()).To(Succeed())
   192  			Expect(serverConn.Close()).To(Succeed())
   193  			Expect(clientConn.Close()).To(Succeed())
   194  			Eventually(stoppedReading).Should(BeClosed())
   195  		})
   196  
   197  		Context("no packet drop", func() {
   198  			It("relays packets from the client to the server", func() {
   199  				startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
   200  				// send the first packet
   201  				_, err := clientConn.Write(makePacket(1, []byte("foobar")))
   202  				Expect(err).ToNot(HaveOccurred())
   203  
   204  				// send the second packet
   205  				_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
   206  				Expect(err).ToNot(HaveOccurred())
   207  
   208  				Eventually(serverReceivedPackets).Should(HaveLen(2))
   209  				Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar"))
   210  				Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad"))
   211  			})
   212  
   213  			It("relays packets from the server to the client", func() {
   214  				startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
   215  				// send the first packet
   216  				_, err := clientConn.Write(makePacket(1, []byte("foobar")))
   217  				Expect(err).ToNot(HaveOccurred())
   218  
   219  				// send the second packet
   220  				_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
   221  				Expect(err).ToNot(HaveOccurred())
   222  
   223  				clientReceivedPackets := make(chan packetData, 2)
   224  				// receive the packets echoed by the server on client side
   225  				go func() {
   226  					for {
   227  						buf := make([]byte, protocol.MaxPacketBufferSize)
   228  						// the ReadFromUDP will error as soon as the UDP conn is closed
   229  						n, _, err2 := clientConn.ReadFromUDP(buf)
   230  						if err2 != nil {
   231  							return
   232  						}
   233  						data := buf[0:n]
   234  						clientReceivedPackets <- packetData(data)
   235  					}
   236  				}()
   237  
   238  				Eventually(serverReceivedPackets).Should(HaveLen(2))
   239  				Expect(serverNumPacketsSent.Load()).To(BeEquivalentTo(2))
   240  				Eventually(clientReceivedPackets).Should(HaveLen(2))
   241  				Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
   242  				Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
   243  			})
   244  		})
   245  
   246  		Context("Drop Callbacks", func() {
   247  			It("drops incoming packets", func() {
   248  				var counter atomic.Int32
   249  				opts := &Opts{
   250  					RemoteAddr: serverConn.LocalAddr().String(),
   251  					DropPacket: func(d Direction, _ []byte) bool {
   252  						if d != DirectionIncoming {
   253  							return false
   254  						}
   255  						return counter.Add(1)%2 == 1
   256  					},
   257  				}
   258  				startProxy(opts)
   259  
   260  				for i := 1; i <= 6; i++ {
   261  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   262  					Expect(err).ToNot(HaveOccurred())
   263  				}
   264  				Eventually(serverReceivedPackets).Should(HaveLen(3))
   265  				Consistently(serverReceivedPackets).Should(HaveLen(3))
   266  			})
   267  
   268  			It("drops outgoing packets", func() {
   269  				const numPackets = 6
   270  				var counter atomic.Int32
   271  				opts := &Opts{
   272  					RemoteAddr: serverConn.LocalAddr().String(),
   273  					DropPacket: func(d Direction, _ []byte) bool {
   274  						if d != DirectionOutgoing {
   275  							return false
   276  						}
   277  						return counter.Add(1)%2 == 1
   278  					},
   279  				}
   280  				startProxy(opts)
   281  
   282  				clientReceivedPackets := make(chan packetData, numPackets)
   283  				// receive the packets echoed by the server on client side
   284  				go func() {
   285  					for {
   286  						buf := make([]byte, protocol.MaxPacketBufferSize)
   287  						// the ReadFromUDP will error as soon as the UDP conn is closed
   288  						n, _, err2 := clientConn.ReadFromUDP(buf)
   289  						if err2 != nil {
   290  							return
   291  						}
   292  						data := buf[0:n]
   293  						clientReceivedPackets <- packetData(data)
   294  					}
   295  				}()
   296  
   297  				for i := 1; i <= numPackets; i++ {
   298  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   299  					Expect(err).ToNot(HaveOccurred())
   300  				}
   301  
   302  				Eventually(clientReceivedPackets).Should(HaveLen(numPackets / 2))
   303  				Consistently(clientReceivedPackets).Should(HaveLen(numPackets / 2))
   304  			})
   305  		})
   306  
   307  		Context("Delay Callback", func() {
   308  			const delay = 200 * time.Millisecond
   309  			expectDelay := func(startTime time.Time, numRTTs int) {
   310  				expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay)
   311  				Expect(time.Now()).To(SatisfyAll(
   312  					BeTemporally(">=", expectedReceiveTime),
   313  					BeTemporally("<", expectedReceiveTime.Add(delay/2)),
   314  				))
   315  			}
   316  
   317  			It("delays incoming packets", func() {
   318  				var counter atomic.Int32
   319  				opts := &Opts{
   320  					RemoteAddr: serverConn.LocalAddr().String(),
   321  					// delay packet 1 by 200 ms
   322  					// delay packet 2 by 400 ms
   323  					// ...
   324  					DelayPacket: func(d Direction, _ []byte) time.Duration {
   325  						if d == DirectionOutgoing {
   326  							return 0
   327  						}
   328  						p := counter.Add(1)
   329  						return time.Duration(p) * delay
   330  					},
   331  				}
   332  				startProxy(opts)
   333  
   334  				// send 3 packets
   335  				start := time.Now()
   336  				for i := 1; i <= 3; i++ {
   337  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   338  					Expect(err).ToNot(HaveOccurred())
   339  				}
   340  				Eventually(serverReceivedPackets).Should(HaveLen(1))
   341  				expectDelay(start, 1)
   342  				Eventually(serverReceivedPackets).Should(HaveLen(2))
   343  				expectDelay(start, 2)
   344  				Eventually(serverReceivedPackets).Should(HaveLen(3))
   345  				expectDelay(start, 3)
   346  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
   347  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
   348  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
   349  			})
   350  
   351  			It("handles reordered packets", func() {
   352  				var counter atomic.Int32
   353  				opts := &Opts{
   354  					RemoteAddr: serverConn.LocalAddr().String(),
   355  					// delay packet 1 by 600 ms
   356  					// delay packet 2 by 400 ms
   357  					// delay packet 3 by 200 ms
   358  					DelayPacket: func(d Direction, _ []byte) time.Duration {
   359  						if d == DirectionOutgoing {
   360  							return 0
   361  						}
   362  						p := counter.Add(1)
   363  						return 600*time.Millisecond - time.Duration(p-1)*delay
   364  					},
   365  				}
   366  				startProxy(opts)
   367  
   368  				// send 3 packets
   369  				start := time.Now()
   370  				for i := 1; i <= 3; i++ {
   371  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   372  					Expect(err).ToNot(HaveOccurred())
   373  				}
   374  				Eventually(serverReceivedPackets).Should(HaveLen(1))
   375  				expectDelay(start, 1)
   376  				Eventually(serverReceivedPackets).Should(HaveLen(2))
   377  				expectDelay(start, 2)
   378  				Eventually(serverReceivedPackets).Should(HaveLen(3))
   379  				expectDelay(start, 3)
   380  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
   381  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
   382  				Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
   383  			})
   384  
   385  			It("doesn't reorder packets when a constant delay is used", func() {
   386  				opts := &Opts{
   387  					RemoteAddr: serverConn.LocalAddr().String(),
   388  					DelayPacket: func(d Direction, _ []byte) time.Duration {
   389  						if d == DirectionOutgoing {
   390  							return 0
   391  						}
   392  						return 100 * time.Millisecond
   393  					},
   394  				}
   395  				startProxy(opts)
   396  
   397  				// send 100 packets
   398  				for i := 0; i < 100; i++ {
   399  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   400  					Expect(err).ToNot(HaveOccurred())
   401  				}
   402  				Eventually(serverReceivedPackets).Should(HaveLen(100))
   403  				for i := 0; i < 100; i++ {
   404  					Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(i)))
   405  				}
   406  			})
   407  
   408  			It("delays outgoing packets", func() {
   409  				const numPackets = 3
   410  				var counter atomic.Int32
   411  				opts := &Opts{
   412  					RemoteAddr: serverConn.LocalAddr().String(),
   413  					// delay packet 1 by 200 ms
   414  					// delay packet 2 by 400 ms
   415  					// ...
   416  					DelayPacket: func(d Direction, _ []byte) time.Duration {
   417  						if d == DirectionIncoming {
   418  							return 0
   419  						}
   420  						p := counter.Add(1)
   421  						return time.Duration(p) * delay
   422  					},
   423  				}
   424  				startProxy(opts)
   425  
   426  				clientReceivedPackets := make(chan packetData, numPackets)
   427  				// receive the packets echoed by the server on client side
   428  				go func() {
   429  					for {
   430  						buf := make([]byte, protocol.MaxPacketBufferSize)
   431  						// the ReadFromUDP will error as soon as the UDP conn is closed
   432  						n, _, err2 := clientConn.ReadFromUDP(buf)
   433  						if err2 != nil {
   434  							return
   435  						}
   436  						data := buf[0:n]
   437  						clientReceivedPackets <- packetData(data)
   438  					}
   439  				}()
   440  
   441  				start := time.Now()
   442  				for i := 1; i <= numPackets; i++ {
   443  					_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
   444  					Expect(err).ToNot(HaveOccurred())
   445  				}
   446  				// the packets should have arrived immediately at the server
   447  				Eventually(serverReceivedPackets).Should(HaveLen(3))
   448  				expectDelay(start, 0)
   449  				Eventually(clientReceivedPackets).Should(HaveLen(1))
   450  				expectDelay(start, 1)
   451  				Eventually(clientReceivedPackets).Should(HaveLen(2))
   452  				expectDelay(start, 2)
   453  				Eventually(clientReceivedPackets).Should(HaveLen(3))
   454  				expectDelay(start, 3)
   455  				Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
   456  				Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
   457  				Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
   458  			})
   459  		})
   460  	})
   461  })