github.com/pion/dtls/v2@v2.2.12/conn_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "context" 9 "crypto" 10 "crypto/ecdsa" 11 cryptoElliptic "crypto/elliptic" 12 "crypto/rand" 13 "crypto/rsa" 14 "crypto/tls" 15 "crypto/x509" 16 "encoding/hex" 17 "errors" 18 "fmt" 19 "io" 20 "net" 21 "strings" 22 "sync" 23 "sync/atomic" 24 "testing" 25 "time" 26 27 "github.com/pion/dtls/v2/internal/ciphersuite" 28 "github.com/pion/dtls/v2/pkg/crypto/elliptic" 29 "github.com/pion/dtls/v2/pkg/crypto/hash" 30 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 31 "github.com/pion/dtls/v2/pkg/crypto/signature" 32 "github.com/pion/dtls/v2/pkg/crypto/signaturehash" 33 "github.com/pion/dtls/v2/pkg/protocol" 34 "github.com/pion/dtls/v2/pkg/protocol/alert" 35 "github.com/pion/dtls/v2/pkg/protocol/extension" 36 "github.com/pion/dtls/v2/pkg/protocol/handshake" 37 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 38 "github.com/pion/logging" 39 "github.com/pion/transport/v2/dpipe" 40 "github.com/pion/transport/v2/test" 41 ) 42 43 var ( 44 errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity") 45 errPSKRejected = errors.New("PSK Rejected") 46 errNotExpectedChain = errors.New("not expected chain") 47 errExpecedChain = errors.New("expected chain") 48 errWrongCert = errors.New("wrong cert") 49 ) 50 51 func TestStressDuplex(t *testing.T) { 52 // Limit runtime in case of deadlocks 53 lim := test.TimeOut(time.Second * 20) 54 defer lim.Stop() 55 56 // Check for leaking routines 57 report := test.CheckRoutines(t) 58 defer report() 59 60 // Run the test 61 stressDuplex(t) 62 } 63 64 func stressDuplex(t *testing.T) { 65 ca, cb, err := pipeMemory() 66 if err != nil { 67 t.Fatal(err) 68 } 69 70 defer func() { 71 err = ca.Close() 72 if err != nil { 73 t.Fatal(err) 74 } 75 err = cb.Close() 76 if err != nil { 77 t.Fatal(err) 78 } 79 }() 80 81 opt := test.Options{ 82 MsgSize: 2048, 83 MsgCount: 100, 84 } 85 86 err = test.StressDuplex(ca, cb, opt) 87 if err != nil { 88 t.Fatal(err) 89 } 90 } 91 92 func TestRoutineLeakOnClose(t *testing.T) { 93 // Limit runtime in case of deadlocks 94 lim := test.TimeOut(5 * time.Second) 95 defer lim.Stop() 96 97 // Check for leaking routines 98 report := test.CheckRoutines(t) 99 defer report() 100 101 ca, cb, err := pipeMemory() 102 if err != nil { 103 t.Fatal(err) 104 } 105 106 if _, err := ca.Write(make([]byte, 100)); err != nil { 107 t.Fatal(err) 108 } 109 if err := cb.Close(); err != nil { 110 t.Fatal(err) 111 } 112 if err := ca.Close(); err != nil { 113 t.Fatal(err) 114 } 115 // Packet is sent, but not read. 116 // inboundLoop routine should not be leaked. 117 } 118 119 func TestReadWriteDeadline(t *testing.T) { 120 // Limit runtime in case of deadlocks 121 lim := test.TimeOut(5 * time.Second) 122 defer lim.Stop() 123 124 // Check for leaking routines 125 report := test.CheckRoutines(t) 126 defer report() 127 128 var e net.Error 129 130 ca, cb, err := pipeMemory() 131 if err != nil { 132 t.Fatal(err) 133 } 134 135 if err := ca.SetDeadline(time.Unix(0, 1)); err != nil { 136 t.Fatal(err) 137 } 138 _, werr := ca.Write(make([]byte, 100)) 139 if errors.As(werr, &e) { 140 if !e.Timeout() { 141 t.Error("Deadline exceeded Write must return Timeout error") 142 } 143 if !e.Temporary() { //nolint:staticcheck 144 t.Error("Deadline exceeded Write must return Temporary error") 145 } 146 } else { 147 t.Error("Write must return net.Error error") 148 } 149 _, rerr := ca.Read(make([]byte, 100)) 150 if errors.As(rerr, &e) { 151 if !e.Timeout() { 152 t.Error("Deadline exceeded Read must return Timeout error") 153 } 154 if !e.Temporary() { //nolint:staticcheck 155 t.Error("Deadline exceeded Read must return Temporary error") 156 } 157 } else { 158 t.Error("Read must return net.Error error") 159 } 160 if err := ca.SetDeadline(time.Time{}); err != nil { 161 t.Error(err) 162 } 163 164 if err := ca.Close(); err != nil { 165 t.Error(err) 166 } 167 if err := cb.Close(); err != nil { 168 t.Error(err) 169 } 170 171 if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) { 172 t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err) 173 } 174 if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) { 175 t.Errorf("Read must return %v after close, got %v", io.EOF, err) 176 } 177 } 178 179 func TestSequenceNumberOverflow(t *testing.T) { 180 // Limit runtime in case of deadlocks 181 lim := test.TimeOut(5 * time.Second) 182 defer lim.Stop() 183 184 // Check for leaking routines 185 report := test.CheckRoutines(t) 186 defer report() 187 188 t.Run("ApplicationData", func(t *testing.T) { 189 ca, cb, err := pipeMemory() 190 if err != nil { 191 t.Fatal(err) 192 } 193 194 atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber) 195 if _, werr := ca.Write(make([]byte, 100)); werr != nil { 196 t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr) 197 } 198 if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) { 199 t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr) 200 } 201 202 if err := ca.Close(); err != nil { 203 t.Error(err) 204 } 205 if err := cb.Close(); err != nil { 206 t.Error(err) 207 } 208 }) 209 t.Run("Handshake", func(t *testing.T) { 210 ca, cb, err := pipeMemory() 211 if err != nil { 212 t.Fatal(err) 213 } 214 215 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 216 defer cancel() 217 218 atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1) 219 220 // Try to send handshake packet. 221 if werr := ca.writePackets(ctx, []*packet{ 222 { 223 record: &recordlayer.RecordLayer{ 224 Header: recordlayer.Header{ 225 Version: protocol.Version1_2, 226 }, 227 Content: &handshake.Handshake{ 228 Message: &handshake.MessageClientHello{ 229 Version: protocol.Version1_2, 230 Cookie: make([]byte, 64), 231 CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), 232 CompressionMethods: defaultCompressionMethods(), 233 }, 234 }, 235 }, 236 }, 237 }); !errors.Is(werr, errSequenceNumberOverflow) { 238 t.Errorf("Connection must fail on handshake packet reaches maximum sequence number") 239 } 240 241 if err := ca.Close(); err != nil { 242 t.Error(err) 243 } 244 if err := cb.Close(); err != nil { 245 t.Error(err) 246 } 247 }) 248 } 249 250 func pipeMemory() (*Conn, *Conn, error) { 251 // In memory pipe 252 ca, cb := dpipe.Pipe() 253 return pipeConn(ca, cb) 254 } 255 256 func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { 257 type result struct { 258 c *Conn 259 err error 260 } 261 262 c := make(chan result) 263 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 264 defer cancel() 265 266 // Setup client 267 go func() { 268 client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) 269 c <- result{client, err} 270 }() 271 272 // Setup server 273 server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) 274 if err != nil { 275 return nil, nil, err 276 } 277 278 // Receive client 279 res := <-c 280 if res.err != nil { 281 _ = server.Close() 282 return nil, nil, res.err 283 } 284 285 return res.c, server, nil 286 } 287 288 func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { 289 if generateCertificate { 290 clientCert, err := selfsign.GenerateSelfSigned() 291 if err != nil { 292 return nil, err 293 } 294 cfg.Certificates = []tls.Certificate{clientCert} 295 } 296 cfg.InsecureSkipVerify = true 297 return ClientWithContext(ctx, c, cfg) 298 } 299 300 func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { 301 if generateCertificate { 302 serverCert, err := selfsign.GenerateSelfSigned() 303 if err != nil { 304 return nil, err 305 } 306 cfg.Certificates = []tls.Certificate{serverCert} 307 } 308 return ServerWithContext(ctx, c, cfg) 309 } 310 311 func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { 312 packet, err := (&recordlayer.RecordLayer{ 313 Header: recordlayer.Header{ 314 Version: protocol.Version1_2, 315 SequenceNumber: sequenceNumber, 316 }, 317 Content: &handshake.Handshake{ 318 Header: handshake.Header{ 319 MessageSequence: uint16(sequenceNumber), 320 }, 321 Message: &handshake.MessageClientHello{ 322 Version: protocol.Version1_2, 323 Cookie: cookie, 324 CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), 325 CompressionMethods: defaultCompressionMethods(), 326 Extensions: extensions, 327 }, 328 }, 329 }).Marshal() 330 if err != nil { 331 return err 332 } 333 334 if _, err = ca.Write(packet); err != nil { 335 return err 336 } 337 return nil 338 } 339 340 func TestHandshakeWithAlert(t *testing.T) { 341 // Limit runtime in case of deadlocks 342 lim := test.TimeOut(time.Second * 20) 343 defer lim.Stop() 344 345 // Check for leaking routines 346 report := test.CheckRoutines(t) 347 defer report() 348 349 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 350 defer cancel() 351 352 cases := map[string]struct { 353 configServer, configClient *Config 354 errServer, errClient error 355 }{ 356 "CipherSuiteNoIntersection": { 357 configServer: &Config{ 358 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 359 }, 360 configClient: &Config{ 361 CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, 362 }, 363 errServer: errCipherSuiteNoIntersection, 364 errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 365 }, 366 "SignatureSchemesNoIntersection": { 367 configServer: &Config{ 368 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 369 SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256}, 370 }, 371 configClient: &Config{ 372 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 373 SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, 374 }, 375 errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 376 errClient: errNoAvailableSignatureSchemes, 377 }, 378 } 379 380 for name, testCase := range cases { 381 testCase := testCase 382 t.Run(name, func(t *testing.T) { 383 clientErr := make(chan error, 1) 384 385 ca, cb := dpipe.Pipe() 386 go func() { 387 _, err := testClient(ctx, ca, testCase.configClient, true) 388 clientErr <- err 389 }() 390 391 _, errServer := testServer(ctx, cb, testCase.configServer, true) 392 if !errors.Is(errServer, testCase.errServer) { 393 t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) 394 } 395 396 errClient := <-clientErr 397 if !errors.Is(errClient, testCase.errClient) { 398 t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient) 399 } 400 }) 401 } 402 } 403 404 func TestHandshakeWithInvalidRecord(t *testing.T) { 405 // Limit runtime in case of deadlocks 406 lim := test.TimeOut(time.Second * 20) 407 defer lim.Stop() 408 409 // Check for leaking routines 410 report := test.CheckRoutines(t) 411 defer report() 412 413 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 414 defer cancel() 415 416 type result struct { 417 c *Conn 418 err error 419 } 420 clientErr := make(chan result, 1) 421 ca, cb := dpipe.Pipe() 422 caWithInvalidRecord := &connWithCallback{Conn: ca} 423 424 var msgSeq atomic.Int32 425 // Send invalid record after first message 426 caWithInvalidRecord.onWrite = func(b []byte) { 427 if msgSeq.Add(1) == 2 { 428 if _, err := ca.Write([]byte{0x01, 0x02}); err != nil { 429 t.Fatal(err) 430 } 431 } 432 } 433 go func() { 434 client, err := testClient(ctx, caWithInvalidRecord, &Config{ 435 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 436 }, true) 437 clientErr <- result{client, err} 438 }() 439 440 server, errServer := testServer(ctx, cb, &Config{ 441 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 442 }, true) 443 444 errClient := <-clientErr 445 446 defer func() { 447 if server != nil { 448 if err := server.Close(); err != nil { 449 t.Fatal(err) 450 } 451 } 452 453 if errClient.c != nil { 454 if err := errClient.c.Close(); err != nil { 455 t.Fatal(err) 456 } 457 } 458 }() 459 460 if errServer != nil { 461 t.Fatalf("Server failed(%v)", errServer) 462 } 463 464 if errClient.err != nil { 465 t.Fatalf("Client failed(%v)", errClient.err) 466 } 467 } 468 469 func TestExportKeyingMaterial(t *testing.T) { 470 // Check for leaking routines 471 report := test.CheckRoutines(t) 472 defer report() 473 474 var rand [28]byte 475 exportLabel := "EXTRACTOR-dtls_srtp" 476 477 expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} 478 expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} 479 480 c := &Conn{ 481 state: State{ 482 localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, 483 remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, 484 localSequenceNumber: []uint64{0, 0}, 485 cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, 486 }, 487 } 488 c.setLocalEpoch(0) 489 c.setRemoteEpoch(0) 490 491 state := c.ConnectionState() 492 _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) 493 if !errors.Is(err, errHandshakeInProgress) { 494 t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) 495 } 496 497 c.setLocalEpoch(1) 498 state = c.ConnectionState() 499 _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) 500 if !errors.Is(err, errContextUnsupported) { 501 t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err) 502 } 503 504 for k := range invalidKeyingLabels() { 505 state = c.ConnectionState() 506 _, err = state.ExportKeyingMaterial(k, nil, 0) 507 if !errors.Is(err, errReservedExportKeyingMaterial) { 508 t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err) 509 } 510 } 511 512 state = c.ConnectionState() 513 keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) 514 if err != nil { 515 t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) 516 } else if !bytes.Equal(keyingMaterial, expectedServerKey) { 517 t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial) 518 } 519 520 c.state.isClient = true 521 state = c.ConnectionState() 522 keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) 523 if err != nil { 524 t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) 525 } else if !bytes.Equal(keyingMaterial, expectedClientKey) { 526 t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial) 527 } 528 } 529 530 func TestPSK(t *testing.T) { 531 // Limit runtime in case of deadlocks 532 lim := test.TimeOut(time.Second * 20) 533 defer lim.Stop() 534 535 // Check for leaking routines 536 report := test.CheckRoutines(t) 537 defer report() 538 539 for _, test := range []struct { 540 Name string 541 ServerIdentity []byte 542 CipherSuites []CipherSuiteID 543 ClientVerifyConnection func(*State) error 544 ServerVerifyConnection func(*State) error 545 WantFail bool 546 ExpectedServerErr string 547 ExpectedClientErr string 548 }{ 549 { 550 Name: "Server identity specified", 551 ServerIdentity: []byte("Test Identity"), 552 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 553 }, 554 { 555 Name: "Server identity specified - Server verify connection fails", 556 ServerIdentity: []byte("Test Identity"), 557 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 558 ServerVerifyConnection: func(s *State) error { 559 return errExample 560 }, 561 WantFail: true, 562 ExpectedServerErr: errExample.Error(), 563 ExpectedClientErr: alert.BadCertificate.String(), 564 }, 565 { 566 Name: "Server identity specified - Client verify connection fails", 567 ServerIdentity: []byte("Test Identity"), 568 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 569 ClientVerifyConnection: func(s *State) error { 570 return errExample 571 }, 572 WantFail: true, 573 ExpectedServerErr: alert.BadCertificate.String(), 574 ExpectedClientErr: errExample.Error(), 575 }, 576 { 577 Name: "Server identity nil", 578 ServerIdentity: nil, 579 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 580 }, 581 { 582 Name: "TLS_PSK_WITH_AES_128_CBC_SHA256", 583 ServerIdentity: nil, 584 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256}, 585 }, 586 { 587 Name: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", 588 ServerIdentity: nil, 589 CipherSuites: []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256}, 590 }, 591 } { 592 test := test 593 t.Run(test.Name, func(t *testing.T) { 594 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 595 defer cancel() 596 597 clientIdentity := []byte("Client Identity") 598 type result struct { 599 c *Conn 600 err error 601 } 602 clientRes := make(chan result, 1) 603 604 ca, cb := dpipe.Pipe() 605 go func() { 606 conf := &Config{ 607 PSK: func(hint []byte) ([]byte, error) { 608 if !bytes.Equal(test.ServerIdentity, hint) { 609 return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113 610 } 611 612 return []byte{0xAB, 0xC1, 0x23}, nil 613 }, 614 PSKIdentityHint: clientIdentity, 615 CipherSuites: test.CipherSuites, 616 VerifyConnection: test.ClientVerifyConnection, 617 } 618 619 c, err := testClient(ctx, ca, conf, false) 620 clientRes <- result{c, err} 621 }() 622 623 config := &Config{ 624 PSK: func(hint []byte) ([]byte, error) { 625 if !bytes.Equal(clientIdentity, hint) { 626 return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint) 627 } 628 return []byte{0xAB, 0xC1, 0x23}, nil 629 }, 630 PSKIdentityHint: test.ServerIdentity, 631 CipherSuites: test.CipherSuites, 632 VerifyConnection: test.ServerVerifyConnection, 633 } 634 635 server, err := testServer(ctx, cb, config, false) 636 if test.WantFail { 637 res := <-clientRes 638 if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { 639 t.Fatalf("TestPSK: Server expected(%v) actual(%v)", test.ExpectedServerErr, err) 640 } 641 if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) { 642 t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err) 643 } 644 return 645 } 646 if err != nil { 647 t.Fatalf("TestPSK: Server failed(%v)", err) 648 } 649 650 actualPSKIdentityHint := server.ConnectionState().IdentityHint 651 if !bytes.Equal(actualPSKIdentityHint, clientIdentity) { 652 t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint) 653 } 654 655 defer func() { 656 _ = server.Close() 657 }() 658 659 res := <-clientRes 660 if res.err != nil { 661 t.Fatal(res.err) 662 } 663 _ = res.c.Close() 664 }) 665 } 666 } 667 668 func TestPSKHintFail(t *testing.T) { 669 // Check for leaking routines 670 report := test.CheckRoutines(t) 671 defer report() 672 673 serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}} 674 pskRejected := errPSKRejected 675 676 // Limit runtime in case of deadlocks 677 lim := test.TimeOut(time.Second * 20) 678 defer lim.Stop() 679 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 680 defer cancel() 681 682 clientErr := make(chan error, 1) 683 684 ca, cb := dpipe.Pipe() 685 go func() { 686 conf := &Config{ 687 PSK: func(hint []byte) ([]byte, error) { 688 return nil, pskRejected 689 }, 690 PSKIdentityHint: []byte{}, 691 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 692 } 693 694 _, err := testClient(ctx, ca, conf, false) 695 clientErr <- err 696 }() 697 698 config := &Config{ 699 PSK: func(hint []byte) ([]byte, error) { 700 return nil, pskRejected 701 }, 702 PSKIdentityHint: []byte{}, 703 CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, 704 } 705 706 if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) { 707 t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) 708 } 709 710 if err := <-clientErr; !errors.Is(err, pskRejected) { 711 t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err) 712 } 713 } 714 715 func TestClientTimeout(t *testing.T) { 716 // Limit runtime in case of deadlocks 717 lim := test.TimeOut(time.Second * 20) 718 defer lim.Stop() 719 720 // Check for leaking routines 721 report := test.CheckRoutines(t) 722 defer report() 723 724 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 725 defer cancel() 726 727 clientErr := make(chan error, 1) 728 729 ca, _ := dpipe.Pipe() 730 go func() { 731 conf := &Config{} 732 733 c, err := testClient(ctx, ca, conf, true) 734 if err == nil { 735 _ = c.Close() //nolint:contextcheck 736 } 737 clientErr <- err 738 }() 739 740 // no server! 741 err := <-clientErr 742 var netErr net.Error 743 if !errors.As(err, &netErr) || !netErr.Timeout() { 744 t.Fatalf("Client error exp(Temporary network error) failed(%v)", err) 745 } 746 } 747 748 func TestSRTPConfiguration(t *testing.T) { 749 // Check for leaking routines 750 report := test.CheckRoutines(t) 751 defer report() 752 753 for _, test := range []struct { 754 Name string 755 ClientSRTP []SRTPProtectionProfile 756 ServerSRTP []SRTPProtectionProfile 757 ExpectedProfile SRTPProtectionProfile 758 WantClientError error 759 WantServerError error 760 }{ 761 { 762 Name: "No SRTP in use", 763 ClientSRTP: nil, 764 ServerSRTP: nil, 765 ExpectedProfile: 0, 766 WantClientError: nil, 767 WantServerError: nil, 768 }, 769 { 770 Name: "SRTP both ends", 771 ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, 772 ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, 773 ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, 774 WantClientError: nil, 775 WantServerError: nil, 776 }, 777 { 778 Name: "SRTP client only", 779 ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, 780 ServerSRTP: nil, 781 ExpectedProfile: 0, 782 WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 783 WantServerError: errServerNoMatchingSRTPProfile, 784 }, 785 { 786 Name: "SRTP server only", 787 ClientSRTP: nil, 788 ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, 789 ExpectedProfile: 0, 790 WantClientError: nil, 791 WantServerError: nil, 792 }, 793 { 794 Name: "Multiple Suites", 795 ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, 796 ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, 797 ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, 798 WantClientError: nil, 799 WantServerError: nil, 800 }, 801 { 802 Name: "Multiple Suites, Client Chooses", 803 ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, 804 ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80}, 805 ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, 806 WantClientError: nil, 807 WantServerError: nil, 808 }, 809 } { 810 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 811 defer cancel() 812 813 ca, cb := dpipe.Pipe() 814 type result struct { 815 c *Conn 816 err error 817 } 818 c := make(chan result) 819 820 go func() { 821 client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) 822 c <- result{client, err} 823 }() 824 825 server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) 826 if !errors.Is(err, test.WantServerError) { 827 t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) 828 } 829 if err == nil { 830 defer func() { 831 _ = server.Close() 832 }() 833 } 834 835 res := <-c 836 if res.err == nil { 837 defer func() { 838 _ = res.c.Close() 839 }() 840 } 841 if !errors.Is(res.err, test.WantClientError) { 842 t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) 843 } 844 if res.c == nil { 845 return 846 } 847 848 actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() 849 if actualClientSRTP != test.ExpectedProfile { 850 t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP) 851 } 852 853 actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() 854 if actualServerSRTP != test.ExpectedProfile { 855 t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) 856 } 857 } 858 } 859 860 func TestClientCertificate(t *testing.T) { 861 // Check for leaking routines 862 report := test.CheckRoutines(t) 863 defer report() 864 865 srvCert, err := selfsign.GenerateSelfSigned() 866 if err != nil { 867 t.Fatal(err) 868 } 869 srvCAPool := x509.NewCertPool() 870 srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0]) 871 if err != nil { 872 t.Fatal(err) 873 } 874 srvCAPool.AddCert(srvCertificate) 875 876 cert, err := selfsign.GenerateSelfSigned() 877 if err != nil { 878 t.Fatal(err) 879 } 880 certificate, err := x509.ParseCertificate(cert.Certificate[0]) 881 if err != nil { 882 t.Fatal(err) 883 } 884 caPool := x509.NewCertPool() 885 caPool.AddCert(certificate) 886 887 t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak 888 tests := map[string]struct { 889 clientCfg *Config 890 serverCfg *Config 891 wantErr bool 892 }{ 893 "NoClientCert": { 894 clientCfg: &Config{RootCAs: srvCAPool}, 895 serverCfg: &Config{ 896 Certificates: []tls.Certificate{srvCert}, 897 ClientAuth: NoClientCert, 898 ClientCAs: caPool, 899 }, 900 }, 901 "NoClientCert_ServerVerifyConnectionFails": { 902 clientCfg: &Config{RootCAs: srvCAPool}, 903 serverCfg: &Config{ 904 Certificates: []tls.Certificate{srvCert}, 905 ClientAuth: NoClientCert, 906 ClientCAs: caPool, 907 VerifyConnection: func(s *State) error { 908 return errExample 909 }, 910 }, 911 wantErr: true, 912 }, 913 "NoClientCert_ClientVerifyConnectionFails": { 914 clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(s *State) error { 915 return errExample 916 }}, 917 serverCfg: &Config{ 918 Certificates: []tls.Certificate{srvCert}, 919 ClientAuth: NoClientCert, 920 ClientCAs: caPool, 921 }, 922 wantErr: true, 923 }, 924 "NoClientCert_cert": { 925 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, 926 serverCfg: &Config{ 927 Certificates: []tls.Certificate{srvCert}, 928 ClientAuth: RequireAnyClientCert, 929 }, 930 }, 931 "RequestClientCert_cert": { 932 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, 933 serverCfg: &Config{ 934 Certificates: []tls.Certificate{srvCert}, 935 ClientAuth: RequestClientCert, 936 }, 937 }, 938 "RequestClientCert_no_cert": { 939 clientCfg: &Config{RootCAs: srvCAPool}, 940 serverCfg: &Config{ 941 Certificates: []tls.Certificate{srvCert}, 942 ClientAuth: RequestClientCert, 943 ClientCAs: caPool, 944 }, 945 }, 946 "RequireAnyClientCert": { 947 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, 948 serverCfg: &Config{ 949 Certificates: []tls.Certificate{srvCert}, 950 ClientAuth: RequireAnyClientCert, 951 }, 952 }, 953 "RequireAnyClientCert_error": { 954 clientCfg: &Config{RootCAs: srvCAPool}, 955 serverCfg: &Config{ 956 Certificates: []tls.Certificate{srvCert}, 957 ClientAuth: RequireAnyClientCert, 958 }, 959 wantErr: true, 960 }, 961 "VerifyClientCertIfGiven_no_cert": { 962 clientCfg: &Config{RootCAs: srvCAPool}, 963 serverCfg: &Config{ 964 Certificates: []tls.Certificate{srvCert}, 965 ClientAuth: VerifyClientCertIfGiven, 966 ClientCAs: caPool, 967 }, 968 }, 969 "VerifyClientCertIfGiven_cert": { 970 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, 971 serverCfg: &Config{ 972 Certificates: []tls.Certificate{srvCert}, 973 ClientAuth: VerifyClientCertIfGiven, 974 ClientCAs: caPool, 975 }, 976 }, 977 "VerifyClientCertIfGiven_error": { 978 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, 979 serverCfg: &Config{ 980 Certificates: []tls.Certificate{srvCert}, 981 ClientAuth: VerifyClientCertIfGiven, 982 }, 983 wantErr: true, 984 }, 985 "RequireAndVerifyClientCert": { 986 clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { 987 if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { 988 return errExample 989 } 990 return nil 991 }}, 992 serverCfg: &Config{ 993 Certificates: []tls.Certificate{srvCert}, 994 ClientAuth: RequireAndVerifyClientCert, 995 ClientCAs: caPool, 996 VerifyConnection: func(s *State) error { 997 if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { 998 return errExample 999 } 1000 return nil 1001 }, 1002 }, 1003 }, 1004 "RequireAndVerifyClientCert_callbacks": { 1005 clientCfg: &Config{ 1006 RootCAs: srvCAPool, 1007 // Certificates: []tls.Certificate{cert}, 1008 GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, 1009 }, 1010 serverCfg: &Config{ 1011 GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, 1012 // Certificates: []tls.Certificate{srvCert}, 1013 ClientAuth: RequireAndVerifyClientCert, 1014 ClientCAs: caPool, 1015 }, 1016 }, 1017 } 1018 for name, tt := range tests { 1019 tt := tt 1020 t.Run(name, func(t *testing.T) { 1021 ca, cb := dpipe.Pipe() 1022 type result struct { 1023 c *Conn 1024 err error 1025 } 1026 c := make(chan result) 1027 1028 go func() { 1029 client, err := Client(ca, tt.clientCfg) 1030 c <- result{client, err} 1031 }() 1032 1033 server, err := Server(cb, tt.serverCfg) 1034 res := <-c 1035 defer func() { 1036 if err == nil { 1037 _ = server.Close() 1038 } 1039 if res.err == nil { 1040 _ = res.c.Close() 1041 } 1042 }() 1043 1044 if tt.wantErr { 1045 if err != nil { 1046 // Error expected, test succeeded 1047 return 1048 } 1049 t.Error("Error expected") 1050 } 1051 if err != nil { 1052 t.Errorf("Server failed(%v)", err) 1053 } 1054 1055 if res.err != nil { 1056 t.Errorf("Client failed(%v)", res.err) 1057 } 1058 1059 actualClientCert := server.ConnectionState().PeerCertificates 1060 if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { 1061 if actualClientCert == nil { 1062 t.Errorf("Client did not provide a certificate") 1063 } 1064 1065 var cfgCert [][]byte 1066 if len(tt.clientCfg.Certificates) > 0 { 1067 cfgCert = tt.clientCfg.Certificates[0].Certificate 1068 } 1069 if tt.clientCfg.GetClientCertificate != nil { 1070 crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{}) 1071 if err != nil { 1072 t.Errorf("Server configuration did not provide a certificate") 1073 } 1074 cfgCert = crt.Certificate 1075 } 1076 if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualClientCert[0]) { 1077 t.Errorf("Client certificate was not communicated correctly") 1078 } 1079 } 1080 if tt.serverCfg.ClientAuth == NoClientCert { 1081 if actualClientCert != nil { 1082 t.Errorf("Client certificate wasn't expected") 1083 } 1084 } 1085 1086 actualServerCert := res.c.ConnectionState().PeerCertificates 1087 if actualServerCert == nil { 1088 t.Errorf("Server did not provide a certificate") 1089 } 1090 var cfgCert [][]byte 1091 if len(tt.serverCfg.Certificates) > 0 { 1092 cfgCert = tt.serverCfg.Certificates[0].Certificate 1093 } 1094 if tt.serverCfg.GetCertificate != nil { 1095 crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{}) 1096 if err != nil { 1097 t.Errorf("Server configuration did not provide a certificate") 1098 } 1099 cfgCert = crt.Certificate 1100 } 1101 if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualServerCert[0]) { 1102 t.Errorf("Server certificate was not communicated correctly") 1103 } 1104 }) 1105 } 1106 }) 1107 } 1108 1109 func TestExtendedMasterSecret(t *testing.T) { 1110 // Check for leaking routines 1111 report := test.CheckRoutines(t) 1112 defer report() 1113 1114 tests := map[string]struct { 1115 clientCfg *Config 1116 serverCfg *Config 1117 expectedClientErr error 1118 expectedServerErr error 1119 }{ 1120 "Request_Request_ExtendedMasterSecret": { 1121 clientCfg: &Config{ 1122 ExtendedMasterSecret: RequestExtendedMasterSecret, 1123 }, 1124 serverCfg: &Config{ 1125 ExtendedMasterSecret: RequestExtendedMasterSecret, 1126 }, 1127 expectedClientErr: nil, 1128 expectedServerErr: nil, 1129 }, 1130 "Request_Require_ExtendedMasterSecret": { 1131 clientCfg: &Config{ 1132 ExtendedMasterSecret: RequestExtendedMasterSecret, 1133 }, 1134 serverCfg: &Config{ 1135 ExtendedMasterSecret: RequireExtendedMasterSecret, 1136 }, 1137 expectedClientErr: nil, 1138 expectedServerErr: nil, 1139 }, 1140 "Request_Disable_ExtendedMasterSecret": { 1141 clientCfg: &Config{ 1142 ExtendedMasterSecret: RequestExtendedMasterSecret, 1143 }, 1144 serverCfg: &Config{ 1145 ExtendedMasterSecret: DisableExtendedMasterSecret, 1146 }, 1147 expectedClientErr: nil, 1148 expectedServerErr: nil, 1149 }, 1150 "Require_Request_ExtendedMasterSecret": { 1151 clientCfg: &Config{ 1152 ExtendedMasterSecret: RequireExtendedMasterSecret, 1153 }, 1154 serverCfg: &Config{ 1155 ExtendedMasterSecret: RequestExtendedMasterSecret, 1156 }, 1157 expectedClientErr: nil, 1158 expectedServerErr: nil, 1159 }, 1160 "Require_Require_ExtendedMasterSecret": { 1161 clientCfg: &Config{ 1162 ExtendedMasterSecret: RequireExtendedMasterSecret, 1163 }, 1164 serverCfg: &Config{ 1165 ExtendedMasterSecret: RequireExtendedMasterSecret, 1166 }, 1167 expectedClientErr: nil, 1168 expectedServerErr: nil, 1169 }, 1170 "Require_Disable_ExtendedMasterSecret": { 1171 clientCfg: &Config{ 1172 ExtendedMasterSecret: RequireExtendedMasterSecret, 1173 }, 1174 serverCfg: &Config{ 1175 ExtendedMasterSecret: DisableExtendedMasterSecret, 1176 }, 1177 expectedClientErr: errClientRequiredButNoServerEMS, 1178 expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 1179 }, 1180 "Disable_Request_ExtendedMasterSecret": { 1181 clientCfg: &Config{ 1182 ExtendedMasterSecret: DisableExtendedMasterSecret, 1183 }, 1184 serverCfg: &Config{ 1185 ExtendedMasterSecret: RequestExtendedMasterSecret, 1186 }, 1187 expectedClientErr: nil, 1188 expectedServerErr: nil, 1189 }, 1190 "Disable_Require_ExtendedMasterSecret": { 1191 clientCfg: &Config{ 1192 ExtendedMasterSecret: DisableExtendedMasterSecret, 1193 }, 1194 serverCfg: &Config{ 1195 ExtendedMasterSecret: RequireExtendedMasterSecret, 1196 }, 1197 expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 1198 expectedServerErr: errServerRequiredButNoClientEMS, 1199 }, 1200 "Disable_Disable_ExtendedMasterSecret": { 1201 clientCfg: &Config{ 1202 ExtendedMasterSecret: DisableExtendedMasterSecret, 1203 }, 1204 serverCfg: &Config{ 1205 ExtendedMasterSecret: DisableExtendedMasterSecret, 1206 }, 1207 expectedClientErr: nil, 1208 expectedServerErr: nil, 1209 }, 1210 } 1211 for name, tt := range tests { 1212 tt := tt 1213 t.Run(name, func(t *testing.T) { 1214 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 1215 defer cancel() 1216 1217 ca, cb := dpipe.Pipe() 1218 type result struct { 1219 c *Conn 1220 err error 1221 } 1222 c := make(chan result) 1223 1224 go func() { 1225 client, err := testClient(ctx, ca, tt.clientCfg, true) 1226 c <- result{client, err} 1227 }() 1228 1229 server, err := testServer(ctx, cb, tt.serverCfg, true) 1230 res := <-c 1231 defer func() { 1232 if err == nil { 1233 _ = server.Close() 1234 } 1235 if res.err == nil { 1236 _ = res.c.Close() 1237 } 1238 }() 1239 1240 if !errors.Is(res.err, tt.expectedClientErr) { 1241 t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err) 1242 } 1243 1244 if !errors.Is(err, tt.expectedServerErr) { 1245 t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err) 1246 } 1247 }) 1248 } 1249 } 1250 1251 func TestServerCertificate(t *testing.T) { 1252 // Check for leaking routines 1253 report := test.CheckRoutines(t) 1254 defer report() 1255 1256 cert, err := selfsign.GenerateSelfSigned() 1257 if err != nil { 1258 t.Fatal(err) 1259 } 1260 certificate, err := x509.ParseCertificate(cert.Certificate[0]) 1261 if err != nil { 1262 t.Fatal(err) 1263 } 1264 caPool := x509.NewCertPool() 1265 caPool.AddCert(certificate) 1266 1267 t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak 1268 tests := map[string]struct { 1269 clientCfg *Config 1270 serverCfg *Config 1271 wantErr bool 1272 }{ 1273 "no_ca": { 1274 clientCfg: &Config{}, 1275 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1276 wantErr: true, 1277 }, 1278 "good_ca": { 1279 clientCfg: &Config{RootCAs: caPool}, 1280 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1281 }, 1282 "no_ca_skip_verify": { 1283 clientCfg: &Config{InsecureSkipVerify: true}, 1284 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1285 }, 1286 "good_ca_skip_verify_custom_verify_peer": { 1287 clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, 1288 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { 1289 if len(chain) != 0 { 1290 return errNotExpectedChain 1291 } 1292 return nil 1293 }}, 1294 }, 1295 "good_ca_verify_custom_verify_peer": { 1296 clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, 1297 serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { 1298 if len(chain) == 0 { 1299 return errExpecedChain 1300 } 1301 return nil 1302 }}, 1303 }, 1304 "good_ca_custom_verify_peer": { 1305 clientCfg: &Config{ 1306 RootCAs: caPool, 1307 VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { 1308 return errWrongCert 1309 }, 1310 }, 1311 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1312 wantErr: true, 1313 }, 1314 "server_name": { 1315 clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName}, 1316 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1317 }, 1318 "server_name_error": { 1319 clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"}, 1320 serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, 1321 wantErr: true, 1322 }, 1323 } 1324 for name, tt := range tests { 1325 tt := tt 1326 t.Run(name, func(t *testing.T) { 1327 ca, cb := dpipe.Pipe() 1328 1329 type result struct { 1330 c *Conn 1331 err error 1332 } 1333 srvCh := make(chan result) 1334 go func() { 1335 s, err := Server(cb, tt.serverCfg) 1336 srvCh <- result{s, err} 1337 }() 1338 1339 cli, err := Client(ca, tt.clientCfg) 1340 if err == nil { 1341 _ = cli.Close() 1342 } 1343 if !tt.wantErr && err != nil { 1344 t.Errorf("Client failed(%v)", err) 1345 } 1346 if tt.wantErr && err == nil { 1347 t.Fatal("Error expected") 1348 } 1349 1350 srv := <-srvCh 1351 if srv.err == nil { 1352 _ = srv.c.Close() 1353 } 1354 }) 1355 } 1356 }) 1357 } 1358 1359 func TestCipherSuiteConfiguration(t *testing.T) { 1360 // Check for leaking routines 1361 report := test.CheckRoutines(t) 1362 defer report() 1363 1364 for _, test := range []struct { 1365 Name string 1366 ClientCipherSuites []CipherSuiteID 1367 ServerCipherSuites []CipherSuiteID 1368 WantClientError error 1369 WantServerError error 1370 WantSelectedCipherSuite CipherSuiteID 1371 }{ 1372 { 1373 Name: "No CipherSuites specified", 1374 ClientCipherSuites: nil, 1375 ServerCipherSuites: nil, 1376 WantClientError: nil, 1377 WantServerError: nil, 1378 }, 1379 { 1380 Name: "Invalid CipherSuite", 1381 ClientCipherSuites: []CipherSuiteID{0x00}, 1382 ServerCipherSuites: []CipherSuiteID{0x00}, 1383 WantClientError: &invalidCipherSuiteError{0x00}, 1384 WantServerError: &invalidCipherSuiteError{0x00}, 1385 }, 1386 { 1387 Name: "Valid CipherSuites specified", 1388 ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1389 ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1390 WantClientError: nil, 1391 WantServerError: nil, 1392 WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 1393 }, 1394 { 1395 Name: "CipherSuites mismatch", 1396 ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1397 ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, 1398 WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, 1399 WantServerError: errCipherSuiteNoIntersection, 1400 }, 1401 { 1402 Name: "Valid CipherSuites CCM specified", 1403 ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, 1404 ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, 1405 WantClientError: nil, 1406 WantServerError: nil, 1407 WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM, 1408 }, 1409 { 1410 Name: "Valid CipherSuites CCM-8 specified", 1411 ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, 1412 ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, 1413 WantClientError: nil, 1414 WantServerError: nil, 1415 WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, 1416 }, 1417 { 1418 Name: "Server supports subset of client suites", 1419 ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, 1420 ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, 1421 WantClientError: nil, 1422 WantServerError: nil, 1423 WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 1424 }, 1425 } { 1426 test := test 1427 t.Run(test.Name, func(t *testing.T) { 1428 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 1429 defer cancel() 1430 1431 ca, cb := dpipe.Pipe() 1432 type result struct { 1433 c *Conn 1434 err error 1435 } 1436 c := make(chan result) 1437 1438 go func() { 1439 client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true) 1440 c <- result{client, err} 1441 }() 1442 1443 server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true) 1444 if err == nil { 1445 defer func() { 1446 _ = server.Close() 1447 }() 1448 } 1449 if !errors.Is(err, test.WantServerError) { 1450 t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) 1451 } 1452 1453 res := <-c 1454 if res.err == nil { 1455 _ = server.Close() 1456 _ = res.c.Close() 1457 } 1458 if !errors.Is(res.err, test.WantClientError) { 1459 t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) 1460 } 1461 if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite { 1462 t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID()) 1463 } 1464 }) 1465 } 1466 } 1467 1468 func TestCertificateAndPSKServer(t *testing.T) { 1469 // Check for leaking routines 1470 report := test.CheckRoutines(t) 1471 defer report() 1472 1473 for _, test := range []struct { 1474 Name string 1475 ClientPSK bool 1476 }{ 1477 { 1478 Name: "Client uses PKI", 1479 ClientPSK: false, 1480 }, 1481 { 1482 Name: "Client uses PSK", 1483 ClientPSK: true, 1484 }, 1485 } { 1486 test := test 1487 t.Run(test.Name, func(t *testing.T) { 1488 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 1489 defer cancel() 1490 1491 ca, cb := dpipe.Pipe() 1492 type result struct { 1493 c *Conn 1494 err error 1495 } 1496 c := make(chan result) 1497 1498 go func() { 1499 config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} 1500 if test.ClientPSK { 1501 config.PSK = func([]byte) ([]byte, error) { 1502 return []byte{0x00, 0x01, 0x02}, nil 1503 } 1504 config.PSKIdentityHint = []byte{0x00} 1505 config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} 1506 } 1507 1508 client, err := testClient(ctx, ca, config, false) 1509 c <- result{client, err} 1510 }() 1511 1512 config := &Config{ 1513 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256}, 1514 PSK: func([]byte) ([]byte, error) { 1515 return []byte{0x00, 0x01, 0x02}, nil 1516 }, 1517 } 1518 1519 server, err := testServer(ctx, cb, config, true) 1520 if err == nil { 1521 defer func() { 1522 _ = server.Close() 1523 }() 1524 } else { 1525 t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err) 1526 } 1527 1528 res := <-c 1529 if res.err == nil { 1530 _ = server.Close() 1531 _ = res.c.Close() 1532 } else { 1533 t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err) 1534 } 1535 }) 1536 } 1537 } 1538 1539 func TestPSKConfiguration(t *testing.T) { 1540 // Check for leaking routines 1541 report := test.CheckRoutines(t) 1542 defer report() 1543 1544 for _, test := range []struct { 1545 Name string 1546 ClientHasCertificate bool 1547 ServerHasCertificate bool 1548 ClientPSK PSKCallback 1549 ServerPSK PSKCallback 1550 ClientPSKIdentity []byte 1551 ServerPSKIdentity []byte 1552 WantClientError error 1553 WantServerError error 1554 }{ 1555 { 1556 Name: "PSK and no certificate specified", 1557 ClientHasCertificate: false, 1558 ServerHasCertificate: false, 1559 ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1560 ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1561 ClientPSKIdentity: []byte{0x00}, 1562 ServerPSKIdentity: []byte{0x00}, 1563 WantClientError: errNoAvailablePSKCipherSuite, 1564 WantServerError: errNoAvailablePSKCipherSuite, 1565 }, 1566 { 1567 Name: "PSK and certificate specified", 1568 ClientHasCertificate: true, 1569 ServerHasCertificate: true, 1570 ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1571 ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1572 ClientPSKIdentity: []byte{0x00}, 1573 ServerPSKIdentity: []byte{0x00}, 1574 WantClientError: errNoAvailablePSKCipherSuite, 1575 WantServerError: errNoAvailablePSKCipherSuite, 1576 }, 1577 { 1578 Name: "PSK and no identity specified", 1579 ClientHasCertificate: false, 1580 ServerHasCertificate: false, 1581 ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1582 ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, 1583 ClientPSKIdentity: nil, 1584 ServerPSKIdentity: nil, 1585 WantClientError: errPSKAndIdentityMustBeSetForClient, 1586 WantServerError: errNoAvailablePSKCipherSuite, 1587 }, 1588 { 1589 Name: "No PSK and identity specified", 1590 ClientHasCertificate: false, 1591 ServerHasCertificate: false, 1592 ClientPSK: nil, 1593 ServerPSK: nil, 1594 ClientPSKIdentity: []byte{0x00}, 1595 ServerPSKIdentity: []byte{0x00}, 1596 WantClientError: errIdentityNoPSK, 1597 WantServerError: errIdentityNoPSK, 1598 }, 1599 } { 1600 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 1601 defer cancel() 1602 1603 ca, cb := dpipe.Pipe() 1604 type result struct { 1605 c *Conn 1606 err error 1607 } 1608 c := make(chan result) 1609 1610 go func() { 1611 client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) 1612 c <- result{client, err} 1613 }() 1614 1615 _, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) 1616 if err != nil || test.WantServerError != nil { 1617 if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { 1618 t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) 1619 } 1620 } 1621 1622 res := <-c 1623 if res.err != nil || test.WantClientError != nil { 1624 if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { 1625 t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) 1626 } 1627 } 1628 } 1629 } 1630 1631 func TestServerTimeout(t *testing.T) { 1632 // Limit runtime in case of deadlocks 1633 lim := test.TimeOut(time.Second * 20) 1634 defer lim.Stop() 1635 1636 // Check for leaking routines 1637 report := test.CheckRoutines(t) 1638 defer report() 1639 1640 cookie := make([]byte, 20) 1641 _, err := rand.Read(cookie) 1642 if err != nil { 1643 t.Fatal(err) 1644 } 1645 1646 var rand [28]byte 1647 random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} 1648 1649 cipherSuites := []CipherSuite{ 1650 &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, 1651 &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, 1652 } 1653 1654 extensions := []extension.Extension{ 1655 &extension.SupportedSignatureAlgorithms{ 1656 SignatureHashAlgorithms: []signaturehash.Algorithm{ 1657 {Hash: hash.SHA256, Signature: signature.ECDSA}, 1658 {Hash: hash.SHA384, Signature: signature.ECDSA}, 1659 {Hash: hash.SHA512, Signature: signature.ECDSA}, 1660 {Hash: hash.SHA256, Signature: signature.RSA}, 1661 {Hash: hash.SHA384, Signature: signature.RSA}, 1662 {Hash: hash.SHA512, Signature: signature.RSA}, 1663 }, 1664 }, 1665 &extension.SupportedEllipticCurves{ 1666 EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, 1667 }, 1668 &extension.SupportedPointFormats{ 1669 PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, 1670 }, 1671 } 1672 1673 record := &recordlayer.RecordLayer{ 1674 Header: recordlayer.Header{ 1675 SequenceNumber: 0, 1676 Version: protocol.Version1_2, 1677 }, 1678 Content: &handshake.Handshake{ 1679 // sequenceNumber and messageSequence line up, may need to be re-evaluated 1680 Header: handshake.Header{ 1681 MessageSequence: 0, 1682 }, 1683 Message: &handshake.MessageClientHello{ 1684 Version: protocol.Version1_2, 1685 Cookie: cookie, 1686 Random: random, 1687 CipherSuiteIDs: cipherSuiteIDs(cipherSuites), 1688 CompressionMethods: defaultCompressionMethods(), 1689 Extensions: extensions, 1690 }, 1691 }, 1692 } 1693 1694 packet, err := record.Marshal() 1695 if err != nil { 1696 t.Fatal(err) 1697 } 1698 1699 ca, cb := dpipe.Pipe() 1700 defer func() { 1701 err := ca.Close() 1702 if err != nil { 1703 t.Fatal(err) 1704 } 1705 }() 1706 1707 // Client reader 1708 caReadChan := make(chan []byte, 1000) 1709 go func() { 1710 for { 1711 data := make([]byte, 8192) 1712 n, err := ca.Read(data) 1713 if err != nil { 1714 return 1715 } 1716 1717 caReadChan <- data[:n] 1718 } 1719 }() 1720 1721 // Start sending ClientHello packets until server responds with first packet 1722 go func() { 1723 for { 1724 select { 1725 case <-time.After(10 * time.Millisecond): 1726 _, err := ca.Write(packet) 1727 if err != nil { 1728 return 1729 } 1730 case <-caReadChan: 1731 // Once we receive the first reply from the server, stop 1732 return 1733 } 1734 } 1735 }() 1736 1737 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 1738 defer cancel() 1739 1740 config := &Config{ 1741 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1742 FlightInterval: 100 * time.Millisecond, 1743 } 1744 1745 _, serverErr := testServer(ctx, cb, config, true) 1746 var netErr net.Error 1747 if !errors.As(serverErr, &netErr) || !netErr.Timeout() { 1748 t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) 1749 } 1750 1751 // Wait a little longer to ensure no additional messages have been sent by the server 1752 time.Sleep(300 * time.Millisecond) 1753 select { 1754 case msg := <-caReadChan: 1755 t.Fatalf("Expected no additional messages from server, got: %+v", msg) 1756 default: 1757 } 1758 } 1759 1760 func TestProtocolVersionValidation(t *testing.T) { 1761 // Limit runtime in case of deadlocks 1762 lim := test.TimeOut(time.Second * 20) 1763 defer lim.Stop() 1764 1765 // Check for leaking routines 1766 report := test.CheckRoutines(t) 1767 defer report() 1768 1769 cookie := make([]byte, 20) 1770 if _, err := rand.Read(cookie); err != nil { 1771 t.Fatal(err) 1772 } 1773 1774 var rand [28]byte 1775 random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} 1776 1777 config := &Config{ 1778 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1779 FlightInterval: 100 * time.Millisecond, 1780 } 1781 1782 t.Run("Server", func(t *testing.T) { 1783 serverCases := map[string]struct { 1784 records []*recordlayer.RecordLayer 1785 }{ 1786 "ClientHelloVersion": { 1787 records: []*recordlayer.RecordLayer{ 1788 { 1789 Header: recordlayer.Header{ 1790 Version: protocol.Version1_2, 1791 }, 1792 Content: &handshake.Handshake{ 1793 Message: &handshake.MessageClientHello{ 1794 Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade 1795 Cookie: cookie, 1796 Random: random, 1797 CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, 1798 CompressionMethods: defaultCompressionMethods(), 1799 }, 1800 }, 1801 }, 1802 }, 1803 }, 1804 "SecondsClientHelloVersion": { 1805 records: []*recordlayer.RecordLayer{ 1806 { 1807 Header: recordlayer.Header{ 1808 Version: protocol.Version1_2, 1809 }, 1810 Content: &handshake.Handshake{ 1811 Message: &handshake.MessageClientHello{ 1812 Version: protocol.Version1_2, 1813 Cookie: cookie, 1814 Random: random, 1815 CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, 1816 CompressionMethods: defaultCompressionMethods(), 1817 }, 1818 }, 1819 }, 1820 { 1821 Header: recordlayer.Header{ 1822 Version: protocol.Version1_2, 1823 SequenceNumber: 1, 1824 }, 1825 Content: &handshake.Handshake{ 1826 Header: handshake.Header{ 1827 MessageSequence: 1, 1828 }, 1829 Message: &handshake.MessageClientHello{ 1830 Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade 1831 Cookie: cookie, 1832 Random: random, 1833 CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, 1834 CompressionMethods: defaultCompressionMethods(), 1835 }, 1836 }, 1837 }, 1838 }, 1839 }, 1840 } 1841 for name, c := range serverCases { 1842 c := c 1843 t.Run(name, func(t *testing.T) { 1844 ca, cb := dpipe.Pipe() 1845 defer func() { 1846 err := ca.Close() 1847 if err != nil { 1848 t.Error(err) 1849 } 1850 }() 1851 1852 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 1853 defer cancel() 1854 1855 var wg sync.WaitGroup 1856 wg.Add(1) 1857 defer wg.Wait() 1858 go func() { 1859 defer wg.Done() 1860 if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { 1861 t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) 1862 } 1863 }() 1864 1865 time.Sleep(50 * time.Millisecond) 1866 1867 resp := make([]byte, 1024) 1868 for _, record := range c.records { 1869 packet, err := record.Marshal() 1870 if err != nil { 1871 t.Fatal(err) 1872 } 1873 if _, werr := ca.Write(packet); werr != nil { 1874 t.Fatal(werr) 1875 } 1876 n, rerr := ca.Read(resp[:cap(resp)]) 1877 if rerr != nil { 1878 t.Fatal(rerr) 1879 } 1880 resp = resp[:n] 1881 } 1882 1883 h := &recordlayer.Header{} 1884 if err := h.Unmarshal(resp); err != nil { 1885 t.Fatal("Failed to unmarshal response") 1886 } 1887 if h.ContentType != protocol.ContentTypeAlert { 1888 t.Errorf("Peer must return alert to unsupported protocol version") 1889 } 1890 }) 1891 } 1892 }) 1893 1894 t.Run("Client", func(t *testing.T) { 1895 clientCases := map[string]struct { 1896 records []*recordlayer.RecordLayer 1897 }{ 1898 "ServerHelloVersion": { 1899 records: []*recordlayer.RecordLayer{ 1900 { 1901 Header: recordlayer.Header{ 1902 Version: protocol.Version1_2, 1903 }, 1904 Content: &handshake.Handshake{ 1905 Message: &handshake.MessageHelloVerifyRequest{ 1906 Version: protocol.Version1_2, 1907 Cookie: cookie, 1908 }, 1909 }, 1910 }, 1911 { 1912 Header: recordlayer.Header{ 1913 Version: protocol.Version1_2, 1914 SequenceNumber: 1, 1915 }, 1916 Content: &handshake.Handshake{ 1917 Header: handshake.Header{ 1918 MessageSequence: 1, 1919 }, 1920 Message: &handshake.MessageServerHello{ 1921 Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade 1922 Random: random, 1923 CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(), 1924 CompressionMethod: defaultCompressionMethods()[0], 1925 }, 1926 }, 1927 }, 1928 }, 1929 }, 1930 } 1931 for name, c := range clientCases { 1932 c := c 1933 t.Run(name, func(t *testing.T) { 1934 ca, cb := dpipe.Pipe() 1935 defer func() { 1936 err := ca.Close() 1937 if err != nil { 1938 t.Error(err) 1939 } 1940 }() 1941 1942 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 1943 defer cancel() 1944 1945 var wg sync.WaitGroup 1946 wg.Add(1) 1947 defer wg.Wait() 1948 go func() { 1949 defer wg.Done() 1950 if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { 1951 t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) 1952 } 1953 }() 1954 1955 time.Sleep(50 * time.Millisecond) 1956 1957 for _, record := range c.records { 1958 if _, err := ca.Read(make([]byte, 1024)); err != nil { 1959 t.Fatal(err) 1960 } 1961 1962 packet, err := record.Marshal() 1963 if err != nil { 1964 t.Fatal(err) 1965 } 1966 if _, err := ca.Write(packet); err != nil { 1967 t.Fatal(err) 1968 } 1969 } 1970 resp := make([]byte, 1024) 1971 n, err := ca.Read(resp) 1972 if err != nil { 1973 t.Fatal(err) 1974 } 1975 resp = resp[:n] 1976 1977 h := &recordlayer.Header{} 1978 if err := h.Unmarshal(resp); err != nil { 1979 t.Fatal("Failed to unmarshal response") 1980 } 1981 if h.ContentType != protocol.ContentTypeAlert { 1982 t.Errorf("Peer must return alert to unsupported protocol version") 1983 } 1984 }) 1985 } 1986 }) 1987 } 1988 1989 func TestMultipleHelloVerifyRequest(t *testing.T) { 1990 // Limit runtime in case of deadlocks 1991 lim := test.TimeOut(time.Second * 20) 1992 defer lim.Stop() 1993 1994 // Check for leaking routines 1995 report := test.CheckRoutines(t) 1996 defer report() 1997 1998 cookies := [][]byte{ 1999 // first clientHello contains an empty cookie 2000 {}, 2001 } 2002 var packets [][]byte 2003 for i := 0; i < 2; i++ { 2004 cookie := make([]byte, 20) 2005 if _, err := rand.Read(cookie); err != nil { 2006 t.Fatal(err) 2007 } 2008 cookies = append(cookies, cookie) 2009 2010 record := &recordlayer.RecordLayer{ 2011 Header: recordlayer.Header{ 2012 SequenceNumber: uint64(i), 2013 Version: protocol.Version1_2, 2014 }, 2015 Content: &handshake.Handshake{ 2016 Header: handshake.Header{ 2017 MessageSequence: uint16(i), 2018 }, 2019 Message: &handshake.MessageHelloVerifyRequest{ 2020 Version: protocol.Version1_2, 2021 Cookie: cookie, 2022 }, 2023 }, 2024 } 2025 packet, err := record.Marshal() 2026 if err != nil { 2027 t.Fatal(err) 2028 } 2029 packets = append(packets, packet) 2030 } 2031 2032 ca, cb := dpipe.Pipe() 2033 defer func() { 2034 err := ca.Close() 2035 if err != nil { 2036 t.Error(err) 2037 } 2038 }() 2039 2040 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 2041 defer cancel() 2042 2043 var wg sync.WaitGroup 2044 wg.Add(1) 2045 defer wg.Wait() 2046 go func() { 2047 defer wg.Done() 2048 _, _ = testClient(ctx, ca, &Config{}, false) 2049 }() 2050 2051 for i, cookie := range cookies { 2052 // read client hello 2053 resp := make([]byte, 1024) 2054 n, err := cb.Read(resp) 2055 if err != nil { 2056 t.Fatal(err) 2057 } 2058 record := &recordlayer.RecordLayer{} 2059 if err := record.Unmarshal(resp[:n]); err != nil { 2060 t.Fatal(err) 2061 } 2062 clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) 2063 if !ok { 2064 t.Fatal("Failed to cast MessageClientHello") 2065 } 2066 2067 if !bytes.Equal(clientHello.Cookie, cookie) { 2068 t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie) 2069 } 2070 if len(packets) <= i { 2071 break 2072 } 2073 // write hello verify request 2074 if _, err := cb.Write(packets[i]); err != nil { 2075 t.Fatal(err) 2076 } 2077 } 2078 cancel() 2079 } 2080 2081 // Assert that a DTLS Server always responds with RenegotiationInfo if 2082 // a ClientHello contained that extension or not 2083 func TestRenegotationInfo(t *testing.T) { 2084 // Limit runtime in case of deadlocks 2085 lim := test.TimeOut(10 * time.Second) 2086 defer lim.Stop() 2087 2088 // Check for leaking routines 2089 report := test.CheckRoutines(t) 2090 defer report() 2091 2092 resp := make([]byte, 1024) 2093 2094 for _, testCase := range []struct { 2095 Name string 2096 SendRenegotiationInfo bool 2097 }{ 2098 { 2099 "Include RenegotiationInfo", 2100 true, 2101 }, 2102 { 2103 "No RenegotiationInfo", 2104 false, 2105 }, 2106 } { 2107 test := testCase 2108 t.Run(test.Name, func(t *testing.T) { 2109 ca, cb := dpipe.Pipe() 2110 defer func() { 2111 if err := ca.Close(); err != nil { 2112 t.Error(err) 2113 } 2114 }() 2115 2116 ctx, cancel := context.WithCancel(context.Background()) 2117 defer cancel() 2118 2119 go func() { 2120 if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { 2121 t.Error(err) 2122 } 2123 }() 2124 2125 time.Sleep(50 * time.Millisecond) 2126 2127 extensions := []extension.Extension{} 2128 if test.SendRenegotiationInfo { 2129 extensions = append(extensions, &extension.RenegotiationInfo{ 2130 RenegotiatedConnection: 0, 2131 }) 2132 } 2133 err := sendClientHello([]byte{}, ca, 0, extensions) 2134 if err != nil { 2135 t.Fatal(err) 2136 } 2137 n, err := ca.Read(resp) 2138 if err != nil { 2139 t.Fatal(err) 2140 } 2141 r := &recordlayer.RecordLayer{} 2142 if err = r.Unmarshal(resp[:n]); err != nil { 2143 t.Fatal(err) 2144 } 2145 2146 helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) 2147 if !ok { 2148 t.Fatal("Failed to cast MessageHelloVerifyRequest") 2149 } 2150 2151 err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) 2152 if err != nil { 2153 t.Fatal(err) 2154 } 2155 if n, err = ca.Read(resp); err != nil { 2156 t.Fatal(err) 2157 } 2158 2159 messages, err := recordlayer.UnpackDatagram(resp[:n]) 2160 if err != nil { 2161 t.Fatal(err) 2162 } 2163 2164 if err := r.Unmarshal(messages[0]); err != nil { 2165 t.Fatal(err) 2166 } 2167 2168 serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) 2169 if !ok { 2170 t.Fatal("Failed to cast MessageServerHello") 2171 } 2172 2173 gotNegotationInfo := false 2174 for _, v := range serverHello.Extensions { 2175 if _, ok := v.(*extension.RenegotiationInfo); ok { 2176 gotNegotationInfo = true 2177 } 2178 } 2179 2180 if !gotNegotationInfo { 2181 t.Fatalf("Received ServerHello without RenegotiationInfo") 2182 } 2183 }) 2184 } 2185 } 2186 2187 func TestServerNameIndicationExtension(t *testing.T) { 2188 // Limit runtime in case of deadlocks 2189 lim := test.TimeOut(time.Second * 20) 2190 defer lim.Stop() 2191 2192 // Check for leaking routines 2193 report := test.CheckRoutines(t) 2194 defer report() 2195 2196 for _, test := range []struct { 2197 Name string 2198 ServerName string 2199 Expected []byte 2200 IncludeSNI bool 2201 }{ 2202 { 2203 Name: "Server name is a valid hostname", 2204 ServerName: "example.com", 2205 Expected: []byte("example.com"), 2206 IncludeSNI: true, 2207 }, 2208 { 2209 Name: "Server name is an IP literal", 2210 ServerName: "1.2.3.4", 2211 Expected: []byte(""), 2212 IncludeSNI: false, 2213 }, 2214 { 2215 Name: "Server name is empty", 2216 ServerName: "", 2217 Expected: []byte(""), 2218 IncludeSNI: false, 2219 }, 2220 } { 2221 test := test 2222 t.Run(test.Name, func(t *testing.T) { 2223 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2224 defer cancel() 2225 2226 ca, cb := dpipe.Pipe() 2227 go func() { 2228 conf := &Config{ 2229 ServerName: test.ServerName, 2230 } 2231 2232 _, _ = testClient(ctx, ca, conf, false) 2233 }() 2234 2235 // Receive ClientHello 2236 resp := make([]byte, 1024) 2237 n, err := cb.Read(resp) 2238 if err != nil { 2239 t.Fatal(err) 2240 } 2241 r := &recordlayer.RecordLayer{} 2242 if err = r.Unmarshal(resp[:n]); err != nil { 2243 t.Fatal(err) 2244 } 2245 2246 clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) 2247 if !ok { 2248 t.Fatal("Failed to cast MessageClientHello") 2249 } 2250 2251 gotSNI := false 2252 var actualServerName string 2253 for _, v := range clientHello.Extensions { 2254 if _, ok := v.(*extension.ServerName); ok { 2255 gotSNI = true 2256 extensionServerName, ok := v.(*extension.ServerName) 2257 if !ok { 2258 t.Fatal("Failed to cast extension.ServerName") 2259 } 2260 2261 actualServerName = extensionServerName.ServerName 2262 } 2263 } 2264 2265 if gotSNI != test.IncludeSNI { 2266 t.Errorf("TestSNI: unexpected SNI inclusion '%s': expected(%v) actual(%v)", test.Name, test.IncludeSNI, gotSNI) 2267 } 2268 2269 if !bytes.Equal([]byte(actualServerName), test.Expected) { 2270 t.Errorf("TestSNI: server name mismatch '%s': expected(%v) actual(%v)", test.Name, test.Expected, actualServerName) 2271 } 2272 }) 2273 } 2274 } 2275 2276 func TestALPNExtension(t *testing.T) { 2277 // Limit runtime in case of deadlocks 2278 lim := test.TimeOut(time.Second * 20) 2279 defer lim.Stop() 2280 2281 // Check for leaking routines 2282 report := test.CheckRoutines(t) 2283 defer report() 2284 2285 for _, test := range []struct { 2286 Name string 2287 ClientProtocolNameList []string 2288 ServerProtocolNameList []string 2289 ExpectedProtocol string 2290 ExpectAlertFromClient bool 2291 ExpectAlertFromServer bool 2292 Alert alert.Description 2293 }{ 2294 { 2295 Name: "Negotiate a protocol", 2296 ClientProtocolNameList: []string{"http/1.1", "spd/1"}, 2297 ServerProtocolNameList: []string{"spd/1"}, 2298 ExpectedProtocol: "spd/1", 2299 ExpectAlertFromClient: false, 2300 ExpectAlertFromServer: false, 2301 Alert: 0, 2302 }, 2303 { 2304 Name: "Server doesn't support any", 2305 ClientProtocolNameList: []string{"http/1.1", "spd/1"}, 2306 ServerProtocolNameList: []string{}, 2307 ExpectedProtocol: "", 2308 ExpectAlertFromClient: false, 2309 ExpectAlertFromServer: false, 2310 Alert: 0, 2311 }, 2312 { 2313 Name: "Negotiate with higher server precedence", 2314 ClientProtocolNameList: []string{"http/1.1", "spd/1", "http/3"}, 2315 ServerProtocolNameList: []string{"ssh/2", "http/3", "spd/1"}, 2316 ExpectedProtocol: "http/3", 2317 ExpectAlertFromClient: false, 2318 ExpectAlertFromServer: false, 2319 Alert: 0, 2320 }, 2321 { 2322 Name: "Empty intersection", 2323 ClientProtocolNameList: []string{"http/1.1", "http/3"}, 2324 ServerProtocolNameList: []string{"ssh/2", "spd/1"}, 2325 ExpectedProtocol: "", 2326 ExpectAlertFromClient: false, 2327 ExpectAlertFromServer: true, 2328 Alert: alert.NoApplicationProtocol, 2329 }, 2330 { 2331 Name: "Multiple protocols in ServerHello", 2332 ClientProtocolNameList: []string{"http/1.1"}, 2333 ServerProtocolNameList: []string{"http/1.1"}, 2334 ExpectedProtocol: "http/1.1", 2335 ExpectAlertFromClient: true, 2336 ExpectAlertFromServer: false, 2337 Alert: alert.InternalError, 2338 }, 2339 } { 2340 test := test 2341 t.Run(test.Name, func(t *testing.T) { 2342 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2343 defer cancel() 2344 2345 ca, cb := dpipe.Pipe() 2346 go func() { 2347 conf := &Config{ 2348 SupportedProtocols: test.ClientProtocolNameList, 2349 } 2350 _, _ = testClient(ctx, ca, conf, false) 2351 }() 2352 2353 // Receive ClientHello 2354 resp := make([]byte, 1024) 2355 n, err := cb.Read(resp) 2356 if err != nil { 2357 t.Fatal(err) 2358 } 2359 2360 ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) 2361 defer cancel2() 2362 2363 ca2, cb2 := dpipe.Pipe() 2364 go func() { 2365 conf := &Config{ 2366 SupportedProtocols: test.ServerProtocolNameList, 2367 } 2368 if _, err2 := testServer(ctx2, cb2, conf, true); !errors.Is(err2, context.Canceled) { 2369 if test.ExpectAlertFromServer { //nolint 2370 // Assert the error type? 2371 } else { 2372 t.Error(err2) 2373 } 2374 } 2375 }() 2376 2377 time.Sleep(50 * time.Millisecond) 2378 2379 // Forward ClientHello 2380 if _, err = ca2.Write(resp[:n]); err != nil { 2381 t.Fatal(err) 2382 } 2383 2384 // Receive HelloVerify 2385 resp2 := make([]byte, 1024) 2386 n, err = ca2.Read(resp2) 2387 if err != nil { 2388 t.Fatal(err) 2389 } 2390 2391 // Forward HelloVerify 2392 if _, err = cb.Write(resp2[:n]); err != nil { 2393 t.Fatal(err) 2394 } 2395 2396 // Receive ClientHello 2397 resp3 := make([]byte, 1024) 2398 n, err = cb.Read(resp3) 2399 if err != nil { 2400 t.Fatal(err) 2401 } 2402 2403 // Forward ClientHello 2404 if _, err = ca2.Write(resp3[:n]); err != nil { 2405 t.Fatal(err) 2406 } 2407 2408 // Receive ServerHello 2409 resp4 := make([]byte, 1024) 2410 n, err = ca2.Read(resp4) 2411 if err != nil { 2412 t.Fatal(err) 2413 } 2414 2415 messages, err := recordlayer.UnpackDatagram(resp4[:n]) 2416 if err != nil { 2417 t.Fatal(err) 2418 } 2419 2420 r := &recordlayer.RecordLayer{} 2421 if err := r.Unmarshal(messages[0]); err != nil { 2422 t.Fatal(err) 2423 } 2424 2425 if test.ExpectAlertFromServer { 2426 a, ok := r.Content.(*alert.Alert) 2427 if !ok { 2428 t.Fatal("Failed to cast alert.Alert") 2429 } 2430 2431 if a.Description != test.Alert { 2432 t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) 2433 } 2434 } else { 2435 serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) 2436 if !ok { 2437 t.Fatal("Failed to cast handshake.MessageServerHello") 2438 } 2439 2440 var negotiatedProtocol string 2441 for _, v := range serverHello.Extensions { 2442 if _, ok := v.(*extension.ALPN); ok { 2443 e, ok := v.(*extension.ALPN) 2444 if !ok { 2445 t.Fatal("Failed to cast extension.ALPN") 2446 } 2447 2448 negotiatedProtocol = e.ProtocolNameList[0] 2449 2450 // Manipulate ServerHello 2451 if test.ExpectAlertFromClient { 2452 e.ProtocolNameList = append(e.ProtocolNameList, "oops") 2453 } 2454 } 2455 } 2456 2457 if negotiatedProtocol != test.ExpectedProtocol { 2458 t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol) 2459 } 2460 2461 s, err := r.Marshal() 2462 if err != nil { 2463 t.Fatal(err) 2464 } 2465 2466 // Forward ServerHello 2467 if _, err = cb.Write(s); err != nil { 2468 t.Fatal(err) 2469 } 2470 2471 if test.ExpectAlertFromClient { 2472 resp5 := make([]byte, 1024) 2473 n, err = cb.Read(resp5) 2474 if err != nil { 2475 t.Fatal(err) 2476 } 2477 2478 r2 := &recordlayer.RecordLayer{} 2479 if err := r2.Unmarshal(resp5[:n]); err != nil { 2480 t.Fatal(err) 2481 } 2482 2483 a, ok := r2.Content.(*alert.Alert) 2484 if !ok { 2485 t.Fatal("Failed to cast alert.Alert") 2486 } 2487 2488 if a.Description != test.Alert { 2489 t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) 2490 } 2491 } 2492 } 2493 2494 time.Sleep(50 * time.Millisecond) // Give some time for returned errors 2495 }) 2496 } 2497 } 2498 2499 // Make sure the supported_groups extension is not included in the ServerHello 2500 func TestSupportedGroupsExtension(t *testing.T) { 2501 // Limit runtime in case of deadlocks 2502 lim := test.TimeOut(time.Second * 20) 2503 defer lim.Stop() 2504 2505 // Check for leaking routines 2506 report := test.CheckRoutines(t) 2507 defer report() 2508 2509 t.Run("ServerHello Supported Groups", func(t *testing.T) { 2510 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2511 defer cancel() 2512 2513 ca, cb := dpipe.Pipe() 2514 go func() { 2515 if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { 2516 t.Error(err) 2517 } 2518 }() 2519 extensions := []extension.Extension{ 2520 &extension.SupportedEllipticCurves{ 2521 EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, 2522 }, 2523 &extension.SupportedPointFormats{ 2524 PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, 2525 }, 2526 } 2527 2528 time.Sleep(50 * time.Millisecond) 2529 2530 resp := make([]byte, 1024) 2531 err := sendClientHello([]byte{}, ca, 0, extensions) 2532 if err != nil { 2533 t.Fatal(err) 2534 } 2535 2536 // Receive ServerHello 2537 n, err := ca.Read(resp) 2538 if err != nil { 2539 t.Fatal(err) 2540 } 2541 r := &recordlayer.RecordLayer{} 2542 if err = r.Unmarshal(resp[:n]); err != nil { 2543 t.Fatal(err) 2544 } 2545 2546 helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) 2547 if !ok { 2548 t.Fatal("Failed to cast MessageHelloVerifyRequest") 2549 } 2550 2551 err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) 2552 if err != nil { 2553 t.Fatal(err) 2554 } 2555 if n, err = ca.Read(resp); err != nil { 2556 t.Fatal(err) 2557 } 2558 2559 messages, err := recordlayer.UnpackDatagram(resp[:n]) 2560 if err != nil { 2561 t.Fatal(err) 2562 } 2563 2564 if err := r.Unmarshal(messages[0]); err != nil { 2565 t.Fatal(err) 2566 } 2567 2568 serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) 2569 if !ok { 2570 t.Fatal("Failed to cast MessageServerHello") 2571 } 2572 2573 gotGroups := false 2574 for _, v := range serverHello.Extensions { 2575 if _, ok := v.(*extension.SupportedEllipticCurves); ok { 2576 gotGroups = true 2577 } 2578 } 2579 2580 if gotGroups { 2581 t.Errorf("TestSupportedGroups: supported_groups extension was sent in ServerHello") 2582 } 2583 }) 2584 } 2585 2586 func TestSessionResume(t *testing.T) { 2587 // Limit runtime in case of deadlocks 2588 lim := test.TimeOut(time.Second * 20) 2589 defer lim.Stop() 2590 2591 // Check for leaking routines 2592 report := test.CheckRoutines(t) 2593 defer report() 2594 2595 t.Run("resumed", func(t *testing.T) { 2596 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2597 defer cancel() 2598 2599 type result struct { 2600 c *Conn 2601 err error 2602 } 2603 clientRes := make(chan result, 1) 2604 2605 ss := &memSessStore{} 2606 2607 id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") 2608 secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") 2609 2610 s := Session{ID: id, Secret: secret} 2611 2612 ca, cb := dpipe.Pipe() 2613 2614 _ = ss.Set(id, s) 2615 _ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s) 2616 2617 go func() { 2618 config := &Config{ 2619 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 2620 ServerName: "example.com", 2621 SessionStore: ss, 2622 MTU: 100, 2623 } 2624 c, err := testClient(ctx, ca, config, false) 2625 clientRes <- result{c, err} 2626 }() 2627 2628 config := &Config{ 2629 CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 2630 ServerName: "example.com", 2631 SessionStore: ss, 2632 MTU: 100, 2633 } 2634 server, err := testServer(ctx, cb, config, true) 2635 if err != nil { 2636 t.Fatalf("TestSessionResume: Server failed(%v)", err) 2637 } 2638 2639 actualSessionID := server.ConnectionState().SessionID 2640 actualMasterSecret := server.ConnectionState().masterSecret 2641 if !bytes.Equal(actualSessionID, id) { 2642 t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) 2643 } 2644 if !bytes.Equal(actualMasterSecret, secret) { 2645 t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret) 2646 } 2647 2648 defer func() { 2649 _ = server.Close() 2650 }() 2651 2652 res := <-clientRes 2653 if res.err != nil { 2654 t.Fatal(res.err) 2655 } 2656 _ = res.c.Close() 2657 }) 2658 2659 t.Run("new session", func(t *testing.T) { 2660 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2661 defer cancel() 2662 2663 type result struct { 2664 c *Conn 2665 err error 2666 } 2667 clientRes := make(chan result, 1) 2668 2669 s1 := &memSessStore{} 2670 s2 := &memSessStore{} 2671 2672 ca, cb := dpipe.Pipe() 2673 go func() { 2674 config := &Config{ 2675 ServerName: "example.com", 2676 SessionStore: s1, 2677 } 2678 c, err := testClient(ctx, ca, config, false) 2679 clientRes <- result{c, err} 2680 }() 2681 2682 config := &Config{ 2683 SessionStore: s2, 2684 } 2685 server, err := testServer(ctx, cb, config, true) 2686 if err != nil { 2687 t.Fatalf("TestSessionResumetion: Server failed(%v)", err) 2688 } 2689 2690 actualSessionID := server.ConnectionState().SessionID 2691 actualMasterSecret := server.ConnectionState().masterSecret 2692 ss, _ := s2.Get(actualSessionID) 2693 if !bytes.Equal(actualMasterSecret, ss.Secret) { 2694 t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) 2695 } 2696 2697 defer func() { 2698 _ = server.Close() 2699 }() 2700 2701 res := <-clientRes 2702 if res.err != nil { 2703 t.Fatal(res.err) 2704 } 2705 cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com")) 2706 if !bytes.Equal(actualMasterSecret, cs.Secret) { 2707 t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) 2708 } 2709 _ = res.c.Close() 2710 }) 2711 } 2712 2713 type memSessStore struct { 2714 sync.Map 2715 } 2716 2717 func (ms *memSessStore) Set(key []byte, s Session) error { 2718 k := hex.EncodeToString(key) 2719 ms.Store(k, s) 2720 2721 return nil 2722 } 2723 2724 func (ms *memSessStore) Get(key []byte) (Session, error) { 2725 k := hex.EncodeToString(key) 2726 2727 v, ok := ms.Load(k) 2728 if !ok { 2729 return Session{}, nil 2730 } 2731 2732 s, ok := v.(Session) 2733 if !ok { 2734 return Session{}, nil 2735 } 2736 2737 return s, nil 2738 } 2739 2740 func (ms *memSessStore) Del(key []byte) error { 2741 k := hex.EncodeToString(key) 2742 ms.Delete(k) 2743 2744 return nil 2745 } 2746 2747 // Assert that the server only uses CipherSuites with a hash+signature that matches 2748 // the certificate. As specified in rfc5246#section-7.4.3 2749 func TestCipherSuiteMatchesCertificateType(t *testing.T) { 2750 // Limit runtime in case of deadlocks 2751 lim := test.TimeOut(time.Second * 20) 2752 defer lim.Stop() 2753 2754 // Check for leaking routines 2755 report := test.CheckRoutines(t) 2756 defer report() 2757 2758 for _, test := range []struct { 2759 Name string 2760 cipherList []CipherSuiteID 2761 expectedCipher CipherSuiteID 2762 generateRSA bool 2763 }{ 2764 { 2765 Name: "ECDSA Certificate with RSA CipherSuite first", 2766 cipherList: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 2767 expectedCipher: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 2768 }, 2769 { 2770 Name: "RSA Certificate with ECDSA CipherSuite first", 2771 cipherList: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, 2772 expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 2773 generateRSA: true, 2774 }, 2775 } { 2776 test := test 2777 t.Run(test.Name, func(t *testing.T) { 2778 clientErr := make(chan error, 1) 2779 client := make(chan *Conn, 1) 2780 2781 ca, cb := dpipe.Pipe() 2782 go func() { 2783 c, err := testClient(context.TODO(), ca, &Config{CipherSuites: test.cipherList}, false) 2784 clientErr <- err 2785 client <- c 2786 }() 2787 2788 var ( 2789 priv crypto.PrivateKey 2790 err error 2791 ) 2792 2793 if test.generateRSA { 2794 if priv, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { 2795 t.Fatal(err) 2796 } 2797 } else { 2798 if priv, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { 2799 t.Fatal(err) 2800 } 2801 } 2802 2803 serverCert, err := selfsign.SelfSign(priv) 2804 if err != nil { 2805 t.Fatal(err) 2806 } 2807 2808 if s, err := testServer(context.TODO(), cb, &Config{ 2809 CipherSuites: test.cipherList, 2810 Certificates: []tls.Certificate{serverCert}, 2811 }, false); err != nil { 2812 t.Fatal(err) 2813 } else if err = s.Close(); err != nil { 2814 t.Fatal(err) 2815 } 2816 2817 if c, err := <-client, <-clientErr; err != nil { 2818 t.Fatal(err) 2819 } else if err := c.Close(); err != nil { 2820 t.Fatal(err) 2821 } else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher { 2822 t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID()) 2823 } 2824 }) 2825 } 2826 } 2827 2828 // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server 2829 func TestMultipleServerCertificates(t *testing.T) { 2830 fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") 2831 if err != nil { 2832 t.Fatal(err) 2833 } 2834 2835 barCert, err := selfsign.GenerateSelfSignedWithDNS("bar") 2836 if err != nil { 2837 t.Fatal(err) 2838 } 2839 2840 caPool := x509.NewCertPool() 2841 for _, cert := range []tls.Certificate{fooCert, barCert} { 2842 certificate, err := x509.ParseCertificate(cert.Certificate[0]) 2843 if err != nil { 2844 t.Fatal(err) 2845 } 2846 caPool.AddCert(certificate) 2847 } 2848 2849 for _, test := range []struct { 2850 RequestServerName string 2851 ExpectedDNSName string 2852 }{ 2853 { 2854 "foo", 2855 "foo", 2856 }, 2857 { 2858 "bar", 2859 "bar", 2860 }, 2861 { 2862 "invalid", 2863 "foo", 2864 }, 2865 } { 2866 test := test 2867 t.Run(test.RequestServerName, func(t *testing.T) { 2868 clientErr := make(chan error, 2) 2869 client := make(chan *Conn, 1) 2870 2871 ca, cb := dpipe.Pipe() 2872 go func() { 2873 c, err := testClient(context.TODO(), ca, &Config{ 2874 RootCAs: caPool, 2875 ServerName: test.RequestServerName, 2876 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 2877 certificate, err := x509.ParseCertificate(rawCerts[0]) 2878 if err != nil { 2879 return err 2880 } 2881 2882 if certificate.DNSNames[0] != test.ExpectedDNSName { 2883 return errWrongCert 2884 } 2885 2886 return nil 2887 }, 2888 }, false) 2889 clientErr <- err 2890 client <- c 2891 }() 2892 2893 if s, err := testServer(context.TODO(), cb, &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { 2894 t.Fatal(err) 2895 } else if err = s.Close(); err != nil { 2896 t.Fatal(err) 2897 } 2898 2899 if c, err := <-client, <-clientErr; err != nil { 2900 t.Fatal(err) 2901 } else if err := c.Close(); err != nil { 2902 t.Fatal(err) 2903 } 2904 }) 2905 } 2906 } 2907 2908 func TestEllipticCurveConfiguration(t *testing.T) { 2909 // Check for leaking routines 2910 report := test.CheckRoutines(t) 2911 defer report() 2912 2913 for _, test := range []struct { 2914 Name string 2915 ConfigCurves []elliptic.Curve 2916 HadnshakeCurves []elliptic.Curve 2917 }{ 2918 { 2919 Name: "Curve defaulting", 2920 ConfigCurves: nil, 2921 HadnshakeCurves: defaultCurves, 2922 }, 2923 { 2924 Name: "Single curve", 2925 ConfigCurves: []elliptic.Curve{elliptic.X25519}, 2926 HadnshakeCurves: []elliptic.Curve{elliptic.X25519}, 2927 }, 2928 { 2929 Name: "Multiple curves", 2930 ConfigCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, 2931 HadnshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, 2932 }, 2933 } { 2934 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2935 defer cancel() 2936 2937 ca, cb := dpipe.Pipe() 2938 type result struct { 2939 c *Conn 2940 err error 2941 } 2942 c := make(chan result) 2943 2944 go func() { 2945 client, err := testClient(ctx, ca, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) 2946 c <- result{client, err} 2947 }() 2948 2949 server, err := testServer(ctx, cb, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) 2950 if err != nil { 2951 t.Fatalf("Server error: %v", err) 2952 } 2953 2954 if len(test.ConfigCurves) == 0 && len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { 2955 t.Fatalf("Failed to default Elliptic curves, expected %d, got: %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) 2956 } 2957 2958 if len(test.ConfigCurves) != 0 { 2959 if len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { 2960 t.Fatalf("Failed to configure Elliptic curves, expect %d, got %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) 2961 } 2962 for i, c := range test.ConfigCurves { 2963 if c != server.fsm.cfg.ellipticCurves[i] { 2964 t.Fatalf("Failed to maintain Elliptic curve order, expected %s, got %s", c, server.fsm.cfg.ellipticCurves[i]) 2965 } 2966 } 2967 } 2968 2969 res := <-c 2970 if res.err != nil { 2971 t.Fatalf("Client error; %v", err) 2972 } 2973 2974 defer func() { 2975 err = server.Close() 2976 if err != nil { 2977 t.Fatal(err) 2978 } 2979 err = res.c.Close() 2980 if err != nil { 2981 t.Fatal(err) 2982 } 2983 }() 2984 } 2985 } 2986 2987 func TestSkipHelloVerify(t *testing.T) { 2988 report := test.CheckRoutines(t) 2989 defer report() 2990 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 2991 defer cancel() 2992 2993 ca, cb := dpipe.Pipe() 2994 certificate, err := selfsign.GenerateSelfSigned() 2995 if err != nil { 2996 t.Fatal(err) 2997 } 2998 gotHello := make(chan struct{}) 2999 3000 go func() { 3001 server, sErr := testServer(ctx, cb, &Config{ 3002 Certificates: []tls.Certificate{certificate}, 3003 LoggerFactory: logging.NewDefaultLoggerFactory(), 3004 InsecureSkipVerifyHello: true, 3005 }, false) 3006 if sErr != nil { 3007 t.Error(sErr) 3008 return 3009 } 3010 buf := make([]byte, 1024) 3011 if _, sErr = server.Read(buf); sErr != nil { 3012 t.Error(sErr) 3013 } 3014 gotHello <- struct{}{} 3015 if sErr = server.Close(); sErr != nil { //nolint:contextcheck 3016 t.Error(sErr) 3017 } 3018 }() 3019 3020 client, err := testClient(ctx, ca, &Config{ 3021 LoggerFactory: logging.NewDefaultLoggerFactory(), 3022 InsecureSkipVerify: true, 3023 }, false) 3024 if err != nil { 3025 t.Fatal(err) 3026 } 3027 if _, err = client.Write([]byte("hello")); err != nil { 3028 t.Error(err) 3029 } 3030 select { 3031 case <-gotHello: 3032 // OK 3033 case <-time.After(time.Second * 5): 3034 t.Error("timeout") 3035 } 3036 3037 if err = client.Close(); err != nil { 3038 t.Error(err) 3039 } 3040 } 3041 3042 type connWithCallback struct { 3043 net.Conn 3044 onWrite func([]byte) 3045 } 3046 3047 func (c *connWithCallback) Write(b []byte) (int, error) { 3048 if c.onWrite != nil { 3049 c.onWrite(b) 3050 } 3051 return c.Conn.Write(b) 3052 } 3053 3054 func TestApplicationDataQueueLimited(t *testing.T) { 3055 // Limit runtime in case of deadlocks 3056 lim := test.TimeOut(time.Second * 20) 3057 defer lim.Stop() 3058 3059 // Check for leaking routines 3060 report := test.CheckRoutines(t) 3061 defer report() 3062 3063 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 3064 defer cancel() 3065 3066 ca, cb := dpipe.Pipe() 3067 defer ca.Close() 3068 defer cb.Close() 3069 3070 done := make(chan struct{}) 3071 go func() { 3072 serverCert, err := selfsign.GenerateSelfSigned() 3073 if err != nil { 3074 t.Error(err) 3075 return 3076 } 3077 cfg := &Config{} 3078 cfg.Certificates = []tls.Certificate{serverCert} 3079 3080 dconn, err := createConn(cb, cfg, false) 3081 if err != nil { 3082 t.Error(err) 3083 return 3084 } 3085 go func() { 3086 for i := 0; i < 5; i++ { 3087 dconn.lock.RLock() 3088 qlen := len(dconn.encryptedPackets) 3089 dconn.lock.RUnlock() 3090 if qlen > maxAppDataPacketQueueSize { 3091 t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) 3092 } 3093 t.Log(qlen) 3094 time.Sleep(1 * time.Second) 3095 } 3096 3097 }() 3098 if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { 3099 t.Error("expected handshake to fail") 3100 } 3101 close(done) 3102 }() 3103 extensions := []extension.Extension{} 3104 3105 time.Sleep(50 * time.Millisecond) 3106 3107 err := sendClientHello([]byte{}, ca, 0, extensions) 3108 if err != nil { 3109 t.Fatal(err) 3110 } 3111 3112 time.Sleep(50 * time.Millisecond) 3113 3114 for i := 0; i < 1000; i++ { 3115 // Send an application data packet 3116 packet, err := (&recordlayer.RecordLayer{ 3117 Header: recordlayer.Header{ 3118 Version: protocol.Version1_2, 3119 SequenceNumber: uint64(3), 3120 Epoch: 1, // use an epoch greater than 0 3121 }, 3122 Content: &protocol.ApplicationData{ 3123 Data: []byte{1, 2, 3, 4}, 3124 }, 3125 }).Marshal() 3126 if err != nil { 3127 t.Fatal(err) 3128 } 3129 ca.Write(packet) 3130 if i%100 == 0 { 3131 time.Sleep(10 * time.Millisecond) 3132 } 3133 } 3134 time.Sleep(1 * time.Second) 3135 ca.Close() 3136 <-done 3137 } 3138 3139 func TestApplicationDataWithClientHelloRejected(t *testing.T) { 3140 // Limit runtime in case of deadlocks 3141 lim := test.TimeOut(time.Second * 20) 3142 defer lim.Stop() 3143 3144 // Check for leaking routines 3145 report := test.CheckRoutines(t) 3146 defer report() 3147 3148 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 3149 defer cancel() 3150 3151 ca, cb := dpipe.Pipe() 3152 defer ca.Close() 3153 defer cb.Close() 3154 3155 done := make(chan struct{}) 3156 go func() { 3157 if _, err := testServer(ctx, cb, &Config{}, true); err == nil { 3158 t.Error("expected handshake to fail") 3159 } 3160 close(done) 3161 }() 3162 extensions := []extension.Extension{} 3163 3164 time.Sleep(50 * time.Millisecond) 3165 3166 err := sendClientHello([]byte{}, ca, 0, extensions) 3167 if err != nil { 3168 t.Fatal(err) 3169 } 3170 3171 // Send an application data packet 3172 packet, err := (&recordlayer.RecordLayer{ 3173 Header: recordlayer.Header{ 3174 Version: protocol.Version1_2, 3175 SequenceNumber: uint64(3), 3176 Epoch: 0, 3177 }, 3178 Content: &protocol.ApplicationData{ 3179 Data: []byte{1, 2, 3, 4}, 3180 }, 3181 }).Marshal() 3182 if err != nil { 3183 t.Fatal(err) 3184 } 3185 ca.Write(packet) 3186 <-done 3187 }