github.com/pion/dtls/v2@v2.2.12/conn.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/pion/dtls/v2/internal/closer"
    17  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    18  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    19  	"github.com/pion/dtls/v2/pkg/protocol"
    20  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    21  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    22  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    23  	"github.com/pion/logging"
    24  	"github.com/pion/transport/v2/connctx"
    25  	"github.com/pion/transport/v2/deadline"
    26  	"github.com/pion/transport/v2/replaydetector"
    27  )
    28  
    29  const (
    30  	initialTickerInterval = time.Second
    31  	cookieLength          = 20
    32  	sessionLength         = 32
    33  	defaultNamedCurve     = elliptic.X25519
    34  	inboundBufferSize     = 8192
    35  	// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
    36  	defaultReplayProtectionWindow = 64
    37  	// maxAppDataPacketQueueSize is the maximum number of app data packets we will
    38  	// enqueue before the handshake is completed
    39  	maxAppDataPacketQueueSize = 100
    40  )
    41  
    42  func invalidKeyingLabels() map[string]bool {
    43  	return map[string]bool{
    44  		"client finished": true,
    45  		"server finished": true,
    46  		"master secret":   true,
    47  		"key expansion":   true,
    48  	}
    49  }
    50  
    51  // Conn represents a DTLS connection
    52  type Conn struct {
    53  	lock           sync.RWMutex     // Internal lock (must not be public)
    54  	nextConn       connctx.ConnCtx  // Embedded Conn, typically a udpconn we read/write from
    55  	fragmentBuffer *fragmentBuffer  // out-of-order and missing fragment handling
    56  	handshakeCache *handshakeCache  // caching of handshake messages for verifyData generation
    57  	decrypted      chan interface{} // Decrypted Application Data or error, pull by calling `Read`
    58  
    59  	state State // Internal state
    60  
    61  	maximumTransmissionUnit int
    62  
    63  	handshakeCompletedSuccessfully atomic.Value
    64  
    65  	encryptedPackets [][]byte
    66  
    67  	connectionClosedByUser bool
    68  	closeLock              sync.Mutex
    69  	closed                 *closer.Closer
    70  	handshakeLoopsFinished sync.WaitGroup
    71  
    72  	readDeadline  *deadline.Deadline
    73  	writeDeadline *deadline.Deadline
    74  
    75  	log logging.LeveledLogger
    76  
    77  	reading               chan struct{}
    78  	handshakeRecv         chan chan struct{}
    79  	cancelHandshaker      func()
    80  	cancelHandshakeReader func()
    81  
    82  	fsm *handshakeFSM
    83  
    84  	replayProtectionWindow uint
    85  }
    86  
    87  func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) {
    88  	err := validateConfig(config)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	if nextConn == nil {
    94  		return nil, errNilNextConn
    95  	}
    96  
    97  	loggerFactory := config.LoggerFactory
    98  	if loggerFactory == nil {
    99  		loggerFactory = logging.NewDefaultLoggerFactory()
   100  	}
   101  
   102  	logger := loggerFactory.NewLogger("dtls")
   103  
   104  	mtu := config.MTU
   105  	if mtu <= 0 {
   106  		mtu = defaultMTU
   107  	}
   108  
   109  	replayProtectionWindow := config.ReplayProtectionWindow
   110  	if replayProtectionWindow <= 0 {
   111  		replayProtectionWindow = defaultReplayProtectionWindow
   112  	}
   113  
   114  	c := &Conn{
   115  		nextConn:                connctx.New(nextConn),
   116  		fragmentBuffer:          newFragmentBuffer(),
   117  		handshakeCache:          newHandshakeCache(),
   118  		maximumTransmissionUnit: mtu,
   119  
   120  		decrypted: make(chan interface{}, 1),
   121  		log:       logger,
   122  
   123  		readDeadline:  deadline.New(),
   124  		writeDeadline: deadline.New(),
   125  
   126  		reading:          make(chan struct{}, 1),
   127  		handshakeRecv:    make(chan chan struct{}),
   128  		closed:           closer.NewCloser(),
   129  		cancelHandshaker: func() {},
   130  
   131  		replayProtectionWindow: uint(replayProtectionWindow),
   132  
   133  		state: State{
   134  			isClient: isClient,
   135  		},
   136  	}
   137  
   138  	c.setRemoteEpoch(0)
   139  	c.setLocalEpoch(0)
   140  	return c, nil
   141  }
   142  
   143  func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
   144  	if conn == nil {
   145  		return nil, errNilNextConn
   146  	}
   147  
   148  	cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	workerInterval := initialTickerInterval
   159  	if config.FlightInterval != 0 {
   160  		workerInterval = config.FlightInterval
   161  	}
   162  
   163  	serverName := config.ServerName
   164  	// Do not allow the use of an IP address literal as an SNI value.
   165  	// See RFC 6066, Section 3.
   166  	if net.ParseIP(serverName) != nil {
   167  		serverName = ""
   168  	}
   169  
   170  	curves := config.EllipticCurves
   171  	if len(curves) == 0 {
   172  		curves = defaultCurves
   173  	}
   174  
   175  	hsCfg := &handshakeConfig{
   176  		localPSKCallback:            config.PSK,
   177  		localPSKIdentityHint:        config.PSKIdentityHint,
   178  		localCipherSuites:           cipherSuites,
   179  		localSignatureSchemes:       signatureSchemes,
   180  		extendedMasterSecret:        config.ExtendedMasterSecret,
   181  		localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
   182  		serverName:                  serverName,
   183  		supportedProtocols:          config.SupportedProtocols,
   184  		clientAuth:                  config.ClientAuth,
   185  		localCertificates:           config.Certificates,
   186  		insecureSkipVerify:          config.InsecureSkipVerify,
   187  		verifyPeerCertificate:       config.VerifyPeerCertificate,
   188  		verifyConnection:            config.VerifyConnection,
   189  		rootCAs:                     config.RootCAs,
   190  		clientCAs:                   config.ClientCAs,
   191  		customCipherSuites:          config.CustomCipherSuites,
   192  		retransmitInterval:          workerInterval,
   193  		log:                         conn.log,
   194  		initialEpoch:                0,
   195  		keyLogWriter:                config.KeyLogWriter,
   196  		sessionStore:                config.SessionStore,
   197  		ellipticCurves:              curves,
   198  		localGetCertificate:         config.GetCertificate,
   199  		localGetClientCertificate:   config.GetClientCertificate,
   200  		insecureSkipHelloVerify:     config.InsecureSkipVerifyHello,
   201  	}
   202  
   203  	// rfc5246#section-7.4.3
   204  	// In addition, the hash and signature algorithms MUST be compatible
   205  	// with the key in the server's end-entity certificate.
   206  	if !isClient {
   207  		cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
   208  		if err != nil && !errors.Is(err, errNoCertificates) {
   209  			return nil, err
   210  		}
   211  		hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
   212  	}
   213  
   214  	var initialFlight flightVal
   215  	var initialFSMState handshakeState
   216  
   217  	if initialState != nil {
   218  		if conn.state.isClient {
   219  			initialFlight = flight5
   220  		} else {
   221  			initialFlight = flight6
   222  		}
   223  		initialFSMState = handshakeFinished
   224  
   225  		conn.state = *initialState
   226  	} else {
   227  		if conn.state.isClient {
   228  			initialFlight = flight1
   229  		} else {
   230  			initialFlight = flight0
   231  		}
   232  		initialFSMState = handshakePreparing
   233  	}
   234  	// Do handshake
   235  	if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	conn.log.Trace("Handshake Completed")
   240  
   241  	return conn, nil
   242  }
   243  
   244  // Dial connects to the given network address and establishes a DTLS connection on top.
   245  // Connection handshake will timeout using ConnectContextMaker in the Config.
   246  // If you want to specify the timeout duration, use DialWithContext() instead.
   247  func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
   248  	ctx, cancel := config.connectContextMaker()
   249  	defer cancel()
   250  
   251  	return DialWithContext(ctx, network, raddr, config)
   252  }
   253  
   254  // Client establishes a DTLS connection over an existing connection.
   255  // Connection handshake will timeout using ConnectContextMaker in the Config.
   256  // If you want to specify the timeout duration, use ClientWithContext() instead.
   257  func Client(conn net.Conn, config *Config) (*Conn, error) {
   258  	ctx, cancel := config.connectContextMaker()
   259  	defer cancel()
   260  
   261  	return ClientWithContext(ctx, conn, config)
   262  }
   263  
   264  // Server listens for incoming DTLS connections.
   265  // Connection handshake will timeout using ConnectContextMaker in the Config.
   266  // If you want to specify the timeout duration, use ServerWithContext() instead.
   267  func Server(conn net.Conn, config *Config) (*Conn, error) {
   268  	ctx, cancel := config.connectContextMaker()
   269  	defer cancel()
   270  
   271  	return ServerWithContext(ctx, conn, config)
   272  }
   273  
   274  // DialWithContext connects to the given network address and establishes a DTLS connection on top.
   275  func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
   276  	pConn, err := net.DialUDP(network, nil, raddr)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	return ClientWithContext(ctx, pConn, config)
   281  }
   282  
   283  // ClientWithContext establishes a DTLS connection over an existing connection.
   284  func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
   285  	switch {
   286  	case config == nil:
   287  		return nil, errNoConfigProvided
   288  	case config.PSK != nil && config.PSKIdentityHint == nil:
   289  		return nil, errPSKAndIdentityMustBeSetForClient
   290  	}
   291  
   292  	dconn, err := createConn(conn, config, true)
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  
   297  	return handshakeConn(ctx, dconn, config, true, nil)
   298  }
   299  
   300  // ServerWithContext listens for incoming DTLS connections.
   301  func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
   302  	if config == nil {
   303  		return nil, errNoConfigProvided
   304  	}
   305  	dconn, err := createConn(conn, config, false)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  	return handshakeConn(ctx, dconn, config, false, nil)
   310  }
   311  
   312  // Read reads data from the connection.
   313  func (c *Conn) Read(p []byte) (n int, err error) {
   314  	if !c.isHandshakeCompletedSuccessfully() {
   315  		return 0, errHandshakeInProgress
   316  	}
   317  
   318  	select {
   319  	case <-c.readDeadline.Done():
   320  		return 0, errDeadlineExceeded
   321  	default:
   322  	}
   323  
   324  	for {
   325  		select {
   326  		case <-c.readDeadline.Done():
   327  			return 0, errDeadlineExceeded
   328  		case out, ok := <-c.decrypted:
   329  			if !ok {
   330  				return 0, io.EOF
   331  			}
   332  			switch val := out.(type) {
   333  			case ([]byte):
   334  				if len(p) < len(val) {
   335  					return 0, errBufferTooSmall
   336  				}
   337  				copy(p, val)
   338  				return len(val), nil
   339  			case (error):
   340  				return 0, val
   341  			}
   342  		}
   343  	}
   344  }
   345  
   346  // Write writes len(p) bytes from p to the DTLS connection
   347  func (c *Conn) Write(p []byte) (int, error) {
   348  	if c.isConnectionClosed() {
   349  		return 0, ErrConnClosed
   350  	}
   351  
   352  	select {
   353  	case <-c.writeDeadline.Done():
   354  		return 0, errDeadlineExceeded
   355  	default:
   356  	}
   357  
   358  	if !c.isHandshakeCompletedSuccessfully() {
   359  		return 0, errHandshakeInProgress
   360  	}
   361  
   362  	return len(p), c.writePackets(c.writeDeadline, []*packet{
   363  		{
   364  			record: &recordlayer.RecordLayer{
   365  				Header: recordlayer.Header{
   366  					Epoch:   c.state.getLocalEpoch(),
   367  					Version: protocol.Version1_2,
   368  				},
   369  				Content: &protocol.ApplicationData{
   370  					Data: p,
   371  				},
   372  			},
   373  			shouldEncrypt: true,
   374  		},
   375  	})
   376  }
   377  
   378  // Close closes the connection.
   379  func (c *Conn) Close() error {
   380  	err := c.close(true) //nolint:contextcheck
   381  	c.handshakeLoopsFinished.Wait()
   382  	return err
   383  }
   384  
   385  // ConnectionState returns basic DTLS details about the connection.
   386  // Note that this replaced the `Export` function of v1.
   387  func (c *Conn) ConnectionState() State {
   388  	c.lock.RLock()
   389  	defer c.lock.RUnlock()
   390  	return *c.state.clone()
   391  }
   392  
   393  // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
   394  func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
   395  	profile := c.state.getSRTPProtectionProfile()
   396  	if profile == 0 {
   397  		return 0, false
   398  	}
   399  
   400  	return profile, true
   401  }
   402  
   403  func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
   404  	c.lock.Lock()
   405  	defer c.lock.Unlock()
   406  
   407  	var rawPackets [][]byte
   408  
   409  	for _, p := range pkts {
   410  		if h, ok := p.record.Content.(*handshake.Handshake); ok {
   411  			handshakeRaw, err := p.record.Marshal()
   412  			if err != nil {
   413  				return err
   414  			}
   415  
   416  			c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
   417  				srvCliStr(c.state.isClient), h.Header.Type.String(),
   418  				p.record.Header.Epoch, h.Header.MessageSequence)
   419  			c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
   420  
   421  			rawHandshakePackets, err := c.processHandshakePacket(p, h)
   422  			if err != nil {
   423  				return err
   424  			}
   425  			rawPackets = append(rawPackets, rawHandshakePackets...)
   426  		} else {
   427  			rawPacket, err := c.processPacket(p)
   428  			if err != nil {
   429  				return err
   430  			}
   431  			rawPackets = append(rawPackets, rawPacket)
   432  		}
   433  	}
   434  	if len(rawPackets) == 0 {
   435  		return nil
   436  	}
   437  	compactedRawPackets := c.compactRawPackets(rawPackets)
   438  
   439  	for _, compactedRawPackets := range compactedRawPackets {
   440  		if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
   441  			return netError(err)
   442  		}
   443  	}
   444  
   445  	return nil
   446  }
   447  
   448  func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
   449  	// avoid a useless copy in the common case
   450  	if len(rawPackets) == 1 {
   451  		return rawPackets
   452  	}
   453  
   454  	combinedRawPackets := make([][]byte, 0)
   455  	currentCombinedRawPacket := make([]byte, 0)
   456  
   457  	for _, rawPacket := range rawPackets {
   458  		if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
   459  			combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
   460  			currentCombinedRawPacket = []byte{}
   461  		}
   462  		currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
   463  	}
   464  
   465  	combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
   466  
   467  	return combinedRawPackets
   468  }
   469  
   470  func (c *Conn) processPacket(p *packet) ([]byte, error) {
   471  	epoch := p.record.Header.Epoch
   472  	for len(c.state.localSequenceNumber) <= int(epoch) {
   473  		c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
   474  	}
   475  	seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
   476  	if seq > recordlayer.MaxSequenceNumber {
   477  		// RFC 6347 Section 4.1.0
   478  		// The implementation must either abandon an association or rehandshake
   479  		// prior to allowing the sequence number to wrap.
   480  		return nil, errSequenceNumberOverflow
   481  	}
   482  	p.record.Header.SequenceNumber = seq
   483  
   484  	rawPacket, err := p.record.Marshal()
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  
   489  	if p.shouldEncrypt {
   490  		var err error
   491  		rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
   492  		if err != nil {
   493  			return nil, err
   494  		}
   495  	}
   496  
   497  	return rawPacket, nil
   498  }
   499  
   500  func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
   501  	rawPackets := make([][]byte, 0)
   502  
   503  	handshakeFragments, err := c.fragmentHandshake(h)
   504  	if err != nil {
   505  		return nil, err
   506  	}
   507  	epoch := p.record.Header.Epoch
   508  	for len(c.state.localSequenceNumber) <= int(epoch) {
   509  		c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
   510  	}
   511  
   512  	for _, handshakeFragment := range handshakeFragments {
   513  		seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
   514  		if seq > recordlayer.MaxSequenceNumber {
   515  			return nil, errSequenceNumberOverflow
   516  		}
   517  
   518  		recordlayerHeader := &recordlayer.Header{
   519  			Version:        p.record.Header.Version,
   520  			ContentType:    p.record.Header.ContentType,
   521  			ContentLen:     uint16(len(handshakeFragment)),
   522  			Epoch:          p.record.Header.Epoch,
   523  			SequenceNumber: seq,
   524  		}
   525  
   526  		rawPacket, err := recordlayerHeader.Marshal()
   527  		if err != nil {
   528  			return nil, err
   529  		}
   530  
   531  		p.record.Header = *recordlayerHeader
   532  
   533  		rawPacket = append(rawPacket, handshakeFragment...)
   534  		if p.shouldEncrypt {
   535  			var err error
   536  			rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
   537  			if err != nil {
   538  				return nil, err
   539  			}
   540  		}
   541  
   542  		rawPackets = append(rawPackets, rawPacket)
   543  	}
   544  
   545  	return rawPackets, nil
   546  }
   547  
   548  func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
   549  	content, err := h.Message.Marshal()
   550  	if err != nil {
   551  		return nil, err
   552  	}
   553  
   554  	fragmentedHandshakes := make([][]byte, 0)
   555  
   556  	contentFragments := splitBytes(content, c.maximumTransmissionUnit)
   557  	if len(contentFragments) == 0 {
   558  		contentFragments = [][]byte{
   559  			{},
   560  		}
   561  	}
   562  
   563  	offset := 0
   564  	for _, contentFragment := range contentFragments {
   565  		contentFragmentLen := len(contentFragment)
   566  
   567  		headerFragment := &handshake.Header{
   568  			Type:            h.Header.Type,
   569  			Length:          h.Header.Length,
   570  			MessageSequence: h.Header.MessageSequence,
   571  			FragmentOffset:  uint32(offset),
   572  			FragmentLength:  uint32(contentFragmentLen),
   573  		}
   574  
   575  		offset += contentFragmentLen
   576  
   577  		fragmentedHandshake, err := headerFragment.Marshal()
   578  		if err != nil {
   579  			return nil, err
   580  		}
   581  
   582  		fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
   583  		fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
   584  	}
   585  
   586  	return fragmentedHandshakes, nil
   587  }
   588  
   589  var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
   590  	New: func() interface{} {
   591  		b := make([]byte, inboundBufferSize)
   592  		return &b
   593  	},
   594  }
   595  
   596  func (c *Conn) readAndBuffer(ctx context.Context) error {
   597  	bufptr, ok := poolReadBuffer.Get().(*[]byte)
   598  	if !ok {
   599  		return errFailedToAccessPoolReadBuffer
   600  	}
   601  	defer poolReadBuffer.Put(bufptr)
   602  
   603  	b := *bufptr
   604  	i, err := c.nextConn.ReadContext(ctx, b)
   605  	if err != nil {
   606  		return netError(err)
   607  	}
   608  
   609  	pkts, err := recordlayer.UnpackDatagram(b[:i])
   610  	if err != nil {
   611  		return err
   612  	}
   613  
   614  	var hasHandshake bool
   615  	for _, p := range pkts {
   616  		hs, alert, err := c.handleIncomingPacket(ctx, p, true)
   617  		if alert != nil {
   618  			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
   619  				if err == nil {
   620  					err = alertErr
   621  				}
   622  			}
   623  		}
   624  		if hs {
   625  			hasHandshake = true
   626  		}
   627  
   628  		if err != nil {
   629  			return err
   630  		}
   631  	}
   632  	if hasHandshake {
   633  		done := make(chan struct{})
   634  		select {
   635  		case c.handshakeRecv <- done:
   636  			// If the other party may retransmit the flight,
   637  			// we should respond even if it not a new message.
   638  			<-done
   639  		case <-c.fsm.Done():
   640  		}
   641  	}
   642  	return nil
   643  }
   644  
   645  func (c *Conn) handleQueuedPackets(ctx context.Context) error {
   646  	pkts := c.encryptedPackets
   647  	c.encryptedPackets = nil
   648  
   649  	for _, p := range pkts {
   650  		_, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
   651  		if alert != nil {
   652  			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
   653  				if err == nil {
   654  					err = alertErr
   655  				}
   656  			}
   657  		}
   658  		var e *alertError
   659  		if errors.As(err, &e) {
   660  			if e.IsFatalOrCloseNotify() {
   661  				return e
   662  			}
   663  		} else if err != nil {
   664  			return err
   665  		}
   666  	}
   667  	return nil
   668  }
   669  
   670  func (c *Conn) enqueueEncryptedPackets(packet []byte) bool {
   671  	if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
   672  		c.encryptedPackets = append(c.encryptedPackets, packet)
   673  		return true
   674  	}
   675  	return false
   676  }
   677  
   678  func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
   679  	h := &recordlayer.Header{}
   680  	if err := h.Unmarshal(buf); err != nil {
   681  		// Decode error must be silently discarded
   682  		// [RFC6347 Section-4.1.2.7]
   683  		c.log.Debugf("discarded broken packet: %v", err)
   684  		return false, nil, nil
   685  	}
   686  	// Validate epoch
   687  	remoteEpoch := c.state.getRemoteEpoch()
   688  	if h.Epoch > remoteEpoch {
   689  		if h.Epoch > remoteEpoch+1 {
   690  			c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
   691  				h.Epoch, h.SequenceNumber,
   692  			)
   693  			return false, nil, nil
   694  		}
   695  		if enqueue {
   696  			if ok := c.enqueueEncryptedPackets(buf); ok {
   697  				c.log.Debug("received packet of next epoch, queuing packet")
   698  			}
   699  		}
   700  		return false, nil, nil
   701  	}
   702  
   703  	// Anti-replay protection
   704  	for len(c.state.replayDetector) <= int(h.Epoch) {
   705  		c.state.replayDetector = append(c.state.replayDetector,
   706  			replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
   707  		)
   708  	}
   709  	markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
   710  	if !ok {
   711  		c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
   712  			h.Epoch, h.SequenceNumber,
   713  		)
   714  		return false, nil, nil
   715  	}
   716  
   717  	// Decrypt
   718  	if h.Epoch != 0 {
   719  		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
   720  			if enqueue {
   721  				if ok := c.enqueueEncryptedPackets(buf); ok {
   722  					c.log.Debug("handshake not finished, queuing packet")
   723  				}
   724  			}
   725  			return false, nil, nil
   726  		}
   727  
   728  		var err error
   729  		buf, err = c.state.cipherSuite.Decrypt(buf)
   730  		if err != nil {
   731  			c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
   732  			return false, nil, nil
   733  		}
   734  	}
   735  
   736  	isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
   737  	if err != nil {
   738  		// Decode error must be silently discarded
   739  		// [RFC6347 Section-4.1.2.7]
   740  		c.log.Debugf("defragment failed: %s", err)
   741  		return false, nil, nil
   742  	} else if isHandshake {
   743  		markPacketAsValid()
   744  		for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
   745  			header := &handshake.Header{}
   746  			if err := header.Unmarshal(out); err != nil {
   747  				c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
   748  				continue
   749  			}
   750  			c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
   751  		}
   752  
   753  		return true, nil, nil
   754  	}
   755  
   756  	r := &recordlayer.RecordLayer{}
   757  	if err := r.Unmarshal(buf); err != nil {
   758  		return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
   759  	}
   760  
   761  	switch content := r.Content.(type) {
   762  	case *alert.Alert:
   763  		c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
   764  		var a *alert.Alert
   765  		if content.Description == alert.CloseNotify {
   766  			// Respond with a close_notify [RFC5246 Section 7.2.1]
   767  			a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
   768  		}
   769  		markPacketAsValid()
   770  		return false, a, &alertError{content}
   771  	case *protocol.ChangeCipherSpec:
   772  		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
   773  			if enqueue {
   774  				if ok := c.enqueueEncryptedPackets(buf); ok {
   775  					c.log.Debugf("CipherSuite not initialized, queuing packet")
   776  				}
   777  			}
   778  			return false, nil, nil
   779  		}
   780  
   781  		newRemoteEpoch := h.Epoch + 1
   782  		c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
   783  
   784  		if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
   785  			c.setRemoteEpoch(newRemoteEpoch)
   786  			markPacketAsValid()
   787  		}
   788  	case *protocol.ApplicationData:
   789  		if h.Epoch == 0 {
   790  			return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
   791  		}
   792  
   793  		markPacketAsValid()
   794  
   795  		select {
   796  		case c.decrypted <- content.Data:
   797  		case <-c.closed.Done():
   798  		case <-ctx.Done():
   799  		}
   800  
   801  	default:
   802  		return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
   803  	}
   804  	return false, nil, nil
   805  }
   806  
   807  func (c *Conn) recvHandshake() <-chan chan struct{} {
   808  	return c.handshakeRecv
   809  }
   810  
   811  func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
   812  	if level == alert.Fatal && len(c.state.SessionID) > 0 {
   813  		// According to the RFC, we need to delete the stored session.
   814  		// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
   815  		if ss := c.fsm.cfg.sessionStore; ss != nil {
   816  			c.log.Tracef("clean invalid session: %s", c.state.SessionID)
   817  			if err := ss.Del(c.sessionKey()); err != nil {
   818  				return err
   819  			}
   820  		}
   821  	}
   822  	return c.writePackets(ctx, []*packet{
   823  		{
   824  			record: &recordlayer.RecordLayer{
   825  				Header: recordlayer.Header{
   826  					Epoch:   c.state.getLocalEpoch(),
   827  					Version: protocol.Version1_2,
   828  				},
   829  				Content: &alert.Alert{
   830  					Level:       level,
   831  					Description: desc,
   832  				},
   833  			},
   834  			shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
   835  		},
   836  	})
   837  }
   838  
   839  func (c *Conn) setHandshakeCompletedSuccessfully() {
   840  	c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
   841  }
   842  
   843  func (c *Conn) isHandshakeCompletedSuccessfully() bool {
   844  	boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
   845  	return boolean.bool
   846  }
   847  
   848  func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
   849  	c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
   850  
   851  	done := make(chan struct{})
   852  	ctxRead, cancelRead := context.WithCancel(context.Background())
   853  	c.cancelHandshakeReader = cancelRead
   854  	cfg.onFlightState = func(f flightVal, s handshakeState) {
   855  		if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
   856  			c.setHandshakeCompletedSuccessfully()
   857  			close(done)
   858  		}
   859  	}
   860  
   861  	ctxHs, cancel := context.WithCancel(context.Background())
   862  	c.cancelHandshaker = cancel
   863  
   864  	firstErr := make(chan error, 1)
   865  
   866  	c.handshakeLoopsFinished.Add(2)
   867  
   868  	// Handshake routine should be live until close.
   869  	// The other party may request retransmission of the last flight to cope with packet drop.
   870  	go func() {
   871  		defer c.handshakeLoopsFinished.Done()
   872  		err := c.fsm.Run(ctxHs, c, initialState)
   873  		if !errors.Is(err, context.Canceled) {
   874  			select {
   875  			case firstErr <- err:
   876  			default:
   877  			}
   878  		}
   879  	}()
   880  	go func() {
   881  		defer func() {
   882  			// Escaping read loop.
   883  			// It's safe to close decrypted channnel now.
   884  			close(c.decrypted)
   885  
   886  			// Force stop handshaker when the underlying connection is closed.
   887  			cancel()
   888  		}()
   889  		defer c.handshakeLoopsFinished.Done()
   890  		for {
   891  			if err := c.readAndBuffer(ctxRead); err != nil {
   892  				var e *alertError
   893  				if errors.As(err, &e) {
   894  					if !e.IsFatalOrCloseNotify() {
   895  						if c.isHandshakeCompletedSuccessfully() {
   896  							// Pass the error to Read()
   897  							select {
   898  							case c.decrypted <- err:
   899  							case <-c.closed.Done():
   900  							case <-ctxRead.Done():
   901  							}
   902  						}
   903  						continue // non-fatal alert must not stop read loop
   904  					}
   905  				} else {
   906  					switch {
   907  					case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
   908  					case errors.Is(err, recordlayer.ErrInvalidPacketLength):
   909  						// Decode error must be silently discarded
   910  						// [RFC6347 Section-4.1.2.7]
   911  						continue
   912  					default:
   913  						if c.isHandshakeCompletedSuccessfully() {
   914  							// Keep read loop and pass the read error to Read()
   915  							select {
   916  							case c.decrypted <- err:
   917  							case <-c.closed.Done():
   918  							case <-ctxRead.Done():
   919  							}
   920  							continue // non-fatal alert must not stop read loop
   921  						}
   922  					}
   923  				}
   924  
   925  				select {
   926  				case firstErr <- err:
   927  				default:
   928  				}
   929  
   930  				if e != nil {
   931  					if e.IsFatalOrCloseNotify() {
   932  						_ = c.close(false) //nolint:contextcheck
   933  					}
   934  				}
   935  				if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
   936  					c.log.Trace("handshake timeouts - closing underline connection")
   937  					_ = c.close(false) //nolint:contextcheck
   938  				}
   939  				return
   940  			}
   941  		}
   942  	}()
   943  
   944  	select {
   945  	case err := <-firstErr:
   946  		cancelRead()
   947  		cancel()
   948  		c.handshakeLoopsFinished.Wait()
   949  		return c.translateHandshakeCtxError(err)
   950  	case <-ctx.Done():
   951  		cancelRead()
   952  		cancel()
   953  		c.handshakeLoopsFinished.Wait()
   954  		return c.translateHandshakeCtxError(ctx.Err())
   955  	case <-done:
   956  		return nil
   957  	}
   958  }
   959  
   960  func (c *Conn) translateHandshakeCtxError(err error) error {
   961  	if err == nil {
   962  		return nil
   963  	}
   964  	if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
   965  		return nil
   966  	}
   967  	return &HandshakeError{Err: err}
   968  }
   969  
   970  func (c *Conn) close(byUser bool) error {
   971  	c.cancelHandshaker()
   972  	c.cancelHandshakeReader()
   973  
   974  	if c.isHandshakeCompletedSuccessfully() && byUser {
   975  		// Discard error from notify() to return non-error on the first user call of Close()
   976  		// even if the underlying connection is already closed.
   977  		_ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
   978  	}
   979  
   980  	c.closeLock.Lock()
   981  	// Don't return ErrConnClosed at the first time of the call from user.
   982  	closedByUser := c.connectionClosedByUser
   983  	if byUser {
   984  		c.connectionClosedByUser = true
   985  	}
   986  	isClosed := c.isConnectionClosed()
   987  	c.closed.Close()
   988  	c.closeLock.Unlock()
   989  
   990  	if closedByUser {
   991  		return ErrConnClosed
   992  	}
   993  
   994  	if isClosed {
   995  		return nil
   996  	}
   997  
   998  	return c.nextConn.Close()
   999  }
  1000  
  1001  func (c *Conn) isConnectionClosed() bool {
  1002  	select {
  1003  	case <-c.closed.Done():
  1004  		return true
  1005  	default:
  1006  		return false
  1007  	}
  1008  }
  1009  
  1010  func (c *Conn) setLocalEpoch(epoch uint16) {
  1011  	c.state.localEpoch.Store(epoch)
  1012  }
  1013  
  1014  func (c *Conn) setRemoteEpoch(epoch uint16) {
  1015  	c.state.remoteEpoch.Store(epoch)
  1016  }
  1017  
  1018  // LocalAddr implements net.Conn.LocalAddr
  1019  func (c *Conn) LocalAddr() net.Addr {
  1020  	return c.nextConn.LocalAddr()
  1021  }
  1022  
  1023  // RemoteAddr implements net.Conn.RemoteAddr
  1024  func (c *Conn) RemoteAddr() net.Addr {
  1025  	return c.nextConn.RemoteAddr()
  1026  }
  1027  
  1028  func (c *Conn) sessionKey() []byte {
  1029  	if c.state.isClient {
  1030  		// As ServerName can be like 0.example.com, it's better to add
  1031  		// delimiter character which is not allowed to be in
  1032  		// neither address or domain name.
  1033  		return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
  1034  	}
  1035  	return c.state.SessionID
  1036  }
  1037  
  1038  // SetDeadline implements net.Conn.SetDeadline
  1039  func (c *Conn) SetDeadline(t time.Time) error {
  1040  	c.readDeadline.Set(t)
  1041  	return c.SetWriteDeadline(t)
  1042  }
  1043  
  1044  // SetReadDeadline implements net.Conn.SetReadDeadline
  1045  func (c *Conn) SetReadDeadline(t time.Time) error {
  1046  	c.readDeadline.Set(t)
  1047  	// Read deadline is fully managed by this layer.
  1048  	// Don't set read deadline to underlying connection.
  1049  	return nil
  1050  }
  1051  
  1052  // SetWriteDeadline implements net.Conn.SetWriteDeadline
  1053  func (c *Conn) SetWriteDeadline(t time.Time) error {
  1054  	c.writeDeadline.Set(t)
  1055  	// Write deadline is also fully managed by this layer.
  1056  	return nil
  1057  }