github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/fuzzing/handshake/fuzz.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "math" 13 mrand "math/rand" 14 "net" 15 "time" 16 17 "github.com/danielpfeifer02/quic-go-prio-packs/fuzzing/internal/helper" 18 "github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake" 19 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 20 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qtls" 21 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 22 "github.com/danielpfeifer02/quic-go-prio-packs/internal/wire" 23 ) 24 25 var ( 26 cert, clientCert *tls.Certificate 27 certPool, clientCertPool *x509.CertPool 28 sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} 29 ) 30 31 func init() { 32 priv, err := rsa.GenerateKey(rand.Reader, 1024) 33 if err != nil { 34 log.Fatal(err) 35 } 36 cert, certPool, err = helper.GenerateCertificate(priv) 37 if err != nil { 38 log.Fatal(err) 39 } 40 41 privClient, err := rsa.GenerateKey(rand.Reader, 1024) 42 if err != nil { 43 log.Fatal(err) 44 } 45 clientCert, clientCertPool, err = helper.GenerateCertificate(privClient) 46 if err != nil { 47 log.Fatal(err) 48 } 49 } 50 51 type messageType uint8 52 53 // TLS handshake message types. 54 const ( 55 typeClientHello messageType = 1 56 typeServerHello messageType = 2 57 typeNewSessionTicket messageType = 4 58 typeEncryptedExtensions messageType = 8 59 typeCertificate messageType = 11 60 typeCertificateRequest messageType = 13 61 typeCertificateVerify messageType = 15 62 typeFinished messageType = 20 63 ) 64 65 func (m messageType) String() string { 66 switch m { 67 case typeClientHello: 68 return "ClientHello" 69 case typeServerHello: 70 return "ServerHello" 71 case typeNewSessionTicket: 72 return "NewSessionTicket" 73 case typeEncryptedExtensions: 74 return "EncryptedExtensions" 75 case typeCertificate: 76 return "Certificate" 77 case typeCertificateRequest: 78 return "CertificateRequest" 79 case typeCertificateVerify: 80 return "CertificateVerify" 81 case typeFinished: 82 return "Finished" 83 default: 84 return fmt.Sprintf("unknown message type: %d", m) 85 } 86 } 87 88 // consumes 3 bits 89 func getClientAuth(rand uint8) tls.ClientAuthType { 90 switch rand { 91 default: 92 return tls.NoClientCert 93 case 0: 94 return tls.RequestClientCert 95 case 1: 96 return tls.RequireAnyClientCert 97 case 2: 98 return tls.VerifyClientCertIfGiven 99 case 3: 100 return tls.RequireAndVerifyClientCert 101 } 102 } 103 104 const ( 105 alpn = "fuzzing" 106 alpnWrong = "wrong" 107 ) 108 109 func toEncryptionLevel(n uint8) protocol.EncryptionLevel { 110 switch n % 3 { 111 default: 112 return protocol.EncryptionInitial 113 case 1: 114 return protocol.EncryptionHandshake 115 case 2: 116 return protocol.Encryption1RTT 117 } 118 } 119 120 func getTransportParameters(seed uint8) *wire.TransportParameters { 121 const maxVarInt = math.MaxUint64 / 4 122 r := mrand.New(mrand.NewSource(int64(seed))) 123 return &wire.TransportParameters{ 124 ActiveConnectionIDLimit: 2, 125 InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), 126 InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), 127 InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), 128 InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)), 129 } 130 } 131 132 // PrefixLen is the number of bytes used for configuration 133 const ( 134 PrefixLen = 12 135 confLen = 5 136 ) 137 138 // Fuzz fuzzes the TLS 1.3 handshake used by QUIC. 139 // 140 //go:generate go run ./cmd/corpus.go 141 func Fuzz(data []byte) int { 142 if len(data) < PrefixLen { 143 return -1 144 } 145 dataLen := len(data) 146 var runConfig1, runConfig2 [confLen]byte 147 copy(runConfig1[:], data) 148 data = data[confLen:] 149 messageConfig1 := data[0] 150 data = data[1:] 151 copy(runConfig2[:], data) 152 data = data[confLen:] 153 messageConfig2 := data[0] 154 data = data[1:] 155 if dataLen != len(data)+PrefixLen { 156 panic("incorrect configuration") 157 } 158 159 clientConf := &tls.Config{ 160 MinVersion: tls.VersionTLS13, 161 ServerName: "localhost", 162 NextProtos: []string{alpn}, 163 RootCAs: certPool, 164 } 165 useSessionTicketCache := helper.NthBit(runConfig1[0], 2) 166 if useSessionTicketCache { 167 clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5) 168 } 169 170 if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 { 171 return val 172 } 173 return runHandshake(runConfig2, messageConfig2, clientConf, data) 174 } 175 176 func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { 177 serverConf := &tls.Config{ 178 MinVersion: tls.VersionTLS13, 179 Certificates: []tls.Certificate{*cert}, 180 NextProtos: []string{alpn}, 181 SessionTicketKey: sessionTicketKey, 182 } 183 184 // This sets the cipher suite for both client and server. 185 // The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server. 186 resetCipherSuite := func() {} 187 switch (runConfig[0] >> 6) % 4 { 188 case 0: 189 resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256) 190 case 1: 191 resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384) 192 case 3: 193 resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) 194 default: 195 } 196 defer resetCipherSuite() 197 198 enable0RTTClient := helper.NthBit(runConfig[0], 0) 199 enable0RTTServer := helper.NthBit(runConfig[0], 1) 200 sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) 201 sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) 202 sendSessionTicket := helper.NthBit(runConfig[0], 5) 203 serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) 204 serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) 205 if helper.NthBit(runConfig[2], 0) { 206 clientConf.RootCAs = x509.NewCertPool() 207 } 208 if helper.NthBit(runConfig[2], 1) { 209 serverConf.ClientCAs = clientCertPool 210 } else { 211 serverConf.ClientCAs = x509.NewCertPool() 212 } 213 if helper.NthBit(runConfig[2], 2) { 214 serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 215 if helper.NthBit(runConfig[2], 3) { 216 return nil, errors.New("getting client config failed") 217 } 218 if helper.NthBit(runConfig[2], 4) { 219 return nil, nil 220 } 221 return serverConf, nil 222 } 223 } 224 if helper.NthBit(runConfig[2], 5) { 225 serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { 226 if helper.NthBit(runConfig[2], 6) { 227 return nil, errors.New("getting certificate failed") 228 } 229 if helper.NthBit(runConfig[2], 7) { 230 return nil, nil 231 } 232 return clientCert, nil // this certificate will be invalid 233 } 234 } 235 if helper.NthBit(runConfig[3], 0) { 236 serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 237 if helper.NthBit(runConfig[3], 1) { 238 return errors.New("certificate verification failed") 239 } 240 return nil 241 } 242 } 243 if helper.NthBit(runConfig[3], 2) { 244 clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 245 if helper.NthBit(runConfig[3], 3) { 246 return errors.New("certificate verification failed") 247 } 248 return nil 249 } 250 } 251 if helper.NthBit(runConfig[3], 4) { 252 serverConf.NextProtos = []string{alpnWrong} 253 } 254 if helper.NthBit(runConfig[3], 5) { 255 serverConf.NextProtos = []string{alpnWrong, alpn} 256 } 257 if helper.NthBit(runConfig[3], 6) { 258 serverConf.KeyLogWriter = io.Discard 259 } 260 if helper.NthBit(runConfig[3], 7) { 261 clientConf.KeyLogWriter = io.Discard 262 } 263 clientTP := getTransportParameters(runConfig[4] & 0x3) 264 if helper.NthBit(runConfig[4], 3) { 265 clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 266 } 267 serverTP := getTransportParameters(runConfig[4] & 0b00011000) 268 if helper.NthBit(runConfig[4], 3) { 269 serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 270 } 271 272 messageToReplace := messageConfig % 32 273 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) 274 275 if len(data) == 0 { 276 return -1 277 } 278 279 client := handshake.NewCryptoSetupClient( 280 protocol.ConnectionID{}, 281 clientTP, 282 clientConf, 283 enable0RTTClient, 284 utils.NewRTTStats(), 285 nil, 286 utils.DefaultLogger.WithPrefix("client"), 287 protocol.Version1, 288 ) 289 if err := client.StartHandshake(); err != nil { 290 log.Fatal(err) 291 } 292 defer client.Close() 293 294 server := handshake.NewCryptoSetupServer( 295 protocol.ConnectionID{}, 296 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 297 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 298 serverTP, 299 serverConf, 300 enable0RTTServer, 301 utils.NewRTTStats(), 302 nil, 303 utils.DefaultLogger.WithPrefix("server"), 304 protocol.Version1, 305 ) 306 if err := server.StartHandshake(); err != nil { 307 log.Fatal(err) 308 } 309 defer server.Close() 310 311 var clientHandshakeComplete, serverHandshakeComplete bool 312 for { 313 var processedEvent bool 314 clientLoop: 315 for { 316 ev := client.NextEvent() 317 //nolint:exhaustive // only need to process a few events 318 switch ev.Kind { 319 case handshake.EventNoEvent: 320 if !processedEvent && !clientHandshakeComplete { // handshake stuck 321 return 1 322 } 323 break clientLoop 324 case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: 325 msg := ev.Data 326 encLevel := protocol.EncryptionInitial 327 if ev.Kind == handshake.EventWriteHandshakeData { 328 encLevel = protocol.EncryptionHandshake 329 } 330 if msg[0] == messageToReplace { 331 fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) 332 msg = data 333 encLevel = messageToReplaceEncLevel 334 } 335 if err := server.HandleMessage(msg, encLevel); err != nil { 336 return 1 337 } 338 case handshake.EventHandshakeComplete: 339 clientHandshakeComplete = true 340 } 341 processedEvent = true 342 } 343 344 processedEvent = false 345 serverLoop: 346 for { 347 ev := server.NextEvent() 348 //nolint:exhaustive // only need to process a few events 349 switch ev.Kind { 350 case handshake.EventNoEvent: 351 if !processedEvent && !serverHandshakeComplete { // handshake stuck 352 return 1 353 } 354 break serverLoop 355 case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: 356 encLevel := protocol.EncryptionInitial 357 if ev.Kind == handshake.EventWriteHandshakeData { 358 encLevel = protocol.EncryptionHandshake 359 } 360 msg := ev.Data 361 if msg[0] == messageToReplace { 362 fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) 363 msg = data 364 encLevel = messageToReplaceEncLevel 365 } 366 if err := client.HandleMessage(msg, encLevel); err != nil { 367 return 1 368 } 369 case handshake.EventHandshakeComplete: 370 serverHandshakeComplete = true 371 } 372 processedEvent = true 373 } 374 375 if serverHandshakeComplete && clientHandshakeComplete { 376 break 377 } 378 } 379 380 _ = client.ConnectionState() 381 _ = server.ConnectionState() 382 383 sealer, err := client.Get1RTTSealer() 384 if err != nil { 385 panic("expected to get a 1-RTT sealer") 386 } 387 opener, err := server.Get1RTTOpener() 388 if err != nil { 389 panic("expected to get a 1-RTT opener") 390 } 391 const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." 392 encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar")) 393 decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar")) 394 if err != nil { 395 panic(fmt.Sprintf("Decrypting message failed: %s", err.Error())) 396 } 397 if string(decrypted) != msg { 398 panic("wrong message") 399 } 400 401 if sendSessionTicket && !serverConf.SessionTicketsDisabled { 402 ticket, err := server.GetSessionTicket() 403 if err != nil { 404 panic(err) 405 } 406 if ticket == nil { 407 panic("empty ticket") 408 } 409 client.HandleMessage(ticket, protocol.Encryption1RTT) 410 } 411 412 if sendPostHandshakeMessageToClient { 413 fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) 414 client.HandleMessage(data, messageToReplaceEncLevel) 415 } 416 if sendPostHandshakeMessageToServer { 417 fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel) 418 server.HandleMessage(data, messageToReplaceEncLevel) 419 } 420 421 return 1 422 }