github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/fuzzing/handshake/fuzz.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"math"
    13  	mrand "math/rand"
    14  	"net"
    15  	"time"
    16  
    17  	"github.com/daeuniverse/quic-go/fuzzing/internal/helper"
    18  	"github.com/daeuniverse/quic-go/internal/handshake"
    19  	"github.com/daeuniverse/quic-go/internal/protocol"
    20  	"github.com/daeuniverse/quic-go/internal/qtls"
    21  	"github.com/daeuniverse/quic-go/internal/utils"
    22  	"github.com/daeuniverse/quic-go/internal/wire"
    23  )
    24  
    25  var (
    26  	cert, clientCert         *tls.Certificate
    27  	certPool, clientCertPool *x509.CertPool
    28  	sessionTicketKey         = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
    29  )
    30  
    31  func init() {
    32  	priv, err := rsa.GenerateKey(rand.Reader, 1024)
    33  	if err != nil {
    34  		log.Fatal(err)
    35  	}
    36  	cert, certPool, err = helper.GenerateCertificate(priv)
    37  	if err != nil {
    38  		log.Fatal(err)
    39  	}
    40  
    41  	privClient, err := rsa.GenerateKey(rand.Reader, 1024)
    42  	if err != nil {
    43  		log.Fatal(err)
    44  	}
    45  	clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
    46  	if err != nil {
    47  		log.Fatal(err)
    48  	}
    49  }
    50  
    51  type messageType uint8
    52  
    53  // TLS handshake message types.
    54  const (
    55  	typeClientHello         messageType = 1
    56  	typeServerHello         messageType = 2
    57  	typeNewSessionTicket    messageType = 4
    58  	typeEncryptedExtensions messageType = 8
    59  	typeCertificate         messageType = 11
    60  	typeCertificateRequest  messageType = 13
    61  	typeCertificateVerify   messageType = 15
    62  	typeFinished            messageType = 20
    63  )
    64  
    65  func (m messageType) String() string {
    66  	switch m {
    67  	case typeClientHello:
    68  		return "ClientHello"
    69  	case typeServerHello:
    70  		return "ServerHello"
    71  	case typeNewSessionTicket:
    72  		return "NewSessionTicket"
    73  	case typeEncryptedExtensions:
    74  		return "EncryptedExtensions"
    75  	case typeCertificate:
    76  		return "Certificate"
    77  	case typeCertificateRequest:
    78  		return "CertificateRequest"
    79  	case typeCertificateVerify:
    80  		return "CertificateVerify"
    81  	case typeFinished:
    82  		return "Finished"
    83  	default:
    84  		return fmt.Sprintf("unknown message type: %d", m)
    85  	}
    86  }
    87  
    88  // consumes 3 bits
    89  func getClientAuth(rand uint8) tls.ClientAuthType {
    90  	switch rand {
    91  	default:
    92  		return tls.NoClientCert
    93  	case 0:
    94  		return tls.RequestClientCert
    95  	case 1:
    96  		return tls.RequireAnyClientCert
    97  	case 2:
    98  		return tls.VerifyClientCertIfGiven
    99  	case 3:
   100  		return tls.RequireAndVerifyClientCert
   101  	}
   102  }
   103  
   104  const (
   105  	alpn      = "fuzzing"
   106  	alpnWrong = "wrong"
   107  )
   108  
   109  func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
   110  	switch n % 3 {
   111  	default:
   112  		return protocol.EncryptionInitial
   113  	case 1:
   114  		return protocol.EncryptionHandshake
   115  	case 2:
   116  		return protocol.Encryption1RTT
   117  	}
   118  }
   119  
   120  func getTransportParameters(seed uint8) *wire.TransportParameters {
   121  	const maxVarInt = math.MaxUint64 / 4
   122  	r := mrand.New(mrand.NewSource(int64(seed)))
   123  	return &wire.TransportParameters{
   124  		ActiveConnectionIDLimit:        2,
   125  		InitialMaxData:                 protocol.ByteCount(r.Int63n(maxVarInt)),
   126  		InitialMaxStreamDataBidiLocal:  protocol.ByteCount(r.Int63n(maxVarInt)),
   127  		InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
   128  		InitialMaxStreamDataUni:        protocol.ByteCount(r.Int63n(maxVarInt)),
   129  	}
   130  }
   131  
   132  // PrefixLen is the number of bytes used for configuration
   133  const (
   134  	PrefixLen = 12
   135  	confLen   = 5
   136  )
   137  
   138  // Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
   139  //
   140  //go:generate go run ./cmd/corpus.go
   141  func Fuzz(data []byte) int {
   142  	if len(data) < PrefixLen {
   143  		return -1
   144  	}
   145  	dataLen := len(data)
   146  	var runConfig1, runConfig2 [confLen]byte
   147  	copy(runConfig1[:], data)
   148  	data = data[confLen:]
   149  	messageConfig1 := data[0]
   150  	data = data[1:]
   151  	copy(runConfig2[:], data)
   152  	data = data[confLen:]
   153  	messageConfig2 := data[0]
   154  	data = data[1:]
   155  	if dataLen != len(data)+PrefixLen {
   156  		panic("incorrect configuration")
   157  	}
   158  
   159  	clientConf := &tls.Config{
   160  		MinVersion: tls.VersionTLS13,
   161  		ServerName: "localhost",
   162  		NextProtos: []string{alpn},
   163  		RootCAs:    certPool,
   164  	}
   165  	useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
   166  	if useSessionTicketCache {
   167  		clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
   168  	}
   169  
   170  	if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 {
   171  		return val
   172  	}
   173  	return runHandshake(runConfig2, messageConfig2, clientConf, data)
   174  }
   175  
   176  func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
   177  	serverConf := &tls.Config{
   178  		MinVersion:       tls.VersionTLS13,
   179  		Certificates:     []tls.Certificate{*cert},
   180  		NextProtos:       []string{alpn},
   181  		SessionTicketKey: sessionTicketKey,
   182  	}
   183  
   184  	// This sets the cipher suite for both client and server.
   185  	// The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server.
   186  	resetCipherSuite := func() {}
   187  	switch (runConfig[0] >> 6) % 4 {
   188  	case 0:
   189  		resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256)
   190  	case 1:
   191  		resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384)
   192  	case 3:
   193  		resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
   194  	default:
   195  	}
   196  	defer resetCipherSuite()
   197  
   198  	enable0RTTClient := helper.NthBit(runConfig[0], 0)
   199  	enable0RTTServer := helper.NthBit(runConfig[0], 1)
   200  	sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
   201  	sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
   202  	sendSessionTicket := helper.NthBit(runConfig[0], 5)
   203  	serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
   204  	serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
   205  	if helper.NthBit(runConfig[2], 0) {
   206  		clientConf.RootCAs = x509.NewCertPool()
   207  	}
   208  	if helper.NthBit(runConfig[2], 1) {
   209  		serverConf.ClientCAs = clientCertPool
   210  	} else {
   211  		serverConf.ClientCAs = x509.NewCertPool()
   212  	}
   213  	if helper.NthBit(runConfig[2], 2) {
   214  		serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   215  			if helper.NthBit(runConfig[2], 3) {
   216  				return nil, errors.New("getting client config failed")
   217  			}
   218  			if helper.NthBit(runConfig[2], 4) {
   219  				return nil, nil
   220  			}
   221  			return serverConf, nil
   222  		}
   223  	}
   224  	if helper.NthBit(runConfig[2], 5) {
   225  		serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
   226  			if helper.NthBit(runConfig[2], 6) {
   227  				return nil, errors.New("getting certificate failed")
   228  			}
   229  			if helper.NthBit(runConfig[2], 7) {
   230  				return nil, nil
   231  			}
   232  			return clientCert, nil // this certificate will be invalid
   233  		}
   234  	}
   235  	if helper.NthBit(runConfig[3], 0) {
   236  		serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   237  			if helper.NthBit(runConfig[3], 1) {
   238  				return errors.New("certificate verification failed")
   239  			}
   240  			return nil
   241  		}
   242  	}
   243  	if helper.NthBit(runConfig[3], 2) {
   244  		clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   245  			if helper.NthBit(runConfig[3], 3) {
   246  				return errors.New("certificate verification failed")
   247  			}
   248  			return nil
   249  		}
   250  	}
   251  	if helper.NthBit(runConfig[3], 4) {
   252  		serverConf.NextProtos = []string{alpnWrong}
   253  	}
   254  	if helper.NthBit(runConfig[3], 5) {
   255  		serverConf.NextProtos = []string{alpnWrong, alpn}
   256  	}
   257  	if helper.NthBit(runConfig[3], 6) {
   258  		serverConf.KeyLogWriter = io.Discard
   259  	}
   260  	if helper.NthBit(runConfig[3], 7) {
   261  		clientConf.KeyLogWriter = io.Discard
   262  	}
   263  	clientTP := getTransportParameters(runConfig[4] & 0x3)
   264  	if helper.NthBit(runConfig[4], 3) {
   265  		clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   266  	}
   267  	serverTP := getTransportParameters(runConfig[4] & 0b00011000)
   268  	if helper.NthBit(runConfig[4], 3) {
   269  		serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   270  	}
   271  
   272  	messageToReplace := messageConfig % 32
   273  	messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
   274  
   275  	if len(data) == 0 {
   276  		return -1
   277  	}
   278  
   279  	client := handshake.NewCryptoSetupClient(
   280  		protocol.ConnectionID{},
   281  		clientTP,
   282  		clientConf,
   283  		enable0RTTClient,
   284  		utils.NewRTTStats(),
   285  		nil,
   286  		utils.DefaultLogger.WithPrefix("client"),
   287  		protocol.Version1,
   288  	)
   289  	if err := client.StartHandshake(); err != nil {
   290  		log.Fatal(err)
   291  	}
   292  	defer client.Close()
   293  
   294  	server := handshake.NewCryptoSetupServer(
   295  		protocol.ConnectionID{},
   296  		&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   297  		&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   298  		serverTP,
   299  		serverConf,
   300  		enable0RTTServer,
   301  		utils.NewRTTStats(),
   302  		nil,
   303  		utils.DefaultLogger.WithPrefix("server"),
   304  		protocol.Version1,
   305  	)
   306  	if err := server.StartHandshake(); err != nil {
   307  		log.Fatal(err)
   308  	}
   309  	defer server.Close()
   310  
   311  	var clientHandshakeComplete, serverHandshakeComplete bool
   312  	for {
   313  		var processedEvent bool
   314  	clientLoop:
   315  		for {
   316  			ev := client.NextEvent()
   317  			//nolint:exhaustive // only need to process a few events
   318  			switch ev.Kind {
   319  			case handshake.EventNoEvent:
   320  				if !processedEvent && !clientHandshakeComplete { // handshake stuck
   321  					return 1
   322  				}
   323  				break clientLoop
   324  			case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
   325  				msg := ev.Data
   326  				encLevel := protocol.EncryptionInitial
   327  				if ev.Kind == handshake.EventWriteHandshakeData {
   328  					encLevel = protocol.EncryptionHandshake
   329  				}
   330  				if msg[0] == messageToReplace {
   331  					fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
   332  					msg = data
   333  					encLevel = messageToReplaceEncLevel
   334  				}
   335  				if err := server.HandleMessage(msg, encLevel); err != nil {
   336  					return 1
   337  				}
   338  			case handshake.EventHandshakeComplete:
   339  				clientHandshakeComplete = true
   340  			}
   341  			processedEvent = true
   342  		}
   343  
   344  		processedEvent = false
   345  	serverLoop:
   346  		for {
   347  			ev := server.NextEvent()
   348  			//nolint:exhaustive // only need to process a few events
   349  			switch ev.Kind {
   350  			case handshake.EventNoEvent:
   351  				if !processedEvent && !serverHandshakeComplete { // handshake stuck
   352  					return 1
   353  				}
   354  				break serverLoop
   355  			case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
   356  				encLevel := protocol.EncryptionInitial
   357  				if ev.Kind == handshake.EventWriteHandshakeData {
   358  					encLevel = protocol.EncryptionHandshake
   359  				}
   360  				msg := ev.Data
   361  				if msg[0] == messageToReplace {
   362  					fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
   363  					msg = data
   364  					encLevel = messageToReplaceEncLevel
   365  				}
   366  				if err := client.HandleMessage(msg, encLevel); err != nil {
   367  					return 1
   368  				}
   369  			case handshake.EventHandshakeComplete:
   370  				serverHandshakeComplete = true
   371  			}
   372  			processedEvent = true
   373  		}
   374  
   375  		if serverHandshakeComplete && clientHandshakeComplete {
   376  			break
   377  		}
   378  	}
   379  
   380  	_ = client.ConnectionState()
   381  	_ = server.ConnectionState()
   382  
   383  	sealer, err := client.Get1RTTSealer()
   384  	if err != nil {
   385  		panic("expected to get a 1-RTT sealer")
   386  	}
   387  	opener, err := server.Get1RTTOpener()
   388  	if err != nil {
   389  		panic("expected to get a 1-RTT opener")
   390  	}
   391  	const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
   392  	encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
   393  	decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
   394  	if err != nil {
   395  		panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
   396  	}
   397  	if string(decrypted) != msg {
   398  		panic("wrong message")
   399  	}
   400  
   401  	if sendSessionTicket && !serverConf.SessionTicketsDisabled {
   402  		ticket, err := server.GetSessionTicket()
   403  		if err != nil {
   404  			panic(err)
   405  		}
   406  		if ticket == nil {
   407  			panic("empty ticket")
   408  		}
   409  		client.HandleMessage(ticket, protocol.Encryption1RTT)
   410  	}
   411  
   412  	if sendPostHandshakeMessageToClient {
   413  		fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel)
   414  		client.HandleMessage(data, messageToReplaceEncLevel)
   415  	}
   416  	if sendPostHandshakeMessageToServer {
   417  		fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel)
   418  		server.HandleMessage(data, messageToReplaceEncLevel)
   419  	}
   420  
   421  	return 1
   422  }