github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/mitm_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"net"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"golang.org/x/exp/rand"
    14  
    15  	"github.com/danielpfeifer02/quic-go-prio-packs"
    16  	quicproxy "github.com/danielpfeifer02/quic-go-prio-packs/integrationtests/tools/proxy"
    17  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    19  	"github.com/danielpfeifer02/quic-go-prio-packs/testutils"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  )
    24  
    25  var _ = Describe("MITM test", func() {
    26  	const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
    27  
    28  	var (
    29  		clientUDPConn                    net.PacketConn
    30  		serverTransport, clientTransport *quic.Transport
    31  		serverConn                       quic.Connection
    32  		serverConfig                     *quic.Config
    33  	)
    34  
    35  	startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
    36  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    37  		Expect(err).ToNot(HaveOccurred())
    38  		c, err := net.ListenUDP("udp", addr)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		serverTransport = &quic.Transport{
    41  			Conn:               c,
    42  			ConnectionIDLength: connIDLen,
    43  		}
    44  		if forceAddressValidation {
    45  			serverTransport.MaxUnvalidatedHandshakes = -1
    46  		}
    47  		ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
    48  		Expect(err).ToNot(HaveOccurred())
    49  		done := make(chan struct{})
    50  		go func() {
    51  			defer GinkgoRecover()
    52  			defer close(done)
    53  			var err error
    54  			serverConn, err = ln.Accept(context.Background())
    55  			if err != nil {
    56  				return
    57  			}
    58  			str, err := serverConn.OpenUniStream()
    59  			Expect(err).ToNot(HaveOccurred())
    60  			_, err = str.Write(PRData)
    61  			Expect(err).ToNot(HaveOccurred())
    62  			Expect(str.Close()).To(Succeed())
    63  		}()
    64  		serverPort := ln.Addr().(*net.UDPAddr).Port
    65  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    66  			RemoteAddr:  fmt.Sprintf("localhost:%d", serverPort),
    67  			DelayPacket: delayCb,
    68  			DropPacket:  dropCb,
    69  		})
    70  		Expect(err).ToNot(HaveOccurred())
    71  		return proxy.LocalPort(), func() {
    72  			proxy.Close()
    73  			ln.Close()
    74  			serverTransport.Close()
    75  			<-done
    76  		}
    77  	}
    78  
    79  	BeforeEach(func() {
    80  		serverConfig = getQuicConfig(nil)
    81  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    82  		Expect(err).ToNot(HaveOccurred())
    83  		clientUDPConn, err = net.ListenUDP("udp", addr)
    84  		Expect(err).ToNot(HaveOccurred())
    85  		clientTransport = &quic.Transport{
    86  			Conn:               clientUDPConn,
    87  			ConnectionIDLength: connIDLen,
    88  		}
    89  	})
    90  
    91  	Context("unsuccessful attacks", func() {
    92  		AfterEach(func() {
    93  			Eventually(serverConn.Context().Done()).Should(BeClosed())
    94  			// Test shutdown is tricky due to the proxy. Just wait for a bit.
    95  			time.Sleep(50 * time.Millisecond)
    96  			Expect(clientUDPConn.Close()).To(Succeed())
    97  			Expect(clientTransport.Close()).To(Succeed())
    98  		})
    99  
   100  		Context("injecting invalid packets", func() {
   101  			const rtt = 20 * time.Millisecond
   102  
   103  			sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
   104  				defer GinkgoRecover()
   105  				const numPackets = 10
   106  				ticker := time.NewTicker(rtt / numPackets)
   107  				defer ticker.Stop()
   108  
   109  				if wire.IsLongHeaderPacket(raw[0]) {
   110  					hdr, _, _, err := wire.ParsePacket(raw)
   111  					Expect(err).ToNot(HaveOccurred())
   112  					replyHdr := &wire.ExtendedHeader{
   113  						Header: wire.Header{
   114  							DestConnectionID: hdr.DestConnectionID,
   115  							SrcConnectionID:  hdr.SrcConnectionID,
   116  							Type:             hdr.Type,
   117  							Version:          hdr.Version,
   118  						},
   119  						PacketNumber:    protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
   120  						PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
   121  					}
   122  
   123  					for i := 0; i < numPackets; i++ {
   124  						payloadLen := rand.Int31n(100)
   125  						replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
   126  						b, err := replyHdr.Append(nil, hdr.Version)
   127  						Expect(err).ToNot(HaveOccurred())
   128  						r := make([]byte, payloadLen)
   129  						rand.Read(r)
   130  						b = append(b, r...)
   131  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   132  							return
   133  						}
   134  						<-ticker.C
   135  					}
   136  				} else {
   137  					connID, err := wire.ParseConnectionID(raw, connIDLen)
   138  					Expect(err).ToNot(HaveOccurred())
   139  					_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
   140  					if err != nil { // normally, ParseShortHeader is called after decrypting the header
   141  						Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   142  					}
   143  					for i := 0; i < numPackets; i++ {
   144  						b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
   145  						Expect(err).ToNot(HaveOccurred())
   146  						payloadLen := rand.Int31n(100)
   147  						r := make([]byte, payloadLen)
   148  						rand.Read(r)
   149  						b = append(b, r...)
   150  						if _, err := conn.WriteTo(b, remoteAddr); err != nil {
   151  							return
   152  						}
   153  						<-ticker.C
   154  					}
   155  				}
   156  			}
   157  
   158  			runTest := func(delayCb quicproxy.DelayCallback) {
   159  				proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
   160  				defer closeFn()
   161  				raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   162  				Expect(err).ToNot(HaveOccurred())
   163  				conn, err := clientTransport.Dial(
   164  					context.Background(),
   165  					raddr,
   166  					getTLSClientConfig(),
   167  					getQuicConfig(nil),
   168  				)
   169  				Expect(err).ToNot(HaveOccurred())
   170  				str, err := conn.AcceptUniStream(context.Background())
   171  				Expect(err).ToNot(HaveOccurred())
   172  				data, err := io.ReadAll(str)
   173  				Expect(err).ToNot(HaveOccurred())
   174  				Expect(data).To(Equal(PRData))
   175  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   176  			}
   177  
   178  			It("downloads a message when the packets are injected towards the server", func() {
   179  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   180  					if dir == quicproxy.DirectionIncoming {
   181  						defer GinkgoRecover()
   182  						go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
   183  					}
   184  					return rtt / 2
   185  				}
   186  				runTest(delayCb)
   187  			})
   188  
   189  			It("downloads a message when the packets are injected towards the client", func() {
   190  				delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   191  					if dir == quicproxy.DirectionOutgoing {
   192  						defer GinkgoRecover()
   193  						go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
   194  					}
   195  					return rtt / 2
   196  				}
   197  				runTest(delayCb)
   198  			})
   199  		})
   200  
   201  		runTest := func(dropCb quicproxy.DropCallback) {
   202  			proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
   203  			defer closeFn()
   204  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   205  			Expect(err).ToNot(HaveOccurred())
   206  			conn, err := clientTransport.Dial(
   207  				context.Background(),
   208  				raddr,
   209  				getTLSClientConfig(),
   210  				getQuicConfig(nil),
   211  			)
   212  			Expect(err).ToNot(HaveOccurred())
   213  			str, err := conn.AcceptUniStream(context.Background())
   214  			Expect(err).ToNot(HaveOccurred())
   215  			data, err := io.ReadAll(str)
   216  			Expect(err).ToNot(HaveOccurred())
   217  			Expect(data).To(Equal(PRData))
   218  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   219  		}
   220  
   221  		Context("duplicating packets", func() {
   222  			It("downloads a message when packets are duplicated towards the server", func() {
   223  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   224  					defer GinkgoRecover()
   225  					if dir == quicproxy.DirectionIncoming {
   226  						_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   227  						Expect(err).ToNot(HaveOccurred())
   228  					}
   229  					return false
   230  				}
   231  				runTest(dropCb)
   232  			})
   233  
   234  			It("downloads a message when packets are duplicated towards the client", func() {
   235  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   236  					defer GinkgoRecover()
   237  					if dir == quicproxy.DirectionOutgoing {
   238  						_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   239  						Expect(err).ToNot(HaveOccurred())
   240  					}
   241  					return false
   242  				}
   243  				runTest(dropCb)
   244  			})
   245  		})
   246  
   247  		Context("corrupting packets", func() {
   248  			const idleTimeout = time.Second
   249  
   250  			var numCorrupted, numPackets atomic.Int32
   251  
   252  			BeforeEach(func() {
   253  				numCorrupted.Store(0)
   254  				numPackets.Store(0)
   255  				serverConfig.MaxIdleTimeout = idleTimeout
   256  			})
   257  
   258  			AfterEach(func() {
   259  				num := numCorrupted.Load()
   260  				fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load())
   261  				Expect(num).To(BeNumerically(">=", 1))
   262  				// If the packet containing the CONNECTION_CLOSE is corrupted,
   263  				// we have to wait for the connection to time out.
   264  				Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
   265  			})
   266  
   267  			It("downloads a message when packet are corrupted towards the server", func() {
   268  				const interval = 4 // corrupt every 4th packet (stochastically)
   269  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   270  					defer GinkgoRecover()
   271  					if dir == quicproxy.DirectionIncoming {
   272  						numPackets.Add(1)
   273  						if rand.Intn(interval) == 0 {
   274  							pos := rand.Intn(len(raw))
   275  							raw[pos] = byte(rand.Intn(256))
   276  							_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
   277  							Expect(err).ToNot(HaveOccurred())
   278  							numCorrupted.Add(1)
   279  							return true
   280  						}
   281  					}
   282  					return false
   283  				}
   284  				runTest(dropCb)
   285  			})
   286  
   287  			It("downloads a message when packet are corrupted towards the client", func() {
   288  				const interval = 10 // corrupt every 10th packet (stochastically)
   289  				dropCb := func(dir quicproxy.Direction, raw []byte) bool {
   290  					defer GinkgoRecover()
   291  					if dir == quicproxy.DirectionOutgoing {
   292  						numPackets.Add(1)
   293  						if rand.Intn(interval) == 0 {
   294  							pos := rand.Intn(len(raw))
   295  							raw[pos] = byte(rand.Intn(256))
   296  							_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
   297  							Expect(err).ToNot(HaveOccurred())
   298  							numCorrupted.Add(1)
   299  							return true
   300  						}
   301  					}
   302  					return false
   303  				}
   304  				runTest(dropCb)
   305  			})
   306  		})
   307  	})
   308  
   309  	Context("successful injection attacks", func() {
   310  		// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
   311  		// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
   312  		// that arrives before the real reply can tear down the connection in multiple ways.
   313  
   314  		const rtt = 20 * time.Millisecond
   315  
   316  		runTest := func(proxyPort int) (closeFn func(), err error) {
   317  			raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
   318  			Expect(err).ToNot(HaveOccurred())
   319  			_, err = clientTransport.Dial(
   320  				context.Background(),
   321  				raddr,
   322  				getTLSClientConfig(),
   323  				getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
   324  			)
   325  			return func() { clientTransport.Close() }, err
   326  		}
   327  
   328  		// fails immediately because client connection closes when it can't find compatible version
   329  		It("fails when a forged version negotiation packet is sent to client", func() {
   330  			done := make(chan struct{})
   331  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   332  				if dir == quicproxy.DirectionIncoming {
   333  					defer GinkgoRecover()
   334  
   335  					hdr, _, _, err := wire.ParsePacket(raw)
   336  					Expect(err).ToNot(HaveOccurred())
   337  
   338  					if hdr.Type != protocol.PacketTypeInitial {
   339  						return 0
   340  					}
   341  
   342  					// Create fake version negotiation packet with no supported versions
   343  					versions := []protocol.Version{}
   344  					packet := wire.ComposeVersionNegotiation(
   345  						protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
   346  						protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
   347  						versions,
   348  					)
   349  
   350  					// Send the packet
   351  					_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
   352  					Expect(err).ToNot(HaveOccurred())
   353  					close(done)
   354  				}
   355  				return rtt / 2
   356  			}
   357  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   358  			defer serverCloseFn()
   359  			closeFn, err := runTest(proxyPort)
   360  			defer closeFn()
   361  			Expect(err).To(HaveOccurred())
   362  			vnErr := &quic.VersionNegotiationError{}
   363  			Expect(errors.As(err, &vnErr)).To(BeTrue())
   364  			Eventually(done).Should(BeClosed())
   365  		})
   366  
   367  		// times out, because client doesn't accept subsequent real retry packets from server
   368  		// as it has already accepted a retry.
   369  		// TODO: determine behavior when server does not send Retry packets
   370  		It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
   371  			var initialPacketIntercepted bool
   372  			done := make(chan struct{})
   373  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   374  				if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
   375  					defer GinkgoRecover()
   376  					defer close(done)
   377  
   378  					hdr, _, _, err := wire.ParsePacket(raw)
   379  					Expect(err).ToNot(HaveOccurred())
   380  
   381  					if hdr.Type != protocol.PacketTypeInitial {
   382  						return 0
   383  					}
   384  
   385  					initialPacketIntercepted = true
   386  					fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
   387  					retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
   388  
   389  					_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
   390  					Expect(err).ToNot(HaveOccurred())
   391  				}
   392  				return rtt / 2
   393  			}
   394  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
   395  			defer serverCloseFn()
   396  			closeFn, err := runTest(proxyPort)
   397  			defer closeFn()
   398  			Expect(err).To(HaveOccurred())
   399  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   400  			Eventually(done).Should(BeClosed())
   401  		})
   402  
   403  		// times out, because client doesn't accept real retry packets from server because
   404  		// it has already accepted an initial.
   405  		// TODO: determine behavior when server does not send Retry packets
   406  		It("fails when a forged initial packet is sent to client", func() {
   407  			done := make(chan struct{})
   408  			var injected bool
   409  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   410  				if dir == quicproxy.DirectionIncoming {
   411  					defer GinkgoRecover()
   412  
   413  					hdr, _, _, err := wire.ParsePacket(raw)
   414  					Expect(err).ToNot(HaveOccurred())
   415  					if hdr.Type != protocol.PacketTypeInitial || injected {
   416  						return 0
   417  					}
   418  					defer close(done)
   419  					injected = true
   420  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, protocol.PerspectiveServer, hdr.Version)
   421  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   422  					Expect(err).ToNot(HaveOccurred())
   423  				}
   424  				return rtt
   425  			}
   426  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   427  			defer serverCloseFn()
   428  			closeFn, err := runTest(proxyPort)
   429  			defer closeFn()
   430  			Expect(err).To(HaveOccurred())
   431  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   432  			Eventually(done).Should(BeClosed())
   433  		})
   434  
   435  		// client connection closes immediately on receiving ack for unsent packet
   436  		It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
   437  			done := make(chan struct{})
   438  			var injected bool
   439  			delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
   440  				if dir == quicproxy.DirectionIncoming {
   441  					defer GinkgoRecover()
   442  
   443  					hdr, _, _, err := wire.ParsePacket(raw)
   444  					Expect(err).ToNot(HaveOccurred())
   445  					if hdr.Type != protocol.PacketTypeInitial || injected {
   446  						return 0
   447  					}
   448  					defer close(done)
   449  					injected = true
   450  					// Fake Initial with ACK for packet 2 (unsent)
   451  					ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
   452  					initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version)
   453  					_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
   454  					Expect(err).ToNot(HaveOccurred())
   455  				}
   456  				return rtt
   457  			}
   458  			proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
   459  			defer serverCloseFn()
   460  			closeFn, err := runTest(proxyPort)
   461  			defer closeFn()
   462  			Expect(err).To(HaveOccurred())
   463  			var transportErr *quic.TransportError
   464  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   465  			Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
   466  			Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
   467  			Eventually(done).Should(BeClosed())
   468  		})
   469  	})
   470  })