github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/fuzzing/handshake/fuzz.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "math" 13 mrand "math/rand" 14 "time" 15 16 "github.com/mikelsr/quic-go/fuzzing/internal/helper" 17 "github.com/mikelsr/quic-go/internal/handshake" 18 "github.com/mikelsr/quic-go/internal/protocol" 19 "github.com/mikelsr/quic-go/internal/utils" 20 "github.com/mikelsr/quic-go/internal/wire" 21 ) 22 23 var ( 24 cert, clientCert *tls.Certificate 25 certPool, clientCertPool *x509.CertPool 26 sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} 27 ) 28 29 func init() { 30 priv, err := rsa.GenerateKey(rand.Reader, 1024) 31 if err != nil { 32 log.Fatal(err) 33 } 34 cert, certPool, err = helper.GenerateCertificate(priv) 35 if err != nil { 36 log.Fatal(err) 37 } 38 39 privClient, err := rsa.GenerateKey(rand.Reader, 1024) 40 if err != nil { 41 log.Fatal(err) 42 } 43 clientCert, clientCertPool, err = helper.GenerateCertificate(privClient) 44 if err != nil { 45 log.Fatal(err) 46 } 47 } 48 49 type messageType uint8 50 51 // TLS handshake message types. 52 const ( 53 typeClientHello messageType = 1 54 typeServerHello messageType = 2 55 typeNewSessionTicket messageType = 4 56 typeEncryptedExtensions messageType = 8 57 typeCertificate messageType = 11 58 typeCertificateRequest messageType = 13 59 typeCertificateVerify messageType = 15 60 typeFinished messageType = 20 61 ) 62 63 func (m messageType) String() string { 64 switch m { 65 case typeClientHello: 66 return "ClientHello" 67 case typeServerHello: 68 return "ServerHello" 69 case typeNewSessionTicket: 70 return "NewSessionTicket" 71 case typeEncryptedExtensions: 72 return "EncryptedExtensions" 73 case typeCertificate: 74 return "Certificate" 75 case typeCertificateRequest: 76 return "CertificateRequest" 77 case typeCertificateVerify: 78 return "CertificateVerify" 79 case typeFinished: 80 return "Finished" 81 default: 82 return fmt.Sprintf("unknown message type: %d", m) 83 } 84 } 85 86 func appendSuites(suites []uint16, rand uint8) []uint16 { 87 const ( 88 s1 = tls.TLS_AES_128_GCM_SHA256 89 s2 = tls.TLS_AES_256_GCM_SHA384 90 s3 = tls.TLS_CHACHA20_POLY1305_SHA256 91 ) 92 switch rand % 4 { 93 default: 94 return suites 95 case 1: 96 return append(suites, s1) 97 case 2: 98 return append(suites, s2) 99 case 3: 100 return append(suites, s3) 101 } 102 } 103 104 // consumes 2 bits 105 func getSuites(rand uint8) []uint16 { 106 suites := make([]uint16, 0, 3) 107 for i := 1; i <= 3; i++ { 108 suites = appendSuites(suites, rand>>i%4) 109 } 110 return suites 111 } 112 113 // consumes 3 bits 114 func getClientAuth(rand uint8) tls.ClientAuthType { 115 switch rand { 116 default: 117 return tls.NoClientCert 118 case 0: 119 return tls.RequestClientCert 120 case 1: 121 return tls.RequireAnyClientCert 122 case 2: 123 return tls.VerifyClientCertIfGiven 124 case 3: 125 return tls.RequireAndVerifyClientCert 126 } 127 } 128 129 type chunk struct { 130 data []byte 131 encLevel protocol.EncryptionLevel 132 } 133 134 type stream struct { 135 chunkChan chan<- chunk 136 encLevel protocol.EncryptionLevel 137 } 138 139 func (s *stream) Write(b []byte) (int, error) { 140 data := append([]byte{}, b...) 141 select { 142 case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: 143 default: 144 panic("chunkChan too small") 145 } 146 return len(b), nil 147 } 148 149 func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) { 150 chunkChan := make(chan chunk, 10) 151 initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial} 152 handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake} 153 return chunkChan, initialStream, handshakeStream 154 } 155 156 type handshakeRunner interface { 157 OnReceivedParams(*wire.TransportParameters) 158 OnHandshakeComplete() 159 OnReceivedReadKeys() 160 DropKeys(protocol.EncryptionLevel) 161 } 162 163 type runner struct { 164 handshakeComplete chan<- struct{} 165 } 166 167 var _ handshakeRunner = &runner{} 168 169 func newRunner(handshakeComplete chan<- struct{}) *runner { 170 return &runner{handshakeComplete: handshakeComplete} 171 } 172 173 func (r *runner) OnReceivedParams(*wire.TransportParameters) {} 174 func (r *runner) OnReceivedReadKeys() {} 175 func (r *runner) OnHandshakeComplete() { 176 close(r.handshakeComplete) 177 } 178 func (r *runner) DropKeys(protocol.EncryptionLevel) {} 179 180 const ( 181 alpn = "fuzzing" 182 alpnWrong = "wrong" 183 ) 184 185 func toEncryptionLevel(n uint8) protocol.EncryptionLevel { 186 switch n % 3 { 187 default: 188 return protocol.EncryptionInitial 189 case 1: 190 return protocol.EncryptionHandshake 191 case 2: 192 return protocol.Encryption1RTT 193 } 194 } 195 196 func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel { 197 //nolint:exhaustive 198 switch encLevel { 199 case protocol.EncryptionInitial: 200 return protocol.EncryptionInitial 201 case protocol.EncryptionHandshake: 202 // Handshake opener not available. We can't possibly read a Handshake handshake message. 203 if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil { 204 return protocol.EncryptionInitial 205 } 206 return protocol.EncryptionHandshake 207 case protocol.Encryption1RTT: 208 // 1-RTT opener not available. We can't possibly read a post-handshake message. 209 if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil { 210 return maxEncLevel(cs, protocol.EncryptionHandshake) 211 } 212 return protocol.Encryption1RTT 213 default: 214 panic("unexpected encryption level") 215 } 216 } 217 218 func getTransportParameters(seed uint8) *wire.TransportParameters { 219 const maxVarInt = math.MaxUint64 / 4 220 r := mrand.New(mrand.NewSource(int64(seed))) 221 return &wire.TransportParameters{ 222 InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), 223 InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), 224 InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), 225 InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)), 226 } 227 } 228 229 // PrefixLen is the number of bytes used for configuration 230 const ( 231 PrefixLen = 12 232 confLen = 5 233 ) 234 235 // Fuzz fuzzes the TLS 1.3 handshake used by QUIC. 236 // 237 //go:generate go run ./cmd/corpus.go 238 func Fuzz(data []byte) int { 239 if len(data) < PrefixLen { 240 return -1 241 } 242 dataLen := len(data) 243 var runConfig1, runConfig2 [confLen]byte 244 copy(runConfig1[:], data) 245 data = data[confLen:] 246 messageConfig1 := data[0] 247 data = data[1:] 248 copy(runConfig2[:], data) 249 data = data[confLen:] 250 messageConfig2 := data[0] 251 data = data[1:] 252 if dataLen != len(data)+PrefixLen { 253 panic("incorrect configuration") 254 } 255 256 clientConf := &tls.Config{ 257 MinVersion: tls.VersionTLS13, 258 ServerName: "localhost", 259 NextProtos: []string{alpn}, 260 RootCAs: certPool, 261 } 262 useSessionTicketCache := helper.NthBit(runConfig1[0], 2) 263 if useSessionTicketCache { 264 clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5) 265 } 266 267 if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 { 268 return val 269 } 270 return runHandshake(runConfig2, messageConfig2, clientConf, data) 271 } 272 273 func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { 274 serverConf := &tls.Config{ 275 MinVersion: tls.VersionTLS13, 276 Certificates: []tls.Certificate{*cert}, 277 NextProtos: []string{alpn}, 278 SessionTicketKey: sessionTicketKey, 279 } 280 281 enable0RTTClient := helper.NthBit(runConfig[0], 0) 282 enable0RTTServer := helper.NthBit(runConfig[0], 1) 283 sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) 284 sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) 285 sendSessionTicket := helper.NthBit(runConfig[0], 5) 286 clientConf.CipherSuites = getSuites(runConfig[0] >> 6) 287 serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) 288 serverConf.CipherSuites = getSuites(runConfig[1] >> 6) 289 serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) 290 if helper.NthBit(runConfig[2], 0) { 291 clientConf.RootCAs = x509.NewCertPool() 292 } 293 if helper.NthBit(runConfig[2], 1) { 294 serverConf.ClientCAs = clientCertPool 295 } else { 296 serverConf.ClientCAs = x509.NewCertPool() 297 } 298 if helper.NthBit(runConfig[2], 2) { 299 serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 300 if helper.NthBit(runConfig[2], 3) { 301 return nil, errors.New("getting client config failed") 302 } 303 if helper.NthBit(runConfig[2], 4) { 304 return nil, nil 305 } 306 return serverConf, nil 307 } 308 } 309 if helper.NthBit(runConfig[2], 5) { 310 serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { 311 if helper.NthBit(runConfig[2], 6) { 312 return nil, errors.New("getting certificate failed") 313 } 314 if helper.NthBit(runConfig[2], 7) { 315 return nil, nil 316 } 317 return clientCert, nil // this certificate will be invalid 318 } 319 } 320 if helper.NthBit(runConfig[3], 0) { 321 serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 322 if helper.NthBit(runConfig[3], 1) { 323 return errors.New("certificate verification failed") 324 } 325 return nil 326 } 327 } 328 if helper.NthBit(runConfig[3], 2) { 329 clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 330 if helper.NthBit(runConfig[3], 3) { 331 return errors.New("certificate verification failed") 332 } 333 return nil 334 } 335 } 336 if helper.NthBit(runConfig[3], 4) { 337 serverConf.NextProtos = []string{alpnWrong} 338 } 339 if helper.NthBit(runConfig[3], 5) { 340 serverConf.NextProtos = []string{alpnWrong, alpn} 341 } 342 if helper.NthBit(runConfig[3], 6) { 343 serverConf.KeyLogWriter = io.Discard 344 } 345 if helper.NthBit(runConfig[3], 7) { 346 clientConf.KeyLogWriter = io.Discard 347 } 348 clientTP := getTransportParameters(runConfig[4] & 0x3) 349 if helper.NthBit(runConfig[4], 3) { 350 clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 351 } 352 serverTP := getTransportParameters(runConfig[4] & 0b00011000) 353 if helper.NthBit(runConfig[4], 3) { 354 serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 355 } 356 357 messageToReplace := messageConfig % 32 358 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) 359 360 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 361 var client, server handshake.CryptoSetup 362 clientHandshakeCompleted := make(chan struct{}) 363 client, _ = handshake.NewCryptoSetupClient( 364 cInitialStream, 365 cHandshakeStream, 366 nil, 367 protocol.ConnectionID{}, 368 clientTP, 369 newRunner(clientHandshakeCompleted), 370 clientConf, 371 enable0RTTClient, 372 utils.NewRTTStats(), 373 nil, 374 utils.DefaultLogger.WithPrefix("client"), 375 protocol.Version1, 376 ) 377 378 serverHandshakeCompleted := make(chan struct{}) 379 sChunkChan, sInitialStream, sHandshakeStream := initStreams() 380 server = handshake.NewCryptoSetupServer( 381 sInitialStream, 382 sHandshakeStream, 383 nil, 384 protocol.ConnectionID{}, 385 serverTP, 386 newRunner(serverHandshakeCompleted), 387 serverConf, 388 enable0RTTServer, 389 utils.NewRTTStats(), 390 nil, 391 utils.DefaultLogger.WithPrefix("server"), 392 protocol.Version1, 393 ) 394 395 if len(data) == 0 { 396 return -1 397 } 398 399 if err := client.StartHandshake(); err != nil { 400 log.Fatal(err) 401 } 402 403 if err := server.StartHandshake(); err != nil { 404 log.Fatal(err) 405 } 406 407 done := make(chan struct{}) 408 go func() { 409 <-serverHandshakeCompleted 410 <-clientHandshakeCompleted 411 close(done) 412 }() 413 414 messageLoop: 415 for { 416 select { 417 case c := <-cChunkChan: 418 b := c.data 419 encLevel := c.encLevel 420 if len(b) > 0 && b[0] == messageToReplace { 421 fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0])) 422 b = data 423 encLevel = maxEncLevel(server, messageToReplaceEncLevel) 424 } 425 if err := server.HandleMessage(b, encLevel); err != nil { 426 break messageLoop 427 } 428 case c := <-sChunkChan: 429 b := c.data 430 encLevel := c.encLevel 431 if len(b) > 0 && b[0] == messageToReplace { 432 fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0])) 433 b = data 434 encLevel = maxEncLevel(client, messageToReplaceEncLevel) 435 } 436 if err := client.HandleMessage(b, encLevel); err != nil { 437 break messageLoop 438 } 439 case <-done: // test done 440 break messageLoop 441 } 442 } 443 444 <-done 445 _ = client.ConnectionState() 446 _ = server.ConnectionState() 447 448 sealer, err := client.Get1RTTSealer() 449 if err != nil { 450 panic("expected to get a 1-RTT sealer") 451 } 452 opener, err := server.Get1RTTOpener() 453 if err != nil { 454 panic("expected to get a 1-RTT opener") 455 } 456 const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." 457 encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar")) 458 decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar")) 459 if err != nil { 460 panic(fmt.Sprintf("Decrypting message failed: %s", err.Error())) 461 } 462 if string(decrypted) != msg { 463 panic("wrong message") 464 } 465 466 if sendSessionTicket && !serverConf.SessionTicketsDisabled { 467 ticket, err := server.GetSessionTicket() 468 if err != nil { 469 panic(err) 470 } 471 if ticket == nil { 472 panic("empty ticket") 473 } 474 client.HandleMessage(ticket, protocol.Encryption1RTT) 475 } 476 if sendPostHandshakeMessageToClient { 477 client.HandleMessage(data, messageToReplaceEncLevel) 478 } 479 if sendPostHandshakeMessageToServer { 480 server.HandleMessage(data, messageToReplaceEncLevel) 481 } 482 483 return 1 484 }