github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/integrationtests/self/mitm_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	mrand "math/rand"
    10  	"net"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/mikelsr/quic-go"
    15  	quicproxy "github.com/mikelsr/quic-go/integrationtests/tools/proxy"
    16  	"github.com/mikelsr/quic-go/internal/protocol"
    17  	"github.com/mikelsr/quic-go/internal/testutils"
    18  	"github.com/mikelsr/quic-go/internal/wire"
    19  
    20  	. "github.com/onsi/ginkgo/v2"
    21  	. "github.com/onsi/gomega"
    22  )
    23  
    24  var _ = Describe("MITM test", func() {
    25  	const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
    26  
    27  	var (
    28  		serverUDPConn, clientUDPConn net.PacketConn
    29  		serverConn                   quic.Connection
    30  		serverConfig                 *quic.Config
    31  	)
    32  
    33  	startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
    34  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    35  		Expect(err).ToNot(HaveOccurred())
    36  		c, err := net.ListenUDP("udp", addr)
    37  		Expect(err).ToNot(HaveOccurred())
    38  		serverUDPConn, err = quic.OptimizeConn(c)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		tr := &quic.Transport{
    41  			Conn:               serverUDPConn,
    42  			ConnectionIDLength: connIDLen,
    43  		}
    44  		ln, err := tr.Listen(getTLSConfig(), serverConfig)
    45  		Expect(err).ToNot(HaveOccurred())
    46  		done := make(chan struct{})
    47  		go func() {
    48  			defer GinkgoRecover()
    49  			defer close(done)
    50  			var err error
    51  			serverConn, err = ln.Accept(context.Background())
    52  			if err != nil {
    53  				return
    54  			}
    55  			str, err := serverConn.OpenUniStream()
    56  			Expect(err).ToNot(HaveOccurred())
    57  			_, err = str.Write(PRData)
    58  			Expect(err).ToNot(HaveOccurred())
    59  			Expect(str.Close()).To(Succeed())
    60  		}()
    61  		serverPort := ln.Addr().(*net.UDPAddr).Port
    62  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    63  			RemoteAddr:  fmt.Sprintf("localhost:%d", serverPort),
    64  			DelayPacket: delayCb,
    65  			DropPacket:  dropCb,
    66  		})
    67  		Expect(err).ToNot(HaveOccurred())
    68  		return proxy.LocalPort(), func() {
    69  			proxy.Close()
    70  			ln.Close()
    71  			serverUDPConn.Close()
    72  			<-done
    73  		}
    74  	}
    75  
    76  	BeforeEach(func() {
    77  		serverConfig = getQuicConfig(nil)
    78  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    79  		Expect(err).ToNot(HaveOccurred())
    80  		c, err := net.ListenUDP("udp", addr)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		clientUDPConn, err = quic.OptimizeConn(c)
    83  		Expect(err).ToNot(HaveOccurred())
    84  	})
    85  
    86  	Context("unsuccessful attacks", func() {
    87  		AfterEach(func() {
    88  			Eventually(serverConn.Context().Done()).Should(BeClosed())
    89  			// Test shutdown is tricky due to the proxy. Just wait for a bit.
    90  			time.Sleep(50 * time.Millisecond)
    91  			Expect(clientUDPConn.Close()).To(Succeed())
    92  		})
    93  
    94  		Context("injecting invalid packets", func() {
    95  			const rtt = 20 * time.Millisecond
    96  
    97  			sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
    98  				defer GinkgoRecover()
    99  				const numPackets = 10
   100  				ticker := time.NewTicker(rtt / numPackets)
   101  				defer ticker.Stop()
   102  
   103  				if wire.IsLongHeaderPacket(raw[0]) {
   104  					hdr, _, _, err := wire.ParsePacket(raw)
   105  					Expect(err).ToNot(HaveOccurred())
   106  					replyHdr := &wire.ExtendedHeader{
   107  						Header: wire.Header{
   108  							DestConnectionID: hdr.DestConnectionID,
   109  							SrcConnectionID:  hdr.SrcConnectionID,
   110  							Type:             hdr.Type,
   111  							Version:          hdr.Version,
   112  						},
   113  						PacketNumber:    protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)),
   114  						PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1),
   115  					}
   116  
   117  					for i := 0; i < numPackets; i++ {
   118  						payloadLen := mrand.Int31n(100)
   119  						replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
   120  						b, err := replyHdr.Append(nil, hdr.Version)
   121  						Expect(err).ToNot(HaveOccurred())
   122  						r := make([]byte, payloadLen)
   123  						mrand.Read(r)
   124  						b = append(b, r...)
   125  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   126  							return
   127  						}
   128  						<-ticker.C
   129  					}
   130  				} else {
   131  					connID, err := wire.ParseConnectionID(raw, connIDLen)
   132  					Expect(err).ToNot(HaveOccurred())
   133  					_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
   134  					if err != nil { // normally, ParseShortHeader is called after decrypting the header
   135  						Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   136  					}
   137  					for i := 0; i < numPackets; i++ {
   138  						b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))
   139  						Expect(err).ToNot(HaveOccurred())
   140  						payloadLen := mrand.Int31n(100)
   141  						r := make([]byte, payloadLen)
   142  						mrand.Read(r)
   143  						b = append(b, r...)
   144  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   145  							return
   146  						}
   147  						<-ticker.C
   148  					}
   149  				}
   150  			}
   151  
   152  			runTest := func(delayCb quicproxy.DelayCallback) {
   153  				proxyPort, closeFn := startServerAndProxy(delayCb, nil)
   154  				defer closeFn()
   155  				raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   156  				Expect(err).ToNot(HaveOccurred())
   157  				tr := &quic.Transport{
   158  					Conn:               clientUDPConn,
   159  					ConnectionIDLength: connIDLen,
   160  				}
   161  				conn, err := tr.Dial(
   162  					context.Background(),
   163  					raddr,
   164  					getTLSClientConfig(),
   165  					getQuicConfig(nil),
   166  				)
   167  				Expect(err).ToNot(HaveOccurred())
   168  				str, err := conn.AcceptUniStream(context.Background())
   169  				Expect(err).ToNot(HaveOccurred())
   170  				data, err := io.ReadAll(str)
   171  				Expect(err).ToNot(HaveOccurred())
   172  				Expect(data).To(Equal(PRData))
   173  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   174  			}
   175  
   176  			It("downloads a message when the packets are injected towards the server", func() {
   177  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   178  					if dir == quicproxy.DirectionIncoming {
   179  						defer GinkgoRecover()
   180  						go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw)
   181  					}
   182  					return rtt / 2
   183  				}
   184  				runTest(delayCb)
   185  			})
   186  
   187  			It("downloads a message when the packets are injected towards the client", func() {
   188  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   189  					if dir == quicproxy.DirectionOutgoing {
   190  						defer GinkgoRecover()
   191  						go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw)
   192  					}
   193  					return rtt / 2
   194  				}
   195  				runTest(delayCb)
   196  			})
   197  		})
   198  
   199  		runTest := func(dropCb quicproxy.DropCallback) {
   200  			proxyPort, closeFn := startServerAndProxy(nil, dropCb)
   201  			defer closeFn()
   202  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   203  			Expect(err).ToNot(HaveOccurred())
   204  			tr := &quic.Transport{
   205  				Conn:               clientUDPConn,
   206  				ConnectionIDLength: connIDLen,
   207  			}
   208  			conn, err := tr.Dial(
   209  				context.Background(),
   210  				raddr,
   211  				getTLSClientConfig(),
   212  				getQuicConfig(nil),
   213  			)
   214  			Expect(err).ToNot(HaveOccurred())
   215  			str, err := conn.AcceptUniStream(context.Background())
   216  			Expect(err).ToNot(HaveOccurred())
   217  			data, err := io.ReadAll(str)
   218  			Expect(err).ToNot(HaveOccurred())
   219  			Expect(data).To(Equal(PRData))
   220  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   221  		}
   222  
   223  		Context("duplicating packets", func() {
   224  			It("downloads a message when packets are duplicated towards the server", func() {
   225  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   226  					defer GinkgoRecover()
   227  					if dir == quicproxy.DirectionIncoming {
   228  						_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
   229  						Expect(err).ToNot(HaveOccurred())
   230  					}
   231  					return false
   232  				}
   233  				runTest(dropCb)
   234  			})
   235  
   236  			It("downloads a message when packets are duplicated towards the client", func() {
   237  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   238  					defer GinkgoRecover()
   239  					if dir == quicproxy.DirectionOutgoing {
   240  						_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
   241  						Expect(err).ToNot(HaveOccurred())
   242  					}
   243  					return false
   244  				}
   245  				runTest(dropCb)
   246  			})
   247  		})
   248  
   249  		Context("corrupting packets", func() {
   250  			const idleTimeout = time.Second
   251  
   252  			var numCorrupted, numPackets int32
   253  
   254  			BeforeEach(func() {
   255  				numCorrupted = 0
   256  				numPackets = 0
   257  				serverConfig.MaxIdleTimeout = idleTimeout
   258  			})
   259  
   260  			AfterEach(func() {
   261  				num := atomic.LoadInt32(&numCorrupted)
   262  				fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets))
   263  				Expect(num).To(BeNumerically(">=", 1))
   264  				// If the packet containing the CONNECTION_CLOSE is corrupted,
   265  				// we have to wait for the connection to time out.
   266  				Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
   267  			})
   268  
   269  			It("downloads a message when packet are corrupted towards the server", func() {
   270  				const interval = 4 // corrupt every 4th packet (stochastically)
   271  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   272  					defer GinkgoRecover()
   273  					if dir == quicproxy.DirectionIncoming {
   274  						atomic.AddInt32(&numPackets, 1)
   275  						if mrand.Intn(interval) == 0 {
   276  							pos := mrand.Intn(len(raw))
   277  							raw[pos] = byte(mrand.Intn(256))
   278  							_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
   279  							Expect(err).ToNot(HaveOccurred())
   280  							atomic.AddInt32(&numCorrupted, 1)
   281  							return true
   282  						}
   283  					}
   284  					return false
   285  				}
   286  				runTest(dropCb)
   287  			})
   288  
   289  			It("downloads a message when packet are corrupted towards the client", func() {
   290  				const interval = 10 // corrupt every 10th packet (stochastically)
   291  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   292  					defer GinkgoRecover()
   293  					if dir == quicproxy.DirectionOutgoing {
   294  						atomic.AddInt32(&numPackets, 1)
   295  						if mrand.Intn(interval) == 0 {
   296  							pos := mrand.Intn(len(raw))
   297  							raw[pos] = byte(mrand.Intn(256))
   298  							_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
   299  							Expect(err).ToNot(HaveOccurred())
   300  							atomic.AddInt32(&numCorrupted, 1)
   301  							return true
   302  						}
   303  					}
   304  					return false
   305  				}
   306  				runTest(dropCb)
   307  			})
   308  		})
   309  	})
   310  
   311  	Context("successful injection attacks", func() {
   312  		// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
   313  		// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
   314  		// that arrives before the real reply can tear down the connection in multiple ways.
   315  
   316  		const rtt = 20 * time.Millisecond
   317  
   318  		runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
   319  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
   320  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   321  			Expect(err).ToNot(HaveOccurred())
   322  			tr := &quic.Transport{
   323  				Conn:               clientUDPConn,
   324  				ConnectionIDLength: connIDLen,
   325  			}
   326  			_, err = tr.Dial(
   327  				context.Background(),
   328  				raddr,
   329  				getTLSClientConfig(),
   330  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
   331  			)
   332  			return func() { tr.Close(); serverCloseFn() }, err
   333  		}
   334  
   335  		// fails immediately because client connection closes when it can't find compatible version
   336  		It("fails when a forged version negotiation packet is sent to client", func() {
   337  			done := make(chan struct{})
   338  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   339  				if dir == quicproxy.DirectionIncoming {
   340  					defer GinkgoRecover()
   341  
   342  					hdr, _, _, err := wire.ParsePacket(raw)
   343  					Expect(err).ToNot(HaveOccurred())
   344  
   345  					if hdr.Type != protocol.PacketTypeInitial {
   346  						return 0
   347  					}
   348  
   349  					// Create fake version negotiation packet with no supported versions
   350  					versions := []protocol.VersionNumber{}
   351  					packet := wire.ComposeVersionNegotiation(
   352  						protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
   353  						protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
   354  						versions,
   355  					)
   356  
   357  					// Send the packet
   358  					_, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr())
   359  					Expect(err).ToNot(HaveOccurred())
   360  					close(done)
   361  				}
   362  				return rtt / 2
   363  			}
   364  			closeFn, err := runTest(delayCb)
   365  			defer closeFn()
   366  			Expect(err).To(HaveOccurred())
   367  			vnErr := &quic.VersionNegotiationError{}
   368  			Expect(errors.As(err, &vnErr)).To(BeTrue())
   369  			Eventually(done).Should(BeClosed())
   370  		})
   371  
   372  		// times out, because client doesn't accept subsequent real retry packets from server
   373  		// as it has already accepted a retry.
   374  		// TODO: determine behavior when server does not send Retry packets
   375  		It("fails when a forged Retry packet with modified srcConnID is sent to client", func() {
   376  			serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
   377  			var initialPacketIntercepted bool
   378  			done := make(chan struct{})
   379  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   380  				if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
   381  					defer GinkgoRecover()
   382  					defer close(done)
   383  
   384  					hdr, _, _, err := wire.ParsePacket(raw)
   385  					Expect(err).ToNot(HaveOccurred())
   386  
   387  					if hdr.Type != protocol.PacketTypeInitial {
   388  						return 0
   389  					}
   390  
   391  					initialPacketIntercepted = true
   392  					fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
   393  					retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
   394  
   395  					_, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr())
   396  					Expect(err).ToNot(HaveOccurred())
   397  				}
   398  				return rtt / 2
   399  			}
   400  			closeFn, err := runTest(delayCb)
   401  			defer closeFn()
   402  			Expect(err).To(HaveOccurred())
   403  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   404  			Eventually(done).Should(BeClosed())
   405  		})
   406  
   407  		// times out, because client doesn't accept real retry packets from server because
   408  		// it has already accepted an initial.
   409  		// TODO: determine behavior when server does not send Retry packets
   410  		It("fails when a forged initial packet is sent to client", func() {
   411  			done := make(chan struct{})
   412  			var injected bool
   413  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   414  				if dir == quicproxy.DirectionIncoming {
   415  					defer GinkgoRecover()
   416  
   417  					hdr, _, _, err := wire.ParsePacket(raw)
   418  					Expect(err).ToNot(HaveOccurred())
   419  					if hdr.Type != protocol.PacketTypeInitial || injected {
   420  						return 0
   421  					}
   422  					defer close(done)
   423  					injected = true
   424  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil)
   425  					_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
   426  					Expect(err).ToNot(HaveOccurred())
   427  				}
   428  				return rtt
   429  			}
   430  			closeFn, err := runTest(delayCb)
   431  			defer closeFn()
   432  			Expect(err).To(HaveOccurred())
   433  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   434  			Eventually(done).Should(BeClosed())
   435  		})
   436  
   437  		// client connection closes immediately on receiving ack for unsent packet
   438  		It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
   439  			done := make(chan struct{})
   440  			var injected bool
   441  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   442  				if dir == quicproxy.DirectionIncoming {
   443  					defer GinkgoRecover()
   444  
   445  					hdr, _, _, err := wire.ParsePacket(raw)
   446  					Expect(err).ToNot(HaveOccurred())
   447  					if hdr.Type != protocol.PacketTypeInitial || injected {
   448  						return 0
   449  					}
   450  					defer close(done)
   451  					injected = true
   452  					// Fake Initial with ACK for packet 2 (unsent)
   453  					ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
   454  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack})
   455  					_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
   456  					Expect(err).ToNot(HaveOccurred())
   457  				}
   458  				return rtt
   459  			}
   460  			closeFn, err := runTest(delayCb)
   461  			defer closeFn()
   462  			Expect(err).To(HaveOccurred())
   463  			var transportErr *quic.TransportError
   464  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   465  			Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
   466  			Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
   467  			Eventually(done).Should(BeClosed())
   468  		})
   469  	})
   470  })