github.com/tumi8/quic-go@v0.37.4-tum/fuzzing/handshake/fuzz.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"math"
    13  	mrand "math/rand"
    14  	"net"
    15  	"time"
    16  
    17  	"github.com/tumi8/quic-go/fuzzing/internal/helper"
    18  	"github.com/tumi8/quic-go/noninternal/handshake"
    19  	"github.com/tumi8/quic-go/noninternal/protocol"
    20  	"github.com/tumi8/quic-go/noninternal/utils"
    21  	"github.com/tumi8/quic-go/noninternal/wire"
    22  )
    23  
    24  var (
    25  	cert, clientCert         *tls.Certificate
    26  	certPool, clientCertPool *x509.CertPool
    27  	sessionTicketKey         = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
    28  )
    29  
    30  func init() {
    31  	priv, err := rsa.GenerateKey(rand.Reader, 1024)
    32  	if err != nil {
    33  		log.Fatal(err)
    34  	}
    35  	cert, certPool, err = helper.GenerateCertificate(priv)
    36  	if err != nil {
    37  		log.Fatal(err)
    38  	}
    39  
    40  	privClient, err := rsa.GenerateKey(rand.Reader, 1024)
    41  	if err != nil {
    42  		log.Fatal(err)
    43  	}
    44  	clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
    45  	if err != nil {
    46  		log.Fatal(err)
    47  	}
    48  }
    49  
    50  type messageType uint8
    51  
    52  // TLS handshake message types.
    53  const (
    54  	typeClientHello         messageType = 1
    55  	typeServerHello         messageType = 2
    56  	typeNewSessionTicket    messageType = 4
    57  	typeEncryptedExtensions messageType = 8
    58  	typeCertificate         messageType = 11
    59  	typeCertificateRequest  messageType = 13
    60  	typeCertificateVerify   messageType = 15
    61  	typeFinished            messageType = 20
    62  )
    63  
    64  func (m messageType) String() string {
    65  	switch m {
    66  	case typeClientHello:
    67  		return "ClientHello"
    68  	case typeServerHello:
    69  		return "ServerHello"
    70  	case typeNewSessionTicket:
    71  		return "NewSessionTicket"
    72  	case typeEncryptedExtensions:
    73  		return "EncryptedExtensions"
    74  	case typeCertificate:
    75  		return "Certificate"
    76  	case typeCertificateRequest:
    77  		return "CertificateRequest"
    78  	case typeCertificateVerify:
    79  		return "CertificateVerify"
    80  	case typeFinished:
    81  		return "Finished"
    82  	default:
    83  		return fmt.Sprintf("unknown message type: %d", m)
    84  	}
    85  }
    86  
    87  func appendSuites(suites []uint16, rand uint8) []uint16 {
    88  	const (
    89  		s1 = tls.TLS_AES_128_GCM_SHA256
    90  		s2 = tls.TLS_AES_256_GCM_SHA384
    91  		s3 = tls.TLS_CHACHA20_POLY1305_SHA256
    92  	)
    93  	switch rand % 4 {
    94  	default:
    95  		return suites
    96  	case 1:
    97  		return append(suites, s1)
    98  	case 2:
    99  		return append(suites, s2)
   100  	case 3:
   101  		return append(suites, s3)
   102  	}
   103  }
   104  
   105  // consumes 2 bits
   106  func getSuites(rand uint8) []uint16 {
   107  	suites := make([]uint16, 0, 3)
   108  	for i := 1; i <= 3; i++ {
   109  		suites = appendSuites(suites, rand>>i%4)
   110  	}
   111  	return suites
   112  }
   113  
   114  // consumes 3 bits
   115  func getClientAuth(rand uint8) tls.ClientAuthType {
   116  	switch rand {
   117  	default:
   118  		return tls.NoClientCert
   119  	case 0:
   120  		return tls.RequestClientCert
   121  	case 1:
   122  		return tls.RequireAnyClientCert
   123  	case 2:
   124  		return tls.VerifyClientCertIfGiven
   125  	case 3:
   126  		return tls.RequireAndVerifyClientCert
   127  	}
   128  }
   129  
   130  const (
   131  	alpn      = "fuzzing"
   132  	alpnWrong = "wrong"
   133  )
   134  
   135  func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
   136  	switch n % 3 {
   137  	default:
   138  		return protocol.EncryptionInitial
   139  	case 1:
   140  		return protocol.EncryptionHandshake
   141  	case 2:
   142  		return protocol.Encryption1RTT
   143  	}
   144  }
   145  
   146  func getTransportParameters(seed uint8) *wire.TransportParameters {
   147  	const maxVarInt = math.MaxUint64 / 4
   148  	r := mrand.New(mrand.NewSource(int64(seed)))
   149  	return &wire.TransportParameters{
   150  		InitialMaxData:                 protocol.ByteCount(r.Int63n(maxVarInt)),
   151  		InitialMaxStreamDataBidiLocal:  protocol.ByteCount(r.Int63n(maxVarInt)),
   152  		InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
   153  		InitialMaxStreamDataUni:        protocol.ByteCount(r.Int63n(maxVarInt)),
   154  	}
   155  }
   156  
   157  // PrefixLen is the number of bytes used for configuration
   158  const (
   159  	PrefixLen = 12
   160  	confLen   = 5
   161  )
   162  
   163  // Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
   164  //
   165  //go:generate go run ./cmd/corpus.go
   166  func Fuzz(data []byte) int {
   167  	if len(data) < PrefixLen {
   168  		return -1
   169  	}
   170  	dataLen := len(data)
   171  	var runConfig1, runConfig2 [confLen]byte
   172  	copy(runConfig1[:], data)
   173  	data = data[confLen:]
   174  	messageConfig1 := data[0]
   175  	data = data[1:]
   176  	copy(runConfig2[:], data)
   177  	data = data[confLen:]
   178  	messageConfig2 := data[0]
   179  	data = data[1:]
   180  	if dataLen != len(data)+PrefixLen {
   181  		panic("incorrect configuration")
   182  	}
   183  
   184  	clientConf := &tls.Config{
   185  		MinVersion: tls.VersionTLS13,
   186  		ServerName: "localhost",
   187  		NextProtos: []string{alpn},
   188  		RootCAs:    certPool,
   189  	}
   190  	useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
   191  	if useSessionTicketCache {
   192  		clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
   193  	}
   194  
   195  	if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 {
   196  		return val
   197  	}
   198  	return runHandshake(runConfig2, messageConfig2, clientConf, data)
   199  }
   200  
   201  func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
   202  	serverConf := &tls.Config{
   203  		MinVersion:       tls.VersionTLS13,
   204  		Certificates:     []tls.Certificate{*cert},
   205  		NextProtos:       []string{alpn},
   206  		SessionTicketKey: sessionTicketKey,
   207  	}
   208  
   209  	enable0RTTClient := helper.NthBit(runConfig[0], 0)
   210  	enable0RTTServer := helper.NthBit(runConfig[0], 1)
   211  	sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
   212  	sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
   213  	sendSessionTicket := helper.NthBit(runConfig[0], 5)
   214  	clientConf.CipherSuites = getSuites(runConfig[0] >> 6)
   215  	serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
   216  	serverConf.CipherSuites = getSuites(runConfig[1] >> 6)
   217  	serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
   218  	if helper.NthBit(runConfig[2], 0) {
   219  		clientConf.RootCAs = x509.NewCertPool()
   220  	}
   221  	if helper.NthBit(runConfig[2], 1) {
   222  		serverConf.ClientCAs = clientCertPool
   223  	} else {
   224  		serverConf.ClientCAs = x509.NewCertPool()
   225  	}
   226  	if helper.NthBit(runConfig[2], 2) {
   227  		serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   228  			if helper.NthBit(runConfig[2], 3) {
   229  				return nil, errors.New("getting client config failed")
   230  			}
   231  			if helper.NthBit(runConfig[2], 4) {
   232  				return nil, nil
   233  			}
   234  			return serverConf, nil
   235  		}
   236  	}
   237  	if helper.NthBit(runConfig[2], 5) {
   238  		serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
   239  			if helper.NthBit(runConfig[2], 6) {
   240  				return nil, errors.New("getting certificate failed")
   241  			}
   242  			if helper.NthBit(runConfig[2], 7) {
   243  				return nil, nil
   244  			}
   245  			return clientCert, nil // this certificate will be invalid
   246  		}
   247  	}
   248  	if helper.NthBit(runConfig[3], 0) {
   249  		serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   250  			if helper.NthBit(runConfig[3], 1) {
   251  				return errors.New("certificate verification failed")
   252  			}
   253  			return nil
   254  		}
   255  	}
   256  	if helper.NthBit(runConfig[3], 2) {
   257  		clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   258  			if helper.NthBit(runConfig[3], 3) {
   259  				return errors.New("certificate verification failed")
   260  			}
   261  			return nil
   262  		}
   263  	}
   264  	if helper.NthBit(runConfig[3], 4) {
   265  		serverConf.NextProtos = []string{alpnWrong}
   266  	}
   267  	if helper.NthBit(runConfig[3], 5) {
   268  		serverConf.NextProtos = []string{alpnWrong, alpn}
   269  	}
   270  	if helper.NthBit(runConfig[3], 6) {
   271  		serverConf.KeyLogWriter = io.Discard
   272  	}
   273  	if helper.NthBit(runConfig[3], 7) {
   274  		clientConf.KeyLogWriter = io.Discard
   275  	}
   276  	clientTP := getTransportParameters(runConfig[4] & 0x3)
   277  	if helper.NthBit(runConfig[4], 3) {
   278  		clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   279  	}
   280  	serverTP := getTransportParameters(runConfig[4] & 0b00011000)
   281  	if helper.NthBit(runConfig[4], 3) {
   282  		serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
   283  	}
   284  
   285  	messageToReplace := messageConfig % 32
   286  	messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
   287  
   288  	if len(data) == 0 {
   289  		return -1
   290  	}
   291  
   292  	client := handshake.NewCryptoSetupClient(
   293  		protocol.ConnectionID{},
   294  		clientTP,
   295  		clientConf,
   296  		enable0RTTClient,
   297  		utils.NewRTTStats(),
   298  		nil,
   299  		utils.DefaultLogger.WithPrefix("client"),
   300  		protocol.Version1,
   301  	)
   302  	if err := client.StartHandshake(); err != nil {
   303  		log.Fatal(err)
   304  	}
   305  
   306  	server := handshake.NewCryptoSetupServer(
   307  		protocol.ConnectionID{},
   308  		&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   309  		&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   310  		serverTP,
   311  		serverConf,
   312  		enable0RTTServer,
   313  		utils.NewRTTStats(),
   314  		nil,
   315  		utils.DefaultLogger.WithPrefix("server"),
   316  		protocol.Version1,
   317  	)
   318  	if err := server.StartHandshake(); err != nil {
   319  		log.Fatal(err)
   320  	}
   321  
   322  	var clientHandshakeComplete, serverHandshakeComplete bool
   323  	for {
   324  	clientLoop:
   325  		for {
   326  			var processedEvent bool
   327  			ev := client.NextEvent()
   328  			//nolint:exhaustive // only need to process a few events
   329  			switch ev.Kind {
   330  			case handshake.EventNoEvent:
   331  				if !processedEvent && !clientHandshakeComplete { // handshake stuck
   332  					return 1
   333  				}
   334  				break clientLoop
   335  			case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
   336  				msg := ev.Data
   337  				if msg[0] == messageToReplace {
   338  					fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
   339  					msg = data
   340  				}
   341  				if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
   342  					return 1
   343  				}
   344  			case handshake.EventHandshakeComplete:
   345  				clientHandshakeComplete = true
   346  			}
   347  			processedEvent = true
   348  		}
   349  
   350  	serverLoop:
   351  		for {
   352  			var processedEvent bool
   353  			ev := server.NextEvent()
   354  			//nolint:exhaustive // only need to process a few events
   355  			switch ev.Kind {
   356  			case handshake.EventNoEvent:
   357  				if !processedEvent && !serverHandshakeComplete { // handshake stuck
   358  					return 1
   359  				}
   360  				break serverLoop
   361  			case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
   362  				msg := ev.Data
   363  				if msg[0] == messageToReplace {
   364  					fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
   365  					msg = data
   366  				}
   367  				if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
   368  					return 1
   369  				}
   370  			case handshake.EventHandshakeComplete:
   371  				serverHandshakeComplete = true
   372  			}
   373  			processedEvent = true
   374  		}
   375  
   376  		if serverHandshakeComplete && clientHandshakeComplete {
   377  			break
   378  		}
   379  	}
   380  
   381  	_ = client.ConnectionState()
   382  	_ = server.ConnectionState()
   383  
   384  	sealer, err := client.Get1RTTSealer()
   385  	if err != nil {
   386  		panic("expected to get a 1-RTT sealer")
   387  	}
   388  	opener, err := server.Get1RTTOpener()
   389  	if err != nil {
   390  		panic("expected to get a 1-RTT opener")
   391  	}
   392  	const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
   393  	encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
   394  	decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
   395  	if err != nil {
   396  		panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
   397  	}
   398  	if string(decrypted) != msg {
   399  		panic("wrong message")
   400  	}
   401  
   402  	if sendSessionTicket && !serverConf.SessionTicketsDisabled {
   403  		ticket, err := server.GetSessionTicket()
   404  		if err != nil {
   405  			panic(err)
   406  		}
   407  		if ticket == nil {
   408  			panic("empty ticket")
   409  		}
   410  		client.HandleMessage(ticket, protocol.Encryption1RTT)
   411  	}
   412  	if sendPostHandshakeMessageToClient {
   413  		client.HandleMessage(data, messageToReplaceEncLevel)
   414  	}
   415  	if sendPostHandshakeMessageToServer {
   416  		server.HandleMessage(data, messageToReplaceEncLevel)
   417  	}
   418  
   419  	return 1
   420  }