github.com/pion/dtls/v2@v2.2.12/conn.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/pion/dtls/v2/internal/closer" 17 "github.com/pion/dtls/v2/pkg/crypto/elliptic" 18 "github.com/pion/dtls/v2/pkg/crypto/signaturehash" 19 "github.com/pion/dtls/v2/pkg/protocol" 20 "github.com/pion/dtls/v2/pkg/protocol/alert" 21 "github.com/pion/dtls/v2/pkg/protocol/handshake" 22 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 23 "github.com/pion/logging" 24 "github.com/pion/transport/v2/connctx" 25 "github.com/pion/transport/v2/deadline" 26 "github.com/pion/transport/v2/replaydetector" 27 ) 28 29 const ( 30 initialTickerInterval = time.Second 31 cookieLength = 20 32 sessionLength = 32 33 defaultNamedCurve = elliptic.X25519 34 inboundBufferSize = 8192 35 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 36 defaultReplayProtectionWindow = 64 37 // maxAppDataPacketQueueSize is the maximum number of app data packets we will 38 // enqueue before the handshake is completed 39 maxAppDataPacketQueueSize = 100 40 ) 41 42 func invalidKeyingLabels() map[string]bool { 43 return map[string]bool{ 44 "client finished": true, 45 "server finished": true, 46 "master secret": true, 47 "key expansion": true, 48 } 49 } 50 51 // Conn represents a DTLS connection 52 type Conn struct { 53 lock sync.RWMutex // Internal lock (must not be public) 54 nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from 55 fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling 56 handshakeCache *handshakeCache // caching of handshake messages for verifyData generation 57 decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read` 58 59 state State // Internal state 60 61 maximumTransmissionUnit int 62 63 handshakeCompletedSuccessfully atomic.Value 64 65 encryptedPackets [][]byte 66 67 connectionClosedByUser bool 68 closeLock sync.Mutex 69 closed *closer.Closer 70 handshakeLoopsFinished sync.WaitGroup 71 72 readDeadline *deadline.Deadline 73 writeDeadline *deadline.Deadline 74 75 log logging.LeveledLogger 76 77 reading chan struct{} 78 handshakeRecv chan chan struct{} 79 cancelHandshaker func() 80 cancelHandshakeReader func() 81 82 fsm *handshakeFSM 83 84 replayProtectionWindow uint 85 } 86 87 func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) { 88 err := validateConfig(config) 89 if err != nil { 90 return nil, err 91 } 92 93 if nextConn == nil { 94 return nil, errNilNextConn 95 } 96 97 loggerFactory := config.LoggerFactory 98 if loggerFactory == nil { 99 loggerFactory = logging.NewDefaultLoggerFactory() 100 } 101 102 logger := loggerFactory.NewLogger("dtls") 103 104 mtu := config.MTU 105 if mtu <= 0 { 106 mtu = defaultMTU 107 } 108 109 replayProtectionWindow := config.ReplayProtectionWindow 110 if replayProtectionWindow <= 0 { 111 replayProtectionWindow = defaultReplayProtectionWindow 112 } 113 114 c := &Conn{ 115 nextConn: connctx.New(nextConn), 116 fragmentBuffer: newFragmentBuffer(), 117 handshakeCache: newHandshakeCache(), 118 maximumTransmissionUnit: mtu, 119 120 decrypted: make(chan interface{}, 1), 121 log: logger, 122 123 readDeadline: deadline.New(), 124 writeDeadline: deadline.New(), 125 126 reading: make(chan struct{}, 1), 127 handshakeRecv: make(chan chan struct{}), 128 closed: closer.NewCloser(), 129 cancelHandshaker: func() {}, 130 131 replayProtectionWindow: uint(replayProtectionWindow), 132 133 state: State{ 134 isClient: isClient, 135 }, 136 } 137 138 c.setRemoteEpoch(0) 139 c.setLocalEpoch(0) 140 return c, nil 141 } 142 143 func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { 144 if conn == nil { 145 return nil, errNilNextConn 146 } 147 148 cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) 149 if err != nil { 150 return nil, err 151 } 152 153 signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) 154 if err != nil { 155 return nil, err 156 } 157 158 workerInterval := initialTickerInterval 159 if config.FlightInterval != 0 { 160 workerInterval = config.FlightInterval 161 } 162 163 serverName := config.ServerName 164 // Do not allow the use of an IP address literal as an SNI value. 165 // See RFC 6066, Section 3. 166 if net.ParseIP(serverName) != nil { 167 serverName = "" 168 } 169 170 curves := config.EllipticCurves 171 if len(curves) == 0 { 172 curves = defaultCurves 173 } 174 175 hsCfg := &handshakeConfig{ 176 localPSKCallback: config.PSK, 177 localPSKIdentityHint: config.PSKIdentityHint, 178 localCipherSuites: cipherSuites, 179 localSignatureSchemes: signatureSchemes, 180 extendedMasterSecret: config.ExtendedMasterSecret, 181 localSRTPProtectionProfiles: config.SRTPProtectionProfiles, 182 serverName: serverName, 183 supportedProtocols: config.SupportedProtocols, 184 clientAuth: config.ClientAuth, 185 localCertificates: config.Certificates, 186 insecureSkipVerify: config.InsecureSkipVerify, 187 verifyPeerCertificate: config.VerifyPeerCertificate, 188 verifyConnection: config.VerifyConnection, 189 rootCAs: config.RootCAs, 190 clientCAs: config.ClientCAs, 191 customCipherSuites: config.CustomCipherSuites, 192 retransmitInterval: workerInterval, 193 log: conn.log, 194 initialEpoch: 0, 195 keyLogWriter: config.KeyLogWriter, 196 sessionStore: config.SessionStore, 197 ellipticCurves: curves, 198 localGetCertificate: config.GetCertificate, 199 localGetClientCertificate: config.GetClientCertificate, 200 insecureSkipHelloVerify: config.InsecureSkipVerifyHello, 201 } 202 203 // rfc5246#section-7.4.3 204 // In addition, the hash and signature algorithms MUST be compatible 205 // with the key in the server's end-entity certificate. 206 if !isClient { 207 cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) 208 if err != nil && !errors.Is(err, errNoCertificates) { 209 return nil, err 210 } 211 hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) 212 } 213 214 var initialFlight flightVal 215 var initialFSMState handshakeState 216 217 if initialState != nil { 218 if conn.state.isClient { 219 initialFlight = flight5 220 } else { 221 initialFlight = flight6 222 } 223 initialFSMState = handshakeFinished 224 225 conn.state = *initialState 226 } else { 227 if conn.state.isClient { 228 initialFlight = flight1 229 } else { 230 initialFlight = flight0 231 } 232 initialFSMState = handshakePreparing 233 } 234 // Do handshake 235 if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { 236 return nil, err 237 } 238 239 conn.log.Trace("Handshake Completed") 240 241 return conn, nil 242 } 243 244 // Dial connects to the given network address and establishes a DTLS connection on top. 245 // Connection handshake will timeout using ConnectContextMaker in the Config. 246 // If you want to specify the timeout duration, use DialWithContext() instead. 247 func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { 248 ctx, cancel := config.connectContextMaker() 249 defer cancel() 250 251 return DialWithContext(ctx, network, raddr, config) 252 } 253 254 // Client establishes a DTLS connection over an existing connection. 255 // Connection handshake will timeout using ConnectContextMaker in the Config. 256 // If you want to specify the timeout duration, use ClientWithContext() instead. 257 func Client(conn net.Conn, config *Config) (*Conn, error) { 258 ctx, cancel := config.connectContextMaker() 259 defer cancel() 260 261 return ClientWithContext(ctx, conn, config) 262 } 263 264 // Server listens for incoming DTLS connections. 265 // Connection handshake will timeout using ConnectContextMaker in the Config. 266 // If you want to specify the timeout duration, use ServerWithContext() instead. 267 func Server(conn net.Conn, config *Config) (*Conn, error) { 268 ctx, cancel := config.connectContextMaker() 269 defer cancel() 270 271 return ServerWithContext(ctx, conn, config) 272 } 273 274 // DialWithContext connects to the given network address and establishes a DTLS connection on top. 275 func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { 276 pConn, err := net.DialUDP(network, nil, raddr) 277 if err != nil { 278 return nil, err 279 } 280 return ClientWithContext(ctx, pConn, config) 281 } 282 283 // ClientWithContext establishes a DTLS connection over an existing connection. 284 func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { 285 switch { 286 case config == nil: 287 return nil, errNoConfigProvided 288 case config.PSK != nil && config.PSKIdentityHint == nil: 289 return nil, errPSKAndIdentityMustBeSetForClient 290 } 291 292 dconn, err := createConn(conn, config, true) 293 if err != nil { 294 return nil, err 295 } 296 297 return handshakeConn(ctx, dconn, config, true, nil) 298 } 299 300 // ServerWithContext listens for incoming DTLS connections. 301 func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { 302 if config == nil { 303 return nil, errNoConfigProvided 304 } 305 dconn, err := createConn(conn, config, false) 306 if err != nil { 307 return nil, err 308 } 309 return handshakeConn(ctx, dconn, config, false, nil) 310 } 311 312 // Read reads data from the connection. 313 func (c *Conn) Read(p []byte) (n int, err error) { 314 if !c.isHandshakeCompletedSuccessfully() { 315 return 0, errHandshakeInProgress 316 } 317 318 select { 319 case <-c.readDeadline.Done(): 320 return 0, errDeadlineExceeded 321 default: 322 } 323 324 for { 325 select { 326 case <-c.readDeadline.Done(): 327 return 0, errDeadlineExceeded 328 case out, ok := <-c.decrypted: 329 if !ok { 330 return 0, io.EOF 331 } 332 switch val := out.(type) { 333 case ([]byte): 334 if len(p) < len(val) { 335 return 0, errBufferTooSmall 336 } 337 copy(p, val) 338 return len(val), nil 339 case (error): 340 return 0, val 341 } 342 } 343 } 344 } 345 346 // Write writes len(p) bytes from p to the DTLS connection 347 func (c *Conn) Write(p []byte) (int, error) { 348 if c.isConnectionClosed() { 349 return 0, ErrConnClosed 350 } 351 352 select { 353 case <-c.writeDeadline.Done(): 354 return 0, errDeadlineExceeded 355 default: 356 } 357 358 if !c.isHandshakeCompletedSuccessfully() { 359 return 0, errHandshakeInProgress 360 } 361 362 return len(p), c.writePackets(c.writeDeadline, []*packet{ 363 { 364 record: &recordlayer.RecordLayer{ 365 Header: recordlayer.Header{ 366 Epoch: c.state.getLocalEpoch(), 367 Version: protocol.Version1_2, 368 }, 369 Content: &protocol.ApplicationData{ 370 Data: p, 371 }, 372 }, 373 shouldEncrypt: true, 374 }, 375 }) 376 } 377 378 // Close closes the connection. 379 func (c *Conn) Close() error { 380 err := c.close(true) //nolint:contextcheck 381 c.handshakeLoopsFinished.Wait() 382 return err 383 } 384 385 // ConnectionState returns basic DTLS details about the connection. 386 // Note that this replaced the `Export` function of v1. 387 func (c *Conn) ConnectionState() State { 388 c.lock.RLock() 389 defer c.lock.RUnlock() 390 return *c.state.clone() 391 } 392 393 // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile 394 func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { 395 profile := c.state.getSRTPProtectionProfile() 396 if profile == 0 { 397 return 0, false 398 } 399 400 return profile, true 401 } 402 403 func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { 404 c.lock.Lock() 405 defer c.lock.Unlock() 406 407 var rawPackets [][]byte 408 409 for _, p := range pkts { 410 if h, ok := p.record.Content.(*handshake.Handshake); ok { 411 handshakeRaw, err := p.record.Marshal() 412 if err != nil { 413 return err 414 } 415 416 c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", 417 srvCliStr(c.state.isClient), h.Header.Type.String(), 418 p.record.Header.Epoch, h.Header.MessageSequence) 419 c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) 420 421 rawHandshakePackets, err := c.processHandshakePacket(p, h) 422 if err != nil { 423 return err 424 } 425 rawPackets = append(rawPackets, rawHandshakePackets...) 426 } else { 427 rawPacket, err := c.processPacket(p) 428 if err != nil { 429 return err 430 } 431 rawPackets = append(rawPackets, rawPacket) 432 } 433 } 434 if len(rawPackets) == 0 { 435 return nil 436 } 437 compactedRawPackets := c.compactRawPackets(rawPackets) 438 439 for _, compactedRawPackets := range compactedRawPackets { 440 if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil { 441 return netError(err) 442 } 443 } 444 445 return nil 446 } 447 448 func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { 449 // avoid a useless copy in the common case 450 if len(rawPackets) == 1 { 451 return rawPackets 452 } 453 454 combinedRawPackets := make([][]byte, 0) 455 currentCombinedRawPacket := make([]byte, 0) 456 457 for _, rawPacket := range rawPackets { 458 if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { 459 combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) 460 currentCombinedRawPacket = []byte{} 461 } 462 currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) 463 } 464 465 combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) 466 467 return combinedRawPackets 468 } 469 470 func (c *Conn) processPacket(p *packet) ([]byte, error) { 471 epoch := p.record.Header.Epoch 472 for len(c.state.localSequenceNumber) <= int(epoch) { 473 c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) 474 } 475 seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 476 if seq > recordlayer.MaxSequenceNumber { 477 // RFC 6347 Section 4.1.0 478 // The implementation must either abandon an association or rehandshake 479 // prior to allowing the sequence number to wrap. 480 return nil, errSequenceNumberOverflow 481 } 482 p.record.Header.SequenceNumber = seq 483 484 rawPacket, err := p.record.Marshal() 485 if err != nil { 486 return nil, err 487 } 488 489 if p.shouldEncrypt { 490 var err error 491 rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) 492 if err != nil { 493 return nil, err 494 } 495 } 496 497 return rawPacket, nil 498 } 499 500 func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) { 501 rawPackets := make([][]byte, 0) 502 503 handshakeFragments, err := c.fragmentHandshake(h) 504 if err != nil { 505 return nil, err 506 } 507 epoch := p.record.Header.Epoch 508 for len(c.state.localSequenceNumber) <= int(epoch) { 509 c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) 510 } 511 512 for _, handshakeFragment := range handshakeFragments { 513 seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 514 if seq > recordlayer.MaxSequenceNumber { 515 return nil, errSequenceNumberOverflow 516 } 517 518 recordlayerHeader := &recordlayer.Header{ 519 Version: p.record.Header.Version, 520 ContentType: p.record.Header.ContentType, 521 ContentLen: uint16(len(handshakeFragment)), 522 Epoch: p.record.Header.Epoch, 523 SequenceNumber: seq, 524 } 525 526 rawPacket, err := recordlayerHeader.Marshal() 527 if err != nil { 528 return nil, err 529 } 530 531 p.record.Header = *recordlayerHeader 532 533 rawPacket = append(rawPacket, handshakeFragment...) 534 if p.shouldEncrypt { 535 var err error 536 rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) 537 if err != nil { 538 return nil, err 539 } 540 } 541 542 rawPackets = append(rawPackets, rawPacket) 543 } 544 545 return rawPackets, nil 546 } 547 548 func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { 549 content, err := h.Message.Marshal() 550 if err != nil { 551 return nil, err 552 } 553 554 fragmentedHandshakes := make([][]byte, 0) 555 556 contentFragments := splitBytes(content, c.maximumTransmissionUnit) 557 if len(contentFragments) == 0 { 558 contentFragments = [][]byte{ 559 {}, 560 } 561 } 562 563 offset := 0 564 for _, contentFragment := range contentFragments { 565 contentFragmentLen := len(contentFragment) 566 567 headerFragment := &handshake.Header{ 568 Type: h.Header.Type, 569 Length: h.Header.Length, 570 MessageSequence: h.Header.MessageSequence, 571 FragmentOffset: uint32(offset), 572 FragmentLength: uint32(contentFragmentLen), 573 } 574 575 offset += contentFragmentLen 576 577 fragmentedHandshake, err := headerFragment.Marshal() 578 if err != nil { 579 return nil, err 580 } 581 582 fragmentedHandshake = append(fragmentedHandshake, contentFragment...) 583 fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) 584 } 585 586 return fragmentedHandshakes, nil 587 } 588 589 var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals 590 New: func() interface{} { 591 b := make([]byte, inboundBufferSize) 592 return &b 593 }, 594 } 595 596 func (c *Conn) readAndBuffer(ctx context.Context) error { 597 bufptr, ok := poolReadBuffer.Get().(*[]byte) 598 if !ok { 599 return errFailedToAccessPoolReadBuffer 600 } 601 defer poolReadBuffer.Put(bufptr) 602 603 b := *bufptr 604 i, err := c.nextConn.ReadContext(ctx, b) 605 if err != nil { 606 return netError(err) 607 } 608 609 pkts, err := recordlayer.UnpackDatagram(b[:i]) 610 if err != nil { 611 return err 612 } 613 614 var hasHandshake bool 615 for _, p := range pkts { 616 hs, alert, err := c.handleIncomingPacket(ctx, p, true) 617 if alert != nil { 618 if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { 619 if err == nil { 620 err = alertErr 621 } 622 } 623 } 624 if hs { 625 hasHandshake = true 626 } 627 628 if err != nil { 629 return err 630 } 631 } 632 if hasHandshake { 633 done := make(chan struct{}) 634 select { 635 case c.handshakeRecv <- done: 636 // If the other party may retransmit the flight, 637 // we should respond even if it not a new message. 638 <-done 639 case <-c.fsm.Done(): 640 } 641 } 642 return nil 643 } 644 645 func (c *Conn) handleQueuedPackets(ctx context.Context) error { 646 pkts := c.encryptedPackets 647 c.encryptedPackets = nil 648 649 for _, p := range pkts { 650 _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue 651 if alert != nil { 652 if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { 653 if err == nil { 654 err = alertErr 655 } 656 } 657 } 658 var e *alertError 659 if errors.As(err, &e) { 660 if e.IsFatalOrCloseNotify() { 661 return e 662 } 663 } else if err != nil { 664 return err 665 } 666 } 667 return nil 668 } 669 670 func (c *Conn) enqueueEncryptedPackets(packet []byte) bool { 671 if len(c.encryptedPackets) < maxAppDataPacketQueueSize { 672 c.encryptedPackets = append(c.encryptedPackets, packet) 673 return true 674 } 675 return false 676 } 677 678 func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit 679 h := &recordlayer.Header{} 680 if err := h.Unmarshal(buf); err != nil { 681 // Decode error must be silently discarded 682 // [RFC6347 Section-4.1.2.7] 683 c.log.Debugf("discarded broken packet: %v", err) 684 return false, nil, nil 685 } 686 // Validate epoch 687 remoteEpoch := c.state.getRemoteEpoch() 688 if h.Epoch > remoteEpoch { 689 if h.Epoch > remoteEpoch+1 { 690 c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", 691 h.Epoch, h.SequenceNumber, 692 ) 693 return false, nil, nil 694 } 695 if enqueue { 696 if ok := c.enqueueEncryptedPackets(buf); ok { 697 c.log.Debug("received packet of next epoch, queuing packet") 698 } 699 } 700 return false, nil, nil 701 } 702 703 // Anti-replay protection 704 for len(c.state.replayDetector) <= int(h.Epoch) { 705 c.state.replayDetector = append(c.state.replayDetector, 706 replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), 707 ) 708 } 709 markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber) 710 if !ok { 711 c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", 712 h.Epoch, h.SequenceNumber, 713 ) 714 return false, nil, nil 715 } 716 717 // Decrypt 718 if h.Epoch != 0 { 719 if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { 720 if enqueue { 721 if ok := c.enqueueEncryptedPackets(buf); ok { 722 c.log.Debug("handshake not finished, queuing packet") 723 } 724 } 725 return false, nil, nil 726 } 727 728 var err error 729 buf, err = c.state.cipherSuite.Decrypt(buf) 730 if err != nil { 731 c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) 732 return false, nil, nil 733 } 734 } 735 736 isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) 737 if err != nil { 738 // Decode error must be silently discarded 739 // [RFC6347 Section-4.1.2.7] 740 c.log.Debugf("defragment failed: %s", err) 741 return false, nil, nil 742 } else if isHandshake { 743 markPacketAsValid() 744 for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { 745 header := &handshake.Header{} 746 if err := header.Unmarshal(out); err != nil { 747 c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) 748 continue 749 } 750 c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) 751 } 752 753 return true, nil, nil 754 } 755 756 r := &recordlayer.RecordLayer{} 757 if err := r.Unmarshal(buf); err != nil { 758 return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err 759 } 760 761 switch content := r.Content.(type) { 762 case *alert.Alert: 763 c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) 764 var a *alert.Alert 765 if content.Description == alert.CloseNotify { 766 // Respond with a close_notify [RFC5246 Section 7.2.1] 767 a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} 768 } 769 markPacketAsValid() 770 return false, a, &alertError{content} 771 case *protocol.ChangeCipherSpec: 772 if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { 773 if enqueue { 774 if ok := c.enqueueEncryptedPackets(buf); ok { 775 c.log.Debugf("CipherSuite not initialized, queuing packet") 776 } 777 } 778 return false, nil, nil 779 } 780 781 newRemoteEpoch := h.Epoch + 1 782 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) 783 784 if c.state.getRemoteEpoch()+1 == newRemoteEpoch { 785 c.setRemoteEpoch(newRemoteEpoch) 786 markPacketAsValid() 787 } 788 case *protocol.ApplicationData: 789 if h.Epoch == 0 { 790 return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero 791 } 792 793 markPacketAsValid() 794 795 select { 796 case c.decrypted <- content.Data: 797 case <-c.closed.Done(): 798 case <-ctx.Done(): 799 } 800 801 default: 802 return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) 803 } 804 return false, nil, nil 805 } 806 807 func (c *Conn) recvHandshake() <-chan chan struct{} { 808 return c.handshakeRecv 809 } 810 811 func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { 812 if level == alert.Fatal && len(c.state.SessionID) > 0 { 813 // According to the RFC, we need to delete the stored session. 814 // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 815 if ss := c.fsm.cfg.sessionStore; ss != nil { 816 c.log.Tracef("clean invalid session: %s", c.state.SessionID) 817 if err := ss.Del(c.sessionKey()); err != nil { 818 return err 819 } 820 } 821 } 822 return c.writePackets(ctx, []*packet{ 823 { 824 record: &recordlayer.RecordLayer{ 825 Header: recordlayer.Header{ 826 Epoch: c.state.getLocalEpoch(), 827 Version: protocol.Version1_2, 828 }, 829 Content: &alert.Alert{ 830 Level: level, 831 Description: desc, 832 }, 833 }, 834 shouldEncrypt: c.isHandshakeCompletedSuccessfully(), 835 }, 836 }) 837 } 838 839 func (c *Conn) setHandshakeCompletedSuccessfully() { 840 c.handshakeCompletedSuccessfully.Store(struct{ bool }{true}) 841 } 842 843 func (c *Conn) isHandshakeCompletedSuccessfully() bool { 844 boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool }) 845 return boolean.bool 846 } 847 848 func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit 849 c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) 850 851 done := make(chan struct{}) 852 ctxRead, cancelRead := context.WithCancel(context.Background()) 853 c.cancelHandshakeReader = cancelRead 854 cfg.onFlightState = func(f flightVal, s handshakeState) { 855 if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() { 856 c.setHandshakeCompletedSuccessfully() 857 close(done) 858 } 859 } 860 861 ctxHs, cancel := context.WithCancel(context.Background()) 862 c.cancelHandshaker = cancel 863 864 firstErr := make(chan error, 1) 865 866 c.handshakeLoopsFinished.Add(2) 867 868 // Handshake routine should be live until close. 869 // The other party may request retransmission of the last flight to cope with packet drop. 870 go func() { 871 defer c.handshakeLoopsFinished.Done() 872 err := c.fsm.Run(ctxHs, c, initialState) 873 if !errors.Is(err, context.Canceled) { 874 select { 875 case firstErr <- err: 876 default: 877 } 878 } 879 }() 880 go func() { 881 defer func() { 882 // Escaping read loop. 883 // It's safe to close decrypted channnel now. 884 close(c.decrypted) 885 886 // Force stop handshaker when the underlying connection is closed. 887 cancel() 888 }() 889 defer c.handshakeLoopsFinished.Done() 890 for { 891 if err := c.readAndBuffer(ctxRead); err != nil { 892 var e *alertError 893 if errors.As(err, &e) { 894 if !e.IsFatalOrCloseNotify() { 895 if c.isHandshakeCompletedSuccessfully() { 896 // Pass the error to Read() 897 select { 898 case c.decrypted <- err: 899 case <-c.closed.Done(): 900 case <-ctxRead.Done(): 901 } 902 } 903 continue // non-fatal alert must not stop read loop 904 } 905 } else { 906 switch { 907 case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): 908 case errors.Is(err, recordlayer.ErrInvalidPacketLength): 909 // Decode error must be silently discarded 910 // [RFC6347 Section-4.1.2.7] 911 continue 912 default: 913 if c.isHandshakeCompletedSuccessfully() { 914 // Keep read loop and pass the read error to Read() 915 select { 916 case c.decrypted <- err: 917 case <-c.closed.Done(): 918 case <-ctxRead.Done(): 919 } 920 continue // non-fatal alert must not stop read loop 921 } 922 } 923 } 924 925 select { 926 case firstErr <- err: 927 default: 928 } 929 930 if e != nil { 931 if e.IsFatalOrCloseNotify() { 932 _ = c.close(false) //nolint:contextcheck 933 } 934 } 935 if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { 936 c.log.Trace("handshake timeouts - closing underline connection") 937 _ = c.close(false) //nolint:contextcheck 938 } 939 return 940 } 941 } 942 }() 943 944 select { 945 case err := <-firstErr: 946 cancelRead() 947 cancel() 948 c.handshakeLoopsFinished.Wait() 949 return c.translateHandshakeCtxError(err) 950 case <-ctx.Done(): 951 cancelRead() 952 cancel() 953 c.handshakeLoopsFinished.Wait() 954 return c.translateHandshakeCtxError(ctx.Err()) 955 case <-done: 956 return nil 957 } 958 } 959 960 func (c *Conn) translateHandshakeCtxError(err error) error { 961 if err == nil { 962 return nil 963 } 964 if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { 965 return nil 966 } 967 return &HandshakeError{Err: err} 968 } 969 970 func (c *Conn) close(byUser bool) error { 971 c.cancelHandshaker() 972 c.cancelHandshakeReader() 973 974 if c.isHandshakeCompletedSuccessfully() && byUser { 975 // Discard error from notify() to return non-error on the first user call of Close() 976 // even if the underlying connection is already closed. 977 _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) 978 } 979 980 c.closeLock.Lock() 981 // Don't return ErrConnClosed at the first time of the call from user. 982 closedByUser := c.connectionClosedByUser 983 if byUser { 984 c.connectionClosedByUser = true 985 } 986 isClosed := c.isConnectionClosed() 987 c.closed.Close() 988 c.closeLock.Unlock() 989 990 if closedByUser { 991 return ErrConnClosed 992 } 993 994 if isClosed { 995 return nil 996 } 997 998 return c.nextConn.Close() 999 } 1000 1001 func (c *Conn) isConnectionClosed() bool { 1002 select { 1003 case <-c.closed.Done(): 1004 return true 1005 default: 1006 return false 1007 } 1008 } 1009 1010 func (c *Conn) setLocalEpoch(epoch uint16) { 1011 c.state.localEpoch.Store(epoch) 1012 } 1013 1014 func (c *Conn) setRemoteEpoch(epoch uint16) { 1015 c.state.remoteEpoch.Store(epoch) 1016 } 1017 1018 // LocalAddr implements net.Conn.LocalAddr 1019 func (c *Conn) LocalAddr() net.Addr { 1020 return c.nextConn.LocalAddr() 1021 } 1022 1023 // RemoteAddr implements net.Conn.RemoteAddr 1024 func (c *Conn) RemoteAddr() net.Addr { 1025 return c.nextConn.RemoteAddr() 1026 } 1027 1028 func (c *Conn) sessionKey() []byte { 1029 if c.state.isClient { 1030 // As ServerName can be like 0.example.com, it's better to add 1031 // delimiter character which is not allowed to be in 1032 // neither address or domain name. 1033 return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName) 1034 } 1035 return c.state.SessionID 1036 } 1037 1038 // SetDeadline implements net.Conn.SetDeadline 1039 func (c *Conn) SetDeadline(t time.Time) error { 1040 c.readDeadline.Set(t) 1041 return c.SetWriteDeadline(t) 1042 } 1043 1044 // SetReadDeadline implements net.Conn.SetReadDeadline 1045 func (c *Conn) SetReadDeadline(t time.Time) error { 1046 c.readDeadline.Set(t) 1047 // Read deadline is fully managed by this layer. 1048 // Don't set read deadline to underlying connection. 1049 return nil 1050 } 1051 1052 // SetWriteDeadline implements net.Conn.SetWriteDeadline 1053 func (c *Conn) SetWriteDeadline(t time.Time) error { 1054 c.writeDeadline.Set(t) 1055 // Write deadline is also fully managed by this layer. 1056 return nil 1057 }