github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/self/mitm_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"net"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"golang.org/x/exp/rand"
    14  
    15  	"github.com/daeuniverse/quic-go"
    16  	quicproxy "github.com/daeuniverse/quic-go/integrationtests/tools/proxy"
    17  	"github.com/daeuniverse/quic-go/internal/protocol"
    18  	"github.com/daeuniverse/quic-go/internal/wire"
    19  	"github.com/daeuniverse/quic-go/testutils"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  )
    24  
    25  var _ = Describe("MITM test", func() {
    26  	const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
    27  
    28  	var (
    29  		clientUDPConn                    net.PacketConn
    30  		serverTransport, clientTransport *quic.Transport
    31  		serverConn                       quic.Connection
    32  		serverConfig                     *quic.Config
    33  	)
    34  
    35  	startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
    36  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    37  		Expect(err).ToNot(HaveOccurred())
    38  		c, err := net.ListenUDP("udp", addr)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		serverTransport = &quic.Transport{
    41  			Conn:               c,
    42  			ConnectionIDLength: connIDLen,
    43  		}
    44  		addTracer(serverTransport)
    45  		if forceAddressValidation {
    46  			serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
    47  		}
    48  		ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
    49  		Expect(err).ToNot(HaveOccurred())
    50  		done := make(chan struct{})
    51  		go func() {
    52  			defer GinkgoRecover()
    53  			defer close(done)
    54  			var err error
    55  			serverConn, err = ln.Accept(context.Background())
    56  			if err != nil {
    57  				return
    58  			}
    59  			str, err := serverConn.OpenUniStream()
    60  			Expect(err).ToNot(HaveOccurred())
    61  			_, err = str.Write(PRData)
    62  			Expect(err).ToNot(HaveOccurred())
    63  			Expect(str.Close()).To(Succeed())
    64  		}()
    65  		serverPort := ln.Addr().(*net.UDPAddr).Port
    66  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    67  			RemoteAddr:  fmt.Sprintf("localhost:%d", serverPort),
    68  			DelayPacket: delayCb,
    69  			DropPacket:  dropCb,
    70  		})
    71  		Expect(err).ToNot(HaveOccurred())
    72  		return proxy.LocalPort(), func() {
    73  			proxy.Close()
    74  			ln.Close()
    75  			serverTransport.Close()
    76  			<-done
    77  		}
    78  	}
    79  
    80  	BeforeEach(func() {
    81  		serverConfig = getQuicConfig(nil)
    82  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    83  		Expect(err).ToNot(HaveOccurred())
    84  		clientUDPConn, err = net.ListenUDP("udp", addr)
    85  		Expect(err).ToNot(HaveOccurred())
    86  		clientTransport = &quic.Transport{
    87  			Conn:               clientUDPConn,
    88  			ConnectionIDLength: connIDLen,
    89  		}
    90  		addTracer(clientTransport)
    91  	})
    92  
    93  	Context("unsuccessful attacks", func() {
    94  		AfterEach(func() {
    95  			Eventually(serverConn.Context().Done()).Should(BeClosed())
    96  			// Test shutdown is tricky due to the proxy. Just wait for a bit.
    97  			time.Sleep(50 * time.Millisecond)
    98  			Expect(clientUDPConn.Close()).To(Succeed())
    99  			Expect(clientTransport.Close()).To(Succeed())
   100  		})
   101  
   102  		Context("injecting invalid packets", func() {
   103  			const rtt = 20 * time.Millisecond
   104  
   105  			sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
   106  				defer GinkgoRecover()
   107  				const numPackets = 10
   108  				ticker := time.NewTicker(rtt / numPackets)
   109  				defer ticker.Stop()
   110  
   111  				if wire.IsLongHeaderPacket(raw[0]) {
   112  					hdr, _, _, err := wire.ParsePacket(raw)
   113  					Expect(err).ToNot(HaveOccurred())
   114  					replyHdr := &wire.ExtendedHeader{
   115  						Header: wire.Header{
   116  							DestConnectionID: hdr.DestConnectionID,
   117  							SrcConnectionID:  hdr.SrcConnectionID,
   118  							Type:             hdr.Type,
   119  							Version:          hdr.Version,
   120  						},
   121  						PacketNumber:    protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
   122  						PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
   123  					}
   124  
   125  					for i := 0; i < numPackets; i++ {
   126  						payloadLen := rand.Int31n(100)
   127  						replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
   128  						b, err := replyHdr.Append(nil, hdr.Version)
   129  						Expect(err).ToNot(HaveOccurred())
   130  						r := make([]byte, payloadLen)
   131  						rand.Read(r)
   132  						b = append(b, r...)
   133  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   134  							return
   135  						}
   136  						<-ticker.C
   137  					}
   138  				} else {
   139  					connID, err := wire.ParseConnectionID(raw, connIDLen)
   140  					Expect(err).ToNot(HaveOccurred())
   141  					_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
   142  					if err != nil { // normally, ParseShortHeader is called after decrypting the header
   143  						Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   144  					}
   145  					for i := 0; i < numPackets; i++ {
   146  						b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
   147  						Expect(err).ToNot(HaveOccurred())
   148  						payloadLen := rand.Int31n(100)
   149  						r := make([]byte, payloadLen)
   150  						rand.Read(r)
   151  						b = append(b, r...)
   152  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   153  							return
   154  						}
   155  						<-ticker.C
   156  					}
   157  				}
   158  			}
   159  
   160  			runTest := func(delayCb quicproxy.DelayCallback) {
   161  				proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
   162  				defer closeFn()
   163  				raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   164  				Expect(err).ToNot(HaveOccurred())
   165  				conn, err := clientTransport.Dial(
   166  					context.Background(),
   167  					raddr,
   168  					getTLSClientConfig(),
   169  					getQuicConfig(nil),
   170  				)
   171  				Expect(err).ToNot(HaveOccurred())
   172  				str, err := conn.AcceptUniStream(context.Background())
   173  				Expect(err).ToNot(HaveOccurred())
   174  				data, err := io.ReadAll(str)
   175  				Expect(err).ToNot(HaveOccurred())
   176  				Expect(data).To(Equal(PRData))
   177  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   178  			}
   179  
   180  			It("downloads a message when the packets are injected towards the server", func() {
   181  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   182  					if dir == quicproxy.DirectionIncoming {
   183  						defer GinkgoRecover()
   184  						go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
   185  					}
   186  					return rtt / 2
   187  				}
   188  				runTest(delayCb)
   189  			})
   190  
   191  			It("downloads a message when the packets are injected towards the client", func() {
   192  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   193  					if dir == quicproxy.DirectionOutgoing {
   194  						defer GinkgoRecover()
   195  						go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
   196  					}
   197  					return rtt / 2
   198  				}
   199  				runTest(delayCb)
   200  			})
   201  		})
   202  
   203  		runTest := func(dropCb quicproxy.DropCallback) {
   204  			proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
   205  			defer closeFn()
   206  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   207  			Expect(err).ToNot(HaveOccurred())
   208  			conn, err := clientTransport.Dial(
   209  				context.Background(),
   210  				raddr,
   211  				getTLSClientConfig(),
   212  				getQuicConfig(nil),
   213  			)
   214  			Expect(err).ToNot(HaveOccurred())
   215  			str, err := conn.AcceptUniStream(context.Background())
   216  			Expect(err).ToNot(HaveOccurred())
   217  			data, err := io.ReadAll(str)
   218  			Expect(err).ToNot(HaveOccurred())
   219  			Expect(data).To(Equal(PRData))
   220  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   221  		}
   222  
   223  		Context("duplicating packets", func() {
   224  			It("downloads a message when packets are duplicated towards the server", func() {
   225  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   226  					defer GinkgoRecover()
   227  					if dir == quicproxy.DirectionIncoming {
   228  						_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   229  						Expect(err).ToNot(HaveOccurred())
   230  					}
   231  					return false
   232  				}
   233  				runTest(dropCb)
   234  			})
   235  
   236  			It("downloads a message when packets are duplicated towards the client", func() {
   237  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   238  					defer GinkgoRecover()
   239  					if dir == quicproxy.DirectionOutgoing {
   240  						_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   241  						Expect(err).ToNot(HaveOccurred())
   242  					}
   243  					return false
   244  				}
   245  				runTest(dropCb)
   246  			})
   247  		})
   248  
   249  		Context("corrupting packets", func() {
   250  			const idleTimeout = time.Second
   251  
   252  			var numCorrupted, numPackets atomic.Int32
   253  
   254  			BeforeEach(func() {
   255  				numCorrupted.Store(0)
   256  				numPackets.Store(0)
   257  				serverConfig.MaxIdleTimeout = idleTimeout
   258  			})
   259  
   260  			AfterEach(func() {
   261  				num := numCorrupted.Load()
   262  				fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load())
   263  				Expect(num).To(BeNumerically(">=", 1))
   264  				// If the packet containing the CONNECTION_CLOSE is corrupted,
   265  				// we have to wait for the connection to time out.
   266  				Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
   267  			})
   268  
   269  			It("downloads a message when packet are corrupted towards the server", func() {
   270  				const interval = 4 // corrupt every 4th packet (stochastically)
   271  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   272  					defer GinkgoRecover()
   273  					if dir == quicproxy.DirectionIncoming {
   274  						numPackets.Add(1)
   275  						if rand.Intn(interval) == 0 {
   276  							pos := rand.Intn(len(raw))
   277  							raw[pos] = byte(rand.Intn(256))
   278  							_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   279  							Expect(err).ToNot(HaveOccurred())
   280  							numCorrupted.Add(1)
   281  							return true
   282  						}
   283  					}
   284  					return false
   285  				}
   286  				runTest(dropCb)
   287  			})
   288  
   289  			It("downloads a message when packet are corrupted towards the client", func() {
   290  				const interval = 10 // corrupt every 10th packet (stochastically)
   291  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   292  					defer GinkgoRecover()
   293  					if dir == quicproxy.DirectionOutgoing {
   294  						numPackets.Add(1)
   295  						if rand.Intn(interval) == 0 {
   296  							pos := rand.Intn(len(raw))
   297  							raw[pos] = byte(rand.Intn(256))
   298  							_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   299  							Expect(err).ToNot(HaveOccurred())
   300  							numCorrupted.Add(1)
   301  							return true
   302  						}
   303  					}
   304  					return false
   305  				}
   306  				runTest(dropCb)
   307  			})
   308  		})
   309  	})
   310  
   311  	Context("successful injection attacks", func() {
   312  		// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
   313  		// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
   314  		// that arrives before the real reply can tear down the connection in multiple ways.
   315  
   316  		const rtt = 20 * time.Millisecond
   317  
   318  		runTest := func(proxyPort int) (closeFn func(), err error) {
   319  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   320  			Expect(err).ToNot(HaveOccurred())
   321  			_, err = clientTransport.Dial(
   322  				context.Background(),
   323  				raddr,
   324  				getTLSClientConfig(),
   325  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
   326  			)
   327  			return func() { clientTransport.Close() }, err
   328  		}
   329  
   330  		// fails immediately because client connection closes when it can't find compatible version
   331  		It("fails when a forged version negotiation packet is sent to client", func() {
   332  			done := make(chan struct{})
   333  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   334  				if dir == quicproxy.DirectionIncoming {
   335  					defer GinkgoRecover()
   336  
   337  					hdr, _, _, err := wire.ParsePacket(raw)
   338  					Expect(err).ToNot(HaveOccurred())
   339  
   340  					if hdr.Type != protocol.PacketTypeInitial {
   341  						return 0
   342  					}
   343  
   344  					// Create fake version negotiation packet with no supported versions
   345  					versions := []protocol.Version{}
   346  					packet := wire.ComposeVersionNegotiation(
   347  						protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
   348  						protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
   349  						versions,
   350  					)
   351  
   352  					// Send the packet
   353  					_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
   354  					Expect(err).ToNot(HaveOccurred())
   355  					close(done)
   356  				}
   357  				return rtt / 2
   358  			}
   359  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   360  			defer serverCloseFn()
   361  			closeFn, err := runTest(proxyPort)
   362  			defer closeFn()
   363  			Expect(err).To(HaveOccurred())
   364  			vnErr := &quic.VersionNegotiationError{}
   365  			Expect(errors.As(err, &vnErr)).To(BeTrue())
   366  			Eventually(done).Should(BeClosed())
   367  		})
   368  
   369  		// times out, because client doesn't accept subsequent real retry packets from server
   370  		// as it has already accepted a retry.
   371  		// TODO: determine behavior when server does not send Retry packets
   372  		It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
   373  			var initialPacketIntercepted bool
   374  			done := make(chan struct{})
   375  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   376  				if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
   377  					defer GinkgoRecover()
   378  					defer close(done)
   379  
   380  					hdr, _, _, err := wire.ParsePacket(raw)
   381  					Expect(err).ToNot(HaveOccurred())
   382  
   383  					if hdr.Type != protocol.PacketTypeInitial {
   384  						return 0
   385  					}
   386  
   387  					initialPacketIntercepted = true
   388  					fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
   389  					retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
   390  
   391  					_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
   392  					Expect(err).ToNot(HaveOccurred())
   393  				}
   394  				return rtt / 2
   395  			}
   396  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
   397  			defer serverCloseFn()
   398  			closeFn, err := runTest(proxyPort)
   399  			defer closeFn()
   400  			Expect(err).To(HaveOccurred())
   401  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   402  			Eventually(done).Should(BeClosed())
   403  		})
   404  
   405  		// times out, because client doesn't accept real retry packets from server because
   406  		// it has already accepted an initial.
   407  		// TODO: determine behavior when server does not send Retry packets
   408  		It("fails when a forged initial packet is sent to client", func() {
   409  			done := make(chan struct{})
   410  			var injected bool
   411  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   412  				if dir == quicproxy.DirectionIncoming {
   413  					defer GinkgoRecover()
   414  
   415  					hdr, _, _, err := wire.ParsePacket(raw)
   416  					Expect(err).ToNot(HaveOccurred())
   417  					if hdr.Type != protocol.PacketTypeInitial || injected {
   418  						return 0
   419  					}
   420  					defer close(done)
   421  					injected = true
   422  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, protocol.PerspectiveServer, hdr.Version)
   423  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   424  					Expect(err).ToNot(HaveOccurred())
   425  				}
   426  				return rtt
   427  			}
   428  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   429  			defer serverCloseFn()
   430  			closeFn, err := runTest(proxyPort)
   431  			defer closeFn()
   432  			Expect(err).To(HaveOccurred())
   433  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   434  			Eventually(done).Should(BeClosed())
   435  		})
   436  
   437  		// client connection closes immediately on receiving ack for unsent packet
   438  		It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
   439  			done := make(chan struct{})
   440  			var injected bool
   441  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   442  				if dir == quicproxy.DirectionIncoming {
   443  					defer GinkgoRecover()
   444  
   445  					hdr, _, _, err := wire.ParsePacket(raw)
   446  					Expect(err).ToNot(HaveOccurred())
   447  					if hdr.Type != protocol.PacketTypeInitial || injected {
   448  						return 0
   449  					}
   450  					defer close(done)
   451  					injected = true
   452  					// Fake Initial with ACK for packet 2 (unsent)
   453  					ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
   454  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version)
   455  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   456  					Expect(err).ToNot(HaveOccurred())
   457  				}
   458  				return rtt
   459  			}
   460  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   461  			defer serverCloseFn()
   462  			closeFn, err := runTest(proxyPort)
   463  			defer closeFn()
   464  			Expect(err).To(HaveOccurred())
   465  			var transportErr *quic.TransportError
   466  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   467  			Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
   468  			Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
   469  			Eventually(done).Should(BeClosed())
   470  		})
   471  	})
   472  })