github.com/pion/dtls/v2@v2.2.12/handshaker_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "context" 9 "crypto/tls" 10 "errors" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 16 "github.com/pion/dtls/v2/pkg/crypto/signaturehash" 17 "github.com/pion/dtls/v2/pkg/protocol/alert" 18 "github.com/pion/dtls/v2/pkg/protocol/handshake" 19 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 20 "github.com/pion/logging" 21 "github.com/pion/transport/v2/test" 22 ) 23 24 const nonZeroRetransmitInterval = 100 * time.Millisecond 25 26 // Test that writes to the key log are in the correct format and only applies 27 // when a key log writer is given. 28 func TestWriteKeyLog(t *testing.T) { 29 var buf bytes.Buffer 30 cfg := handshakeConfig{ 31 keyLogWriter: &buf, 32 } 33 cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) 34 35 // Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret> 36 // https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format 37 want := "LABEL aabbcc ddeeff\n" 38 if buf.String() != want { 39 t.Fatalf("Got %s want %s", buf.String(), want) 40 } 41 42 // no key log writer = no writes 43 cfg = handshakeConfig{} 44 cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) 45 } 46 47 func TestHandshaker(t *testing.T) { 48 // Check for leaking routines 49 report := test.CheckRoutines(t) 50 defer report() 51 52 loggerFactory := logging.NewDefaultLoggerFactory() 53 logger := loggerFactory.NewLogger("dtls") 54 55 cipherSuites, err := parseCipherSuites(nil, nil, true, false) 56 if err != nil { 57 t.Fatal(err) 58 } 59 clientCert, err := selfsign.GenerateSelfSigned() 60 if err != nil { 61 t.Fatal(err) 62 } 63 64 genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){ 65 "PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { 66 return TestEndpoint{}, TestEndpoint{}, nil 67 }, 68 69 "HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { 70 var ( 71 cntHelloVerifyRequest = 0 72 cntClientHelloNoCookie = 0 73 ) 74 const helloVerifyDrop = 5 75 76 clientEndpoint := TestEndpoint{ 77 Filter: func(p *packet) bool { 78 h, ok := p.record.Content.(*handshake.Handshake) 79 if !ok { 80 return true 81 } 82 if hmch, ok := h.Message.(*handshake.MessageClientHello); ok { 83 if len(hmch.Cookie) == 0 { 84 cntClientHelloNoCookie++ 85 } 86 } 87 return true 88 }, 89 } 90 91 serverEndpoint := TestEndpoint{ 92 Filter: func(p *packet) bool { 93 h, ok := p.record.Content.(*handshake.Handshake) 94 if !ok { 95 return true 96 } 97 if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { 98 cntHelloVerifyRequest++ 99 return cntHelloVerifyRequest > helloVerifyDrop 100 } 101 return true 102 }, 103 } 104 105 report := func(t *testing.T) { 106 if cntHelloVerifyRequest != helloVerifyDrop+1 { 107 t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) 108 } 109 if cntClientHelloNoCookie != cntHelloVerifyRequest { 110 t.Errorf( 111 "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", 112 cntHelloVerifyRequest, cntClientHelloNoCookie, 113 ) 114 } 115 } 116 117 return clientEndpoint, serverEndpoint, report 118 }, 119 120 "NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { 121 var ( 122 cntClientFinished = 0 123 cntServerFinished = 0 124 ) 125 126 clientEndpoint := TestEndpoint{ 127 Filter: func(p *packet) bool { 128 h, ok := p.record.Content.(*handshake.Handshake) 129 if !ok { 130 return true 131 } 132 if _, ok := h.Message.(*handshake.MessageFinished); ok { 133 cntClientFinished++ 134 } 135 return true 136 }, 137 } 138 139 serverEndpoint := TestEndpoint{ 140 Filter: func(p *packet) bool { 141 h, ok := p.record.Content.(*handshake.Handshake) 142 if !ok { 143 return true 144 } 145 if _, ok := h.Message.(*handshake.MessageFinished); ok { 146 cntServerFinished++ 147 } 148 return true 149 }, 150 } 151 152 report := func(t *testing.T) { 153 if cntClientFinished != 1 { 154 t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished) 155 } 156 if cntServerFinished != 1 { 157 t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished) 158 } 159 } 160 161 return clientEndpoint, serverEndpoint, report 162 }, 163 164 "SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { 165 var ( 166 cntClientFinished = 0 167 isClientFinished = false 168 cntClientFinishedLastRetransmit = 0 169 cntServerFinished = 0 170 isServerFinished = false 171 cntServerFinishedLastRetransmit = 0 172 ) 173 174 clientEndpoint := TestEndpoint{ 175 Filter: func(p *packet) bool { 176 h, ok := p.record.Content.(*handshake.Handshake) 177 if !ok { 178 return true 179 } 180 if _, ok := h.Message.(*handshake.MessageFinished); ok { 181 if isClientFinished { 182 cntClientFinishedLastRetransmit++ 183 } else { 184 cntClientFinished++ 185 } 186 } 187 return true 188 }, 189 Delay: 0, 190 OnFinished: func() { 191 isClientFinished = true 192 }, 193 FinishWait: 2000 * time.Millisecond, 194 } 195 196 serverEndpoint := TestEndpoint{ 197 Filter: func(p *packet) bool { 198 h, ok := p.record.Content.(*handshake.Handshake) 199 if !ok { 200 return true 201 } 202 if _, ok := h.Message.(*handshake.MessageFinished); ok { 203 if isServerFinished { 204 cntServerFinishedLastRetransmit++ 205 } else { 206 cntServerFinished++ 207 } 208 } 209 return true 210 }, 211 Delay: 1000 * time.Millisecond, 212 OnFinished: func() { 213 isServerFinished = true 214 }, 215 FinishWait: 2000 * time.Millisecond, 216 } 217 218 report := func(t *testing.T) { 219 // with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client 220 // using a range of 9 - 11 for checking 221 if cntClientFinished < 8 || cntClientFinished > 11 { 222 t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished) 223 } 224 if !isClientFinished { 225 t.Errorf("Client is not finished") 226 } 227 // there should be no `Finished` last retransmit from client 228 if cntClientFinishedLastRetransmit != 0 { 229 t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit) 230 } 231 if cntServerFinished < 1 { 232 t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished) 233 } 234 if !isServerFinished { 235 t.Errorf("Server is not finished") 236 } 237 // there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`. 238 if cntServerFinishedLastRetransmit < 1 { 239 t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit) 240 } 241 } 242 243 return clientEndpoint, serverEndpoint, report 244 }, 245 } 246 247 for name, filters := range genFilters { 248 clientEndpoint, serverEndpoint, report := filters() 249 t.Run(name, func(t *testing.T) { 250 ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) 251 defer cancel() 252 253 if report != nil { 254 defer report(t) 255 } 256 257 ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint) 258 ca.state.isClient = true 259 260 var wg sync.WaitGroup 261 wg.Add(2) 262 263 ctxCliFinished, cancelCli := context.WithCancel(ctx) 264 ctxSrvFinished, cancelSrv := context.WithCancel(ctx) 265 go func() { 266 defer wg.Done() 267 cfg := &handshakeConfig{ 268 localCipherSuites: cipherSuites, 269 localCertificates: []tls.Certificate{clientCert}, 270 ellipticCurves: defaultCurves, 271 localSignatureSchemes: signaturehash.Algorithms(), 272 insecureSkipVerify: true, 273 log: logger, 274 onFlightState: func(f flightVal, s handshakeState) { 275 if s == handshakeFinished { 276 if clientEndpoint.OnFinished != nil { 277 clientEndpoint.OnFinished() 278 } 279 time.AfterFunc(clientEndpoint.FinishWait, func() { 280 cancelCli() 281 }) 282 } 283 }, 284 retransmitInterval: nonZeroRetransmitInterval, 285 } 286 287 fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) 288 err := fsm.Run(ctx, ca, handshakePreparing) 289 switch { 290 case errors.Is(err, context.Canceled): 291 case errors.Is(err, context.DeadlineExceeded): 292 t.Error("Timeout") 293 default: 294 t.Error(err) 295 } 296 }() 297 298 go func() { 299 defer wg.Done() 300 cfg := &handshakeConfig{ 301 localCipherSuites: cipherSuites, 302 localCertificates: []tls.Certificate{clientCert}, 303 ellipticCurves: defaultCurves, 304 localSignatureSchemes: signaturehash.Algorithms(), 305 insecureSkipVerify: true, 306 log: logger, 307 onFlightState: func(f flightVal, s handshakeState) { 308 if s == handshakeFinished { 309 if serverEndpoint.OnFinished != nil { 310 serverEndpoint.OnFinished() 311 } 312 time.AfterFunc(serverEndpoint.FinishWait, func() { 313 cancelSrv() 314 }) 315 } 316 }, 317 retransmitInterval: nonZeroRetransmitInterval, 318 } 319 320 fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) 321 err := fsm.Run(ctx, cb, handshakePreparing) 322 switch { 323 case errors.Is(err, context.Canceled): 324 case errors.Is(err, context.DeadlineExceeded): 325 t.Error("Timeout") 326 default: 327 t.Error(err) 328 } 329 }() 330 331 <-ctxCliFinished.Done() 332 <-ctxSrvFinished.Done() 333 334 cancel() 335 wg.Wait() 336 }) 337 } 338 } 339 340 type packetFilter func(p *packet) bool 341 342 type TestEndpoint struct { 343 Filter packetFilter 344 Delay time.Duration 345 OnFinished func() 346 FinishWait time.Duration 347 } 348 349 func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) { 350 ca := newHandshakeCache() 351 cb := newHandshakeCache() 352 chA := make(chan chan struct{}) 353 chB := make(chan chan struct{}) 354 return &flightTestConn{ 355 handshakeCache: ca, 356 otherEndCache: cb, 357 recv: chA, 358 otherEndRecv: chB, 359 done: ctx.Done(), 360 filter: clientEndpoint.Filter, 361 delay: clientEndpoint.Delay, 362 }, &flightTestConn{ 363 handshakeCache: cb, 364 otherEndCache: ca, 365 recv: chB, 366 otherEndRecv: chA, 367 done: ctx.Done(), 368 filter: serverEndpoint.Filter, 369 delay: serverEndpoint.Delay, 370 } 371 } 372 373 type flightTestConn struct { 374 state State 375 handshakeCache *handshakeCache 376 recv chan chan struct{} 377 done <-chan struct{} 378 epoch uint16 379 380 filter packetFilter 381 382 delay time.Duration 383 384 otherEndCache *handshakeCache 385 otherEndRecv chan chan struct{} 386 } 387 388 func (c *flightTestConn) recvHandshake() <-chan chan struct{} { 389 return c.recv 390 } 391 392 func (c *flightTestConn) setLocalEpoch(epoch uint16) { 393 c.epoch = epoch 394 } 395 396 func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) error { 397 return nil 398 } 399 400 func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error { 401 time.Sleep(c.delay) 402 for _, p := range pkts { 403 if c.filter != nil && !c.filter(p) { 404 continue 405 } 406 if h, ok := p.record.Content.(*handshake.Handshake); ok { 407 handshakeRaw, err := p.record.Marshal() 408 if err != nil { 409 return err 410 } 411 412 c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) 413 414 content, err := h.Message.Marshal() 415 if err != nil { 416 return err 417 } 418 h.Header.Length = uint32(len(content)) 419 h.Header.FragmentLength = uint32(len(content)) 420 hdr, err := h.Header.Marshal() 421 if err != nil { 422 return err 423 } 424 c.otherEndCache.push( 425 append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) 426 } 427 } 428 go func() { 429 select { 430 case c.otherEndRecv <- make(chan struct{}): 431 case <-c.done: 432 } 433 }() 434 435 // Avoid deadlock on JS/WASM environment due to context switch problem. 436 time.Sleep(10 * time.Millisecond) 437 438 return nil 439 } 440 441 func (c *flightTestConn) handleQueuedPackets(context.Context) error { 442 return nil 443 } 444 445 func (c *flightTestConn) sessionKey() []byte { 446 return nil 447 }