github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/mitm_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"net"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"golang.org/x/exp/rand"
    14  
    15  	"github.com/MerlinKodo/quic-go"
    16  	quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy"
    17  	"github.com/MerlinKodo/quic-go/internal/protocol"
    18  	"github.com/MerlinKodo/quic-go/internal/testutils"
    19  	"github.com/MerlinKodo/quic-go/internal/wire"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  )
    24  
    25  var _ = Describe("MITM test", func() {
    26  	const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
    27  
    28  	var (
    29  		clientUDPConn                    net.PacketConn
    30  		serverTransport, clientTransport *quic.Transport
    31  		serverConn                       quic.Connection
    32  		serverConfig                     *quic.Config
    33  	)
    34  
    35  	startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
    36  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    37  		Expect(err).ToNot(HaveOccurred())
    38  		c, err := net.ListenUDP("udp", addr)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		serverTransport = &quic.Transport{
    41  			Conn:               c,
    42  			ConnectionIDLength: connIDLen,
    43  		}
    44  		ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
    45  		Expect(err).ToNot(HaveOccurred())
    46  		done := make(chan struct{})
    47  		go func() {
    48  			defer GinkgoRecover()
    49  			defer close(done)
    50  			var err error
    51  			serverConn, err = ln.Accept(context.Background())
    52  			if err != nil {
    53  				return
    54  			}
    55  			str, err := serverConn.OpenUniStream()
    56  			Expect(err).ToNot(HaveOccurred())
    57  			_, err = str.Write(PRData)
    58  			Expect(err).ToNot(HaveOccurred())
    59  			Expect(str.Close()).To(Succeed())
    60  		}()
    61  		serverPort := ln.Addr().(*net.UDPAddr).Port
    62  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    63  			RemoteAddr:  fmt.Sprintf("localhost:%d", serverPort),
    64  			DelayPacket: delayCb,
    65  			DropPacket:  dropCb,
    66  		})
    67  		Expect(err).ToNot(HaveOccurred())
    68  		return proxy.LocalPort(), func() {
    69  			proxy.Close()
    70  			ln.Close()
    71  			serverTransport.Close()
    72  			<-done
    73  		}
    74  	}
    75  
    76  	BeforeEach(func() {
    77  		serverConfig = getQuicConfig(nil)
    78  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    79  		Expect(err).ToNot(HaveOccurred())
    80  		clientUDPConn, err = net.ListenUDP("udp", addr)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		clientTransport = &quic.Transport{
    83  			Conn:               clientUDPConn,
    84  			ConnectionIDLength: connIDLen,
    85  		}
    86  	})
    87  
    88  	Context("unsuccessful attacks", func() {
    89  		AfterEach(func() {
    90  			Eventually(serverConn.Context().Done()).Should(BeClosed())
    91  			// Test shutdown is tricky due to the proxy. Just wait for a bit.
    92  			time.Sleep(50 * time.Millisecond)
    93  			Expect(clientUDPConn.Close()).To(Succeed())
    94  			Expect(clientTransport.Close()).To(Succeed())
    95  		})
    96  
    97  		Context("injecting invalid packets", func() {
    98  			const rtt = 20 * time.Millisecond
    99  
   100  			sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
   101  				defer GinkgoRecover()
   102  				const numPackets = 10
   103  				ticker := time.NewTicker(rtt / numPackets)
   104  				defer ticker.Stop()
   105  
   106  				if wire.IsLongHeaderPacket(raw[0]) {
   107  					hdr, _, _, err := wire.ParsePacket(raw)
   108  					Expect(err).ToNot(HaveOccurred())
   109  					replyHdr := &wire.ExtendedHeader{
   110  						Header: wire.Header{
   111  							DestConnectionID: hdr.DestConnectionID,
   112  							SrcConnectionID:  hdr.SrcConnectionID,
   113  							Type:             hdr.Type,
   114  							Version:          hdr.Version,
   115  						},
   116  						PacketNumber:    protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
   117  						PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
   118  					}
   119  
   120  					for i := 0; i < numPackets; i++ {
   121  						payloadLen := rand.Int31n(100)
   122  						replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
   123  						b, err := replyHdr.Append(nil, hdr.Version)
   124  						Expect(err).ToNot(HaveOccurred())
   125  						r := make([]byte, payloadLen)
   126  						rand.Read(r)
   127  						b = append(b, r...)
   128  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   129  							return
   130  						}
   131  						<-ticker.C
   132  					}
   133  				} else {
   134  					connID, err := wire.ParseConnectionID(raw, connIDLen)
   135  					Expect(err).ToNot(HaveOccurred())
   136  					_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
   137  					if err != nil { // normally, ParseShortHeader is called after decrypting the header
   138  						Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   139  					}
   140  					for i := 0; i < numPackets; i++ {
   141  						b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
   142  						Expect(err).ToNot(HaveOccurred())
   143  						payloadLen := rand.Int31n(100)
   144  						r := make([]byte, payloadLen)
   145  						rand.Read(r)
   146  						b = append(b, r...)
   147  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   148  							return
   149  						}
   150  						<-ticker.C
   151  					}
   152  				}
   153  			}
   154  
   155  			runTest := func(delayCb quicproxy.DelayCallback) {
   156  				proxyPort, closeFn := startServerAndProxy(delayCb, nil)
   157  				defer closeFn()
   158  				raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   159  				Expect(err).ToNot(HaveOccurred())
   160  				conn, err := clientTransport.Dial(
   161  					context.Background(),
   162  					raddr,
   163  					getTLSClientConfig(),
   164  					getQuicConfig(nil),
   165  				)
   166  				Expect(err).ToNot(HaveOccurred())
   167  				str, err := conn.AcceptUniStream(context.Background())
   168  				Expect(err).ToNot(HaveOccurred())
   169  				data, err := io.ReadAll(str)
   170  				Expect(err).ToNot(HaveOccurred())
   171  				Expect(data).To(Equal(PRData))
   172  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   173  			}
   174  
   175  			It("downloads a message when the packets are injected towards the server", func() {
   176  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   177  					if dir == quicproxy.DirectionIncoming {
   178  						defer GinkgoRecover()
   179  						go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
   180  					}
   181  					return rtt / 2
   182  				}
   183  				runTest(delayCb)
   184  			})
   185  
   186  			It("downloads a message when the packets are injected towards the client", func() {
   187  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   188  					if dir == quicproxy.DirectionOutgoing {
   189  						defer GinkgoRecover()
   190  						go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
   191  					}
   192  					return rtt / 2
   193  				}
   194  				runTest(delayCb)
   195  			})
   196  		})
   197  
   198  		runTest := func(dropCb quicproxy.DropCallback) {
   199  			proxyPort, closeFn := startServerAndProxy(nil, dropCb)
   200  			defer closeFn()
   201  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   202  			Expect(err).ToNot(HaveOccurred())
   203  			conn, err := clientTransport.Dial(
   204  				context.Background(),
   205  				raddr,
   206  				getTLSClientConfig(),
   207  				getQuicConfig(nil),
   208  			)
   209  			Expect(err).ToNot(HaveOccurred())
   210  			str, err := conn.AcceptUniStream(context.Background())
   211  			Expect(err).ToNot(HaveOccurred())
   212  			data, err := io.ReadAll(str)
   213  			Expect(err).ToNot(HaveOccurred())
   214  			Expect(data).To(Equal(PRData))
   215  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   216  		}
   217  
   218  		Context("duplicating packets", func() {
   219  			It("downloads a message when packets are duplicated towards the server", func() {
   220  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   221  					defer GinkgoRecover()
   222  					if dir == quicproxy.DirectionIncoming {
   223  						_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   224  						Expect(err).ToNot(HaveOccurred())
   225  					}
   226  					return false
   227  				}
   228  				runTest(dropCb)
   229  			})
   230  
   231  			It("downloads a message when packets are duplicated towards the client", func() {
   232  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   233  					defer GinkgoRecover()
   234  					if dir == quicproxy.DirectionOutgoing {
   235  						_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   236  						Expect(err).ToNot(HaveOccurred())
   237  					}
   238  					return false
   239  				}
   240  				runTest(dropCb)
   241  			})
   242  		})
   243  
   244  		Context("corrupting packets", func() {
   245  			const idleTimeout = time.Second
   246  
   247  			var numCorrupted, numPackets int32
   248  
   249  			BeforeEach(func() {
   250  				numCorrupted = 0
   251  				numPackets = 0
   252  				serverConfig.MaxIdleTimeout = idleTimeout
   253  			})
   254  
   255  			AfterEach(func() {
   256  				num := atomic.LoadInt32(&numCorrupted)
   257  				fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets))
   258  				Expect(num).To(BeNumerically(">=", 1))
   259  				// If the packet containing the CONNECTION_CLOSE is corrupted,
   260  				// we have to wait for the connection to time out.
   261  				Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
   262  			})
   263  
   264  			It("downloads a message when packet are corrupted towards the server", func() {
   265  				const interval = 4 // corrupt every 4th packet (stochastically)
   266  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   267  					defer GinkgoRecover()
   268  					if dir == quicproxy.DirectionIncoming {
   269  						atomic.AddInt32(&numPackets, 1)
   270  						if rand.Intn(interval) == 0 {
   271  							pos := rand.Intn(len(raw))
   272  							raw[pos] = byte(rand.Intn(256))
   273  							_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   274  							Expect(err).ToNot(HaveOccurred())
   275  							atomic.AddInt32(&numCorrupted, 1)
   276  							return true
   277  						}
   278  					}
   279  					return false
   280  				}
   281  				runTest(dropCb)
   282  			})
   283  
   284  			It("downloads a message when packet are corrupted towards the client", func() {
   285  				const interval = 10 // corrupt every 10th packet (stochastically)
   286  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   287  					defer GinkgoRecover()
   288  					if dir == quicproxy.DirectionOutgoing {
   289  						atomic.AddInt32(&numPackets, 1)
   290  						if rand.Intn(interval) == 0 {
   291  							pos := rand.Intn(len(raw))
   292  							raw[pos] = byte(rand.Intn(256))
   293  							_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   294  							Expect(err).ToNot(HaveOccurred())
   295  							atomic.AddInt32(&numCorrupted, 1)
   296  							return true
   297  						}
   298  					}
   299  					return false
   300  				}
   301  				runTest(dropCb)
   302  			})
   303  		})
   304  	})
   305  
   306  	Context("successful injection attacks", func() {
   307  		// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
   308  		// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
   309  		// that arrives before the real reply can tear down the connection in multiple ways.
   310  
   311  		const rtt = 20 * time.Millisecond
   312  
   313  		runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
   314  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
   315  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   316  			Expect(err).ToNot(HaveOccurred())
   317  			_, err = clientTransport.Dial(
   318  				context.Background(),
   319  				raddr,
   320  				getTLSClientConfig(),
   321  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
   322  			)
   323  			return func() { clientTransport.Close(); serverCloseFn() }, err
   324  		}
   325  
   326  		// fails immediately because client connection closes when it can't find compatible version
   327  		It("fails when a forged version negotiation packet is sent to client", func() {
   328  			done := make(chan struct{})
   329  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   330  				if dir == quicproxy.DirectionIncoming {
   331  					defer GinkgoRecover()
   332  
   333  					hdr, _, _, err := wire.ParsePacket(raw)
   334  					Expect(err).ToNot(HaveOccurred())
   335  
   336  					if hdr.Type != protocol.PacketTypeInitial {
   337  						return 0
   338  					}
   339  
   340  					// Create fake version negotiation packet with no supported versions
   341  					versions := []protocol.VersionNumber{}
   342  					packet := wire.ComposeVersionNegotiation(
   343  						protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
   344  						protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
   345  						versions,
   346  					)
   347  
   348  					// Send the packet
   349  					_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
   350  					Expect(err).ToNot(HaveOccurred())
   351  					close(done)
   352  				}
   353  				return rtt / 2
   354  			}
   355  			closeFn, err := runTest(delayCb)
   356  			defer closeFn()
   357  			Expect(err).To(HaveOccurred())
   358  			vnErr := &quic.VersionNegotiationError{}
   359  			Expect(errors.As(err, &vnErr)).To(BeTrue())
   360  			Eventually(done).Should(BeClosed())
   361  		})
   362  
   363  		// times out, because client doesn't accept subsequent real retry packets from server
   364  		// as it has already accepted a retry.
   365  		// TODO: determine behavior when server does not send Retry packets
   366  		It("fails when a forged Retry packet with modified srcConnID is sent to client", func() {
   367  			serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
   368  			var initialPacketIntercepted bool
   369  			done := make(chan struct{})
   370  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   371  				if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
   372  					defer GinkgoRecover()
   373  					defer close(done)
   374  
   375  					hdr, _, _, err := wire.ParsePacket(raw)
   376  					Expect(err).ToNot(HaveOccurred())
   377  
   378  					if hdr.Type != protocol.PacketTypeInitial {
   379  						return 0
   380  					}
   381  
   382  					initialPacketIntercepted = true
   383  					fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
   384  					retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
   385  
   386  					_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
   387  					Expect(err).ToNot(HaveOccurred())
   388  				}
   389  				return rtt / 2
   390  			}
   391  			closeFn, err := runTest(delayCb)
   392  			defer closeFn()
   393  			Expect(err).To(HaveOccurred())
   394  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   395  			Eventually(done).Should(BeClosed())
   396  		})
   397  
   398  		// times out, because client doesn't accept real retry packets from server because
   399  		// it has already accepted an initial.
   400  		// TODO: determine behavior when server does not send Retry packets
   401  		It("fails when a forged initial packet is sent to client", func() {
   402  			done := make(chan struct{})
   403  			var injected bool
   404  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   405  				if dir == quicproxy.DirectionIncoming {
   406  					defer GinkgoRecover()
   407  
   408  					hdr, _, _, err := wire.ParsePacket(raw)
   409  					Expect(err).ToNot(HaveOccurred())
   410  					if hdr.Type != protocol.PacketTypeInitial || injected {
   411  						return 0
   412  					}
   413  					defer close(done)
   414  					injected = true
   415  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil)
   416  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   417  					Expect(err).ToNot(HaveOccurred())
   418  				}
   419  				return rtt
   420  			}
   421  			closeFn, err := runTest(delayCb)
   422  			defer closeFn()
   423  			Expect(err).To(HaveOccurred())
   424  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   425  			Eventually(done).Should(BeClosed())
   426  		})
   427  
   428  		// client connection closes immediately on receiving ack for unsent packet
   429  		It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
   430  			done := make(chan struct{})
   431  			var injected bool
   432  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   433  				if dir == quicproxy.DirectionIncoming {
   434  					defer GinkgoRecover()
   435  
   436  					hdr, _, _, err := wire.ParsePacket(raw)
   437  					Expect(err).ToNot(HaveOccurred())
   438  					if hdr.Type != protocol.PacketTypeInitial || injected {
   439  						return 0
   440  					}
   441  					defer close(done)
   442  					injected = true
   443  					// Fake Initial with ACK for packet 2 (unsent)
   444  					ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
   445  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack})
   446  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   447  					Expect(err).ToNot(HaveOccurred())
   448  				}
   449  				return rtt
   450  			}
   451  			closeFn, err := runTest(delayCb)
   452  			defer closeFn()
   453  			Expect(err).To(HaveOccurred())
   454  			var transportErr *quic.TransportError
   455  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   456  			Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
   457  			Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
   458  			Eventually(done).Should(BeClosed())
   459  		})
   460  	})
   461  })