github.com/tumi8/quic-go@v0.37.4-tum/fuzzing/handshake/fuzz.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "math" 13 mrand "math/rand" 14 "net" 15 "time" 16 17 "github.com/tumi8/quic-go/fuzzing/internal/helper" 18 "github.com/tumi8/quic-go/noninternal/handshake" 19 "github.com/tumi8/quic-go/noninternal/protocol" 20 "github.com/tumi8/quic-go/noninternal/utils" 21 "github.com/tumi8/quic-go/noninternal/wire" 22 ) 23 24 var ( 25 cert, clientCert *tls.Certificate 26 certPool, clientCertPool *x509.CertPool 27 sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} 28 ) 29 30 func init() { 31 priv, err := rsa.GenerateKey(rand.Reader, 1024) 32 if err != nil { 33 log.Fatal(err) 34 } 35 cert, certPool, err = helper.GenerateCertificate(priv) 36 if err != nil { 37 log.Fatal(err) 38 } 39 40 privClient, err := rsa.GenerateKey(rand.Reader, 1024) 41 if err != nil { 42 log.Fatal(err) 43 } 44 clientCert, clientCertPool, err = helper.GenerateCertificate(privClient) 45 if err != nil { 46 log.Fatal(err) 47 } 48 } 49 50 type messageType uint8 51 52 // TLS handshake message types. 53 const ( 54 typeClientHello messageType = 1 55 typeServerHello messageType = 2 56 typeNewSessionTicket messageType = 4 57 typeEncryptedExtensions messageType = 8 58 typeCertificate messageType = 11 59 typeCertificateRequest messageType = 13 60 typeCertificateVerify messageType = 15 61 typeFinished messageType = 20 62 ) 63 64 func (m messageType) String() string { 65 switch m { 66 case typeClientHello: 67 return "ClientHello" 68 case typeServerHello: 69 return "ServerHello" 70 case typeNewSessionTicket: 71 return "NewSessionTicket" 72 case typeEncryptedExtensions: 73 return "EncryptedExtensions" 74 case typeCertificate: 75 return "Certificate" 76 case typeCertificateRequest: 77 return "CertificateRequest" 78 case typeCertificateVerify: 79 return "CertificateVerify" 80 case typeFinished: 81 return "Finished" 82 default: 83 return fmt.Sprintf("unknown message type: %d", m) 84 } 85 } 86 87 func appendSuites(suites []uint16, rand uint8) []uint16 { 88 const ( 89 s1 = tls.TLS_AES_128_GCM_SHA256 90 s2 = tls.TLS_AES_256_GCM_SHA384 91 s3 = tls.TLS_CHACHA20_POLY1305_SHA256 92 ) 93 switch rand % 4 { 94 default: 95 return suites 96 case 1: 97 return append(suites, s1) 98 case 2: 99 return append(suites, s2) 100 case 3: 101 return append(suites, s3) 102 } 103 } 104 105 // consumes 2 bits 106 func getSuites(rand uint8) []uint16 { 107 suites := make([]uint16, 0, 3) 108 for i := 1; i <= 3; i++ { 109 suites = appendSuites(suites, rand>>i%4) 110 } 111 return suites 112 } 113 114 // consumes 3 bits 115 func getClientAuth(rand uint8) tls.ClientAuthType { 116 switch rand { 117 default: 118 return tls.NoClientCert 119 case 0: 120 return tls.RequestClientCert 121 case 1: 122 return tls.RequireAnyClientCert 123 case 2: 124 return tls.VerifyClientCertIfGiven 125 case 3: 126 return tls.RequireAndVerifyClientCert 127 } 128 } 129 130 const ( 131 alpn = "fuzzing" 132 alpnWrong = "wrong" 133 ) 134 135 func toEncryptionLevel(n uint8) protocol.EncryptionLevel { 136 switch n % 3 { 137 default: 138 return protocol.EncryptionInitial 139 case 1: 140 return protocol.EncryptionHandshake 141 case 2: 142 return protocol.Encryption1RTT 143 } 144 } 145 146 func getTransportParameters(seed uint8) *wire.TransportParameters { 147 const maxVarInt = math.MaxUint64 / 4 148 r := mrand.New(mrand.NewSource(int64(seed))) 149 return &wire.TransportParameters{ 150 InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), 151 InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), 152 InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), 153 InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)), 154 } 155 } 156 157 // PrefixLen is the number of bytes used for configuration 158 const ( 159 PrefixLen = 12 160 confLen = 5 161 ) 162 163 // Fuzz fuzzes the TLS 1.3 handshake used by QUIC. 164 // 165 //go:generate go run ./cmd/corpus.go 166 func Fuzz(data []byte) int { 167 if len(data) < PrefixLen { 168 return -1 169 } 170 dataLen := len(data) 171 var runConfig1, runConfig2 [confLen]byte 172 copy(runConfig1[:], data) 173 data = data[confLen:] 174 messageConfig1 := data[0] 175 data = data[1:] 176 copy(runConfig2[:], data) 177 data = data[confLen:] 178 messageConfig2 := data[0] 179 data = data[1:] 180 if dataLen != len(data)+PrefixLen { 181 panic("incorrect configuration") 182 } 183 184 clientConf := &tls.Config{ 185 MinVersion: tls.VersionTLS13, 186 ServerName: "localhost", 187 NextProtos: []string{alpn}, 188 RootCAs: certPool, 189 } 190 useSessionTicketCache := helper.NthBit(runConfig1[0], 2) 191 if useSessionTicketCache { 192 clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5) 193 } 194 195 if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 { 196 return val 197 } 198 return runHandshake(runConfig2, messageConfig2, clientConf, data) 199 } 200 201 func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { 202 serverConf := &tls.Config{ 203 MinVersion: tls.VersionTLS13, 204 Certificates: []tls.Certificate{*cert}, 205 NextProtos: []string{alpn}, 206 SessionTicketKey: sessionTicketKey, 207 } 208 209 enable0RTTClient := helper.NthBit(runConfig[0], 0) 210 enable0RTTServer := helper.NthBit(runConfig[0], 1) 211 sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) 212 sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) 213 sendSessionTicket := helper.NthBit(runConfig[0], 5) 214 clientConf.CipherSuites = getSuites(runConfig[0] >> 6) 215 serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) 216 serverConf.CipherSuites = getSuites(runConfig[1] >> 6) 217 serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) 218 if helper.NthBit(runConfig[2], 0) { 219 clientConf.RootCAs = x509.NewCertPool() 220 } 221 if helper.NthBit(runConfig[2], 1) { 222 serverConf.ClientCAs = clientCertPool 223 } else { 224 serverConf.ClientCAs = x509.NewCertPool() 225 } 226 if helper.NthBit(runConfig[2], 2) { 227 serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 228 if helper.NthBit(runConfig[2], 3) { 229 return nil, errors.New("getting client config failed") 230 } 231 if helper.NthBit(runConfig[2], 4) { 232 return nil, nil 233 } 234 return serverConf, nil 235 } 236 } 237 if helper.NthBit(runConfig[2], 5) { 238 serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { 239 if helper.NthBit(runConfig[2], 6) { 240 return nil, errors.New("getting certificate failed") 241 } 242 if helper.NthBit(runConfig[2], 7) { 243 return nil, nil 244 } 245 return clientCert, nil // this certificate will be invalid 246 } 247 } 248 if helper.NthBit(runConfig[3], 0) { 249 serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 250 if helper.NthBit(runConfig[3], 1) { 251 return errors.New("certificate verification failed") 252 } 253 return nil 254 } 255 } 256 if helper.NthBit(runConfig[3], 2) { 257 clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 258 if helper.NthBit(runConfig[3], 3) { 259 return errors.New("certificate verification failed") 260 } 261 return nil 262 } 263 } 264 if helper.NthBit(runConfig[3], 4) { 265 serverConf.NextProtos = []string{alpnWrong} 266 } 267 if helper.NthBit(runConfig[3], 5) { 268 serverConf.NextProtos = []string{alpnWrong, alpn} 269 } 270 if helper.NthBit(runConfig[3], 6) { 271 serverConf.KeyLogWriter = io.Discard 272 } 273 if helper.NthBit(runConfig[3], 7) { 274 clientConf.KeyLogWriter = io.Discard 275 } 276 clientTP := getTransportParameters(runConfig[4] & 0x3) 277 if helper.NthBit(runConfig[4], 3) { 278 clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 279 } 280 serverTP := getTransportParameters(runConfig[4] & 0b00011000) 281 if helper.NthBit(runConfig[4], 3) { 282 serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 283 } 284 285 messageToReplace := messageConfig % 32 286 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) 287 288 if len(data) == 0 { 289 return -1 290 } 291 292 client := handshake.NewCryptoSetupClient( 293 protocol.ConnectionID{}, 294 clientTP, 295 clientConf, 296 enable0RTTClient, 297 utils.NewRTTStats(), 298 nil, 299 utils.DefaultLogger.WithPrefix("client"), 300 protocol.Version1, 301 ) 302 if err := client.StartHandshake(); err != nil { 303 log.Fatal(err) 304 } 305 306 server := handshake.NewCryptoSetupServer( 307 protocol.ConnectionID{}, 308 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 309 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 310 serverTP, 311 serverConf, 312 enable0RTTServer, 313 utils.NewRTTStats(), 314 nil, 315 utils.DefaultLogger.WithPrefix("server"), 316 protocol.Version1, 317 ) 318 if err := server.StartHandshake(); err != nil { 319 log.Fatal(err) 320 } 321 322 var clientHandshakeComplete, serverHandshakeComplete bool 323 for { 324 clientLoop: 325 for { 326 var processedEvent bool 327 ev := client.NextEvent() 328 //nolint:exhaustive // only need to process a few events 329 switch ev.Kind { 330 case handshake.EventNoEvent: 331 if !processedEvent && !clientHandshakeComplete { // handshake stuck 332 return 1 333 } 334 break clientLoop 335 case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: 336 msg := ev.Data 337 if msg[0] == messageToReplace { 338 fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) 339 msg = data 340 } 341 if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil { 342 return 1 343 } 344 case handshake.EventHandshakeComplete: 345 clientHandshakeComplete = true 346 } 347 processedEvent = true 348 } 349 350 serverLoop: 351 for { 352 var processedEvent bool 353 ev := server.NextEvent() 354 //nolint:exhaustive // only need to process a few events 355 switch ev.Kind { 356 case handshake.EventNoEvent: 357 if !processedEvent && !serverHandshakeComplete { // handshake stuck 358 return 1 359 } 360 break serverLoop 361 case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: 362 msg := ev.Data 363 if msg[0] == messageToReplace { 364 fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) 365 msg = data 366 } 367 if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil { 368 return 1 369 } 370 case handshake.EventHandshakeComplete: 371 serverHandshakeComplete = true 372 } 373 processedEvent = true 374 } 375 376 if serverHandshakeComplete && clientHandshakeComplete { 377 break 378 } 379 } 380 381 _ = client.ConnectionState() 382 _ = server.ConnectionState() 383 384 sealer, err := client.Get1RTTSealer() 385 if err != nil { 386 panic("expected to get a 1-RTT sealer") 387 } 388 opener, err := server.Get1RTTOpener() 389 if err != nil { 390 panic("expected to get a 1-RTT opener") 391 } 392 const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." 393 encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar")) 394 decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar")) 395 if err != nil { 396 panic(fmt.Sprintf("Decrypting message failed: %s", err.Error())) 397 } 398 if string(decrypted) != msg { 399 panic("wrong message") 400 } 401 402 if sendSessionTicket && !serverConf.SessionTicketsDisabled { 403 ticket, err := server.GetSessionTicket() 404 if err != nil { 405 panic(err) 406 } 407 if ticket == nil { 408 panic("empty ticket") 409 } 410 client.HandleMessage(ticket, protocol.Encryption1RTT) 411 } 412 if sendPostHandshakeMessageToClient { 413 client.HandleMessage(data, messageToReplaceEncLevel) 414 } 415 if sendPostHandshakeMessageToServer { 416 server.HandleMessage(data, messageToReplaceEncLevel) 417 } 418 419 return 1 420 }