github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/fuzzing/handshake/fuzz.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"math"
    13  	mrand "math/rand"
    14  	"time"
    15  
    16  	"github.com/mikelsr/quic-go/fuzzing/internal/helper"
    17  	"github.com/mikelsr/quic-go/internal/handshake"
    18  	"github.com/mikelsr/quic-go/internal/protocol"
    19  	"github.com/mikelsr/quic-go/internal/utils"
    20  	"github.com/mikelsr/quic-go/internal/wire"
    21  )
    22  
    23  var (
    24  	cert, clientCert         *tls.Certificate
    25  	certPool, clientCertPool *x509.CertPool
    26  	sessionTicketKey         = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
    27  )
    28  
    29  func init() {
    30  	priv, err := rsa.GenerateKey(rand.Reader, 1024)
    31  	if err != nil {
    32  		log.Fatal(err)
    33  	}
    34  	cert, certPool, err = helper.GenerateCertificate(priv)
    35  	if err != nil {
    36  		log.Fatal(err)
    37  	}
    38  
    39  	privClient, err := rsa.GenerateKey(rand.Reader, 1024)
    40  	if err != nil {
    41  		log.Fatal(err)
    42  	}
    43  	clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
    44  	if err != nil {
    45  		log.Fatal(err)
    46  	}
    47  }
    48  
    49  type messageType uint8
    50  
    51  // TLS handshake message types.
    52  const (
    53  	typeClientHello         messageType = 1
    54  	typeServerHello         messageType = 2
    55  	typeNewSessionTicket    messageType = 4
    56  	typeEncryptedExtensions messageType = 8
    57  	typeCertificate         messageType = 11
    58  	typeCertificateRequest  messageType = 13
    59  	typeCertificateVerify   messageType = 15
    60  	typeFinished            messageType = 20
    61  )
    62  
    63  func (m messageType) String() string {
    64  	switch m {
    65  	case typeClientHello:
    66  		return "ClientHello"
    67  	case typeServerHello:
    68  		return "ServerHello"
    69  	case typeNewSessionTicket:
    70  		return "NewSessionTicket"
    71  	case typeEncryptedExtensions:
    72  		return "EncryptedExtensions"
    73  	case typeCertificate:
    74  		return "Certificate"
    75  	case typeCertificateRequest:
    76  		return "CertificateRequest"
    77  	case typeCertificateVerify:
    78  		return "CertificateVerify"
    79  	case typeFinished:
    80  		return "Finished"
    81  	default:
    82  		return fmt.Sprintf("unknown message type: %d", m)
    83  	}
    84  }
    85  
    86  func appendSuites(suites []uint16, rand uint8) []uint16 {
    87  	const (
    88  		s1 = tls.TLS_AES_128_GCM_SHA256
    89  		s2 = tls.TLS_AES_256_GCM_SHA384
    90  		s3 = tls.TLS_CHACHA20_POLY1305_SHA256
    91  	)
    92  	switch rand % 4 {
    93  	default:
    94  		return suites
    95  	case 1:
    96  		return append(suites, s1)
    97  	case 2:
    98  		return append(suites, s2)
    99  	case 3:
   100  		return append(suites, s3)
   101  	}
   102  }
   103  
   104  // consumes 2 bits
   105  func getSuites(rand uint8) []uint16 {
   106  	suites := make([]uint16, 0, 3)
   107  	for i := 1; i <= 3; i++ {
   108  		suites = appendSuites(suites, rand>>i%4)
   109  	}
   110  	return suites
   111  }
   112  
   113  // consumes 3 bits
   114  func getClientAuth(rand uint8) tls.ClientAuthType {
   115  	switch rand {
   116  	default:
   117  		return tls.NoClientCert
   118  	case 0:
   119  		return tls.RequestClientCert
   120  	case 1:
   121  		return tls.RequireAnyClientCert
   122  	case 2:
   123  		return tls.VerifyClientCertIfGiven
   124  	case 3:
   125  		return tls.RequireAndVerifyClientCert
   126  	}
   127  }
   128  
   129  type chunk struct {
   130  	data     []byte
   131  	encLevel protocol.EncryptionLevel
   132  }
   133  
   134  type stream struct {
   135  	chunkChan chan<- chunk
   136  	encLevel  protocol.EncryptionLevel
   137  }
   138  
   139  func (s *stream) Write(b []byte) (int, error) {
   140  	data := append([]byte{}, b...)
   141  	select {
   142  	case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
   143  	default:
   144  		panic("chunkChan too small")
   145  	}
   146  	return len(b), nil
   147  }
   148  
   149  func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
   150  	chunkChan := make(chan chunk, 10)
   151  	initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
   152  	handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
   153  	return chunkChan, initialStream, handshakeStream
   154  }
   155  
   156  type handshakeRunner interface {
   157  	OnReceivedParams(*wire.TransportParameters)
   158  	OnHandshakeComplete()
   159  	OnReceivedReadKeys()
   160  	DropKeys(protocol.EncryptionLevel)
   161  }
   162  
   163  type runner struct {
   164  	handshakeComplete chan<- struct{}
   165  }
   166  
   167  var _ handshakeRunner = &runner{}
   168  
   169  func newRunner(handshakeComplete chan<- struct{}) *runner {
   170  	return &runner{handshakeComplete: handshakeComplete}
   171  }
   172  
   173  func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
   174  func (r *runner) OnReceivedReadKeys()                        {}
   175  func (r *runner) OnHandshakeComplete() {
   176  	close(r.handshakeComplete)
   177  }
   178  func (r *runner) DropKeys(protocol.EncryptionLevel) {}
   179  
   180  const (
   181  	alpn      = "fuzzing"
   182  	alpnWrong = "wrong"
   183  )
   184  
   185  func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
   186  	switch n % 3 {
   187  	default:
   188  		return protocol.EncryptionInitial
   189  	case 1:
   190  		return protocol.EncryptionHandshake
   191  	case 2:
   192  		return protocol.Encryption1RTT
   193  	}
   194  }
   195  
   196  func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel {
   197  	//nolint:exhaustive
   198  	switch encLevel {
   199  	case protocol.EncryptionInitial:
   200  		return protocol.EncryptionInitial
   201  	case protocol.EncryptionHandshake:
   202  		// Handshake opener not available. We can't possibly read a Handshake handshake message.
   203  		if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil {
   204  			return protocol.EncryptionInitial
   205  		}
   206  		return protocol.EncryptionHandshake
   207  	case protocol.Encryption1RTT:
   208  		// 1-RTT opener not available. We can't possibly read a post-handshake message.
   209  		if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil {
   210  			return maxEncLevel(cs, protocol.EncryptionHandshake)
   211  		}
   212  		return protocol.Encryption1RTT
   213  	default:
   214  		panic("unexpected encryption level")
   215  	}
   216  }
   217  
   218  func getTransportParameters(seed uint8) *wire.TransportParameters {
   219  	const maxVarInt = math.MaxUint64 / 4
   220  	r := mrand.New(mrand.NewSource(int64(seed)))
   221  	return &wire.TransportParameters{
   222  		InitialMaxData:                 protocol.ByteCount(r.Int63n(maxVarInt)),
   223  		InitialMaxStreamDataBidiLocal:  protocol.ByteCount(r.Int63n(maxVarInt)),
   224  		InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
   225  		InitialMaxStreamDataUni:        protocol.ByteCount(r.Int63n(maxVarInt)),
   226  	}
   227  }
   228  
   229  // PrefixLen is the number of bytes used for configuration
   230  const (
   231  	PrefixLen = 12
   232  	confLen   = 5
   233  )
   234  
   235  // Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
   236  //
   237  //go:generate go run ./cmd/corpus.go
   238  func Fuzz(data []byte) int {
   239  	if len(data) < PrefixLen {
   240  		return -1
   241  	}
   242  	dataLen := len(data)
   243  	var runConfig1, runConfig2 [confLen]byte
   244  	copy(runConfig1[:], data)
   245  	data = data[confLen:]
   246  	messageConfig1 := data[0]
   247  	data = data[1:]
   248  	copy(runConfig2[:], data)
   249  	data = data[confLen:]
   250  	messageConfig2 := data[0]
   251  	data = data[1:]
   252  	if dataLen != len(data)+PrefixLen {
   253  		panic("incorrect configuration")
   254  	}
   255  
   256  	clientConf := &tls.Config{
   257  		MinVersion: tls.VersionTLS13,
   258  		ServerName: "localhost",
   259  		NextProtos: []string{alpn},
   260  		RootCAs:    certPool,
   261  	}
   262  	useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
   263  	if useSessionTicketCache {
   264  		clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
   265  	}
   266  
   267  	if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 {
   268  		return val
   269  	}
   270  	return runHandshake(runConfig2, messageConfig2, clientConf, data)
   271  }
   272  
   273  func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
   274  	serverConf := &tls.Config{
   275  		MinVersion:       tls.VersionTLS13,
   276  		Certificates:     []tls.Certificate{*cert},
   277  		NextProtos:       []string{alpn},
   278  		SessionTicketKey: sessionTicketKey,
   279  	}
   280  
   281  	enable0RTTClient := helper.NthBit(runConfig[0], 0)
   282  	enable0RTTServer := helper.NthBit(runConfig[0], 1)
   283  	sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
   284  	sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
   285  	sendSessionTicket := helper.NthBit(runConfig[0], 5)
   286  	clientConf.CipherSuites = getSuites(runConfig[0] >> 6)
   287  	serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
   288  	serverConf.CipherSuites = getSuites(runConfig[1] >> 6)
   289  	serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
   290  	if helper.NthBit(runConfig[2], 0) {
   291  		clientConf.RootCAs = x509.NewCertPool()
   292  	}
   293  	if helper.NthBit(runConfig[2], 1) {
   294  		serverConf.ClientCAs = clientCertPool
   295  	} else {
   296  		serverConf.ClientCAs = x509.NewCertPool()
   297  	}
   298  	if helper.NthBit(runConfig[2], 2) {
   299  		serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   300  			if helper.NthBit(runConfig[2], 3) {
   301  				return nil, errors.New("getting client config failed")
   302  			}
   303  			if helper.NthBit(runConfig[2], 4) {
   304  				return nil, nil
   305  			}
   306  			return serverConf, nil
   307  		}
   308  	}
   309  	if helper.NthBit(runConfig[2], 5) {
   310  		serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
   311  			if helper.NthBit(runConfig[2], 6) {
   312  				return nil, errors.New("getting certificate failed")
   313  			}
   314  			if helper.NthBit(runConfig[2], 7) {
   315  				return nil, nil
   316  			}
   317  			return clientCert, nil // this certificate will be invalid
   318  		}
   319  	}
   320  	if helper.NthBit(runConfig[3], 0) {
   321  		serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   322  			if helper.NthBit(runConfig[3], 1) {
   323  				return errors.New("certificate verification failed")
   324  			}
   325  			return nil
   326  		}
   327  	}
   328  	if helper.NthBit(runConfig[3], 2) {
   329  		clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   330  			if helper.NthBit(runConfig[3], 3) {
   331  				return errors.New("certificate verification failed")
   332  			}
   333  			return nil
   334  		}
   335  	}
   336  	if helper.NthBit(runConfig[3], 4) {
   337  		serverConf.NextProtos = []string{alpnWrong}
   338  	}
   339  	if helper.NthBit(runConfig[3], 5) {
   340  		serverConf.NextProtos = []string{alpnWrong, alpn}
   341  	}
   342  	if helper.NthBit(runConfig[3], 6) {
   343  		serverConf.KeyLogWriter = io.Discard
   344  	}
   345  	if helper.NthBit(runConfig[3], 7) {
   346  		clientConf.KeyLogWriter = io.Discard
   347  	}
   348  	clientTP := getTransportParameters(runConfig[4] & 0x3)
   349  	if helper.NthBit(runConfig[4], 3) {
   350  		clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   351  	}
   352  	serverTP := getTransportParameters(runConfig[4] & 0b00011000)
   353  	if helper.NthBit(runConfig[4], 3) {
   354  		serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   355  	}
   356  
   357  	messageToReplace := messageConfig % 32
   358  	messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
   359  
   360  	cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   361  	var client, server handshake.CryptoSetup
   362  	clientHandshakeCompleted := make(chan struct{})
   363  	client, _ = handshake.NewCryptoSetupClient(
   364  		cInitialStream,
   365  		cHandshakeStream,
   366  		nil,
   367  		protocol.ConnectionID{},
   368  		clientTP,
   369  		newRunner(clientHandshakeCompleted),
   370  		clientConf,
   371  		enable0RTTClient,
   372  		utils.NewRTTStats(),
   373  		nil,
   374  		utils.DefaultLogger.WithPrefix("client"),
   375  		protocol.Version1,
   376  	)
   377  
   378  	serverHandshakeCompleted := make(chan struct{})
   379  	sChunkChan, sInitialStream, sHandshakeStream := initStreams()
   380  	server = handshake.NewCryptoSetupServer(
   381  		sInitialStream,
   382  		sHandshakeStream,
   383  		nil,
   384  		protocol.ConnectionID{},
   385  		serverTP,
   386  		newRunner(serverHandshakeCompleted),
   387  		serverConf,
   388  		enable0RTTServer,
   389  		utils.NewRTTStats(),
   390  		nil,
   391  		utils.DefaultLogger.WithPrefix("server"),
   392  		protocol.Version1,
   393  	)
   394  
   395  	if len(data) == 0 {
   396  		return -1
   397  	}
   398  
   399  	if err := client.StartHandshake(); err != nil {
   400  		log.Fatal(err)
   401  	}
   402  
   403  	if err := server.StartHandshake(); err != nil {
   404  		log.Fatal(err)
   405  	}
   406  
   407  	done := make(chan struct{})
   408  	go func() {
   409  		<-serverHandshakeCompleted
   410  		<-clientHandshakeCompleted
   411  		close(done)
   412  	}()
   413  
   414  messageLoop:
   415  	for {
   416  		select {
   417  		case c := <-cChunkChan:
   418  			b := c.data
   419  			encLevel := c.encLevel
   420  			if len(b) > 0 && b[0] == messageToReplace {
   421  				fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
   422  				b = data
   423  				encLevel = maxEncLevel(server, messageToReplaceEncLevel)
   424  			}
   425  			if err := server.HandleMessage(b, encLevel); err != nil {
   426  				break messageLoop
   427  			}
   428  		case c := <-sChunkChan:
   429  			b := c.data
   430  			encLevel := c.encLevel
   431  			if len(b) > 0 && b[0] == messageToReplace {
   432  				fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
   433  				b = data
   434  				encLevel = maxEncLevel(client, messageToReplaceEncLevel)
   435  			}
   436  			if err := client.HandleMessage(b, encLevel); err != nil {
   437  				break messageLoop
   438  			}
   439  		case <-done: // test done
   440  			break messageLoop
   441  		}
   442  	}
   443  
   444  	<-done
   445  	_ = client.ConnectionState()
   446  	_ = server.ConnectionState()
   447  
   448  	sealer, err := client.Get1RTTSealer()
   449  	if err != nil {
   450  		panic("expected to get a 1-RTT sealer")
   451  	}
   452  	opener, err := server.Get1RTTOpener()
   453  	if err != nil {
   454  		panic("expected to get a 1-RTT opener")
   455  	}
   456  	const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
   457  	encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
   458  	decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
   459  	if err != nil {
   460  		panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
   461  	}
   462  	if string(decrypted) != msg {
   463  		panic("wrong message")
   464  	}
   465  
   466  	if sendSessionTicket && !serverConf.SessionTicketsDisabled {
   467  		ticket, err := server.GetSessionTicket()
   468  		if err != nil {
   469  			panic(err)
   470  		}
   471  		if ticket == nil {
   472  			panic("empty ticket")
   473  		}
   474  		client.HandleMessage(ticket, protocol.Encryption1RTT)
   475  	}
   476  	if sendPostHandshakeMessageToClient {
   477  		client.HandleMessage(data, messageToReplaceEncLevel)
   478  	}
   479  	if sendPostHandshakeMessageToServer {
   480  		server.HandleMessage(data, messageToReplaceEncLevel)
   481  	}
   482  
   483  	return 1
   484  }