github.com/pion/dtls/v2@v2.2.12/handshaker.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"fmt"
    11  	"io"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    16  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    17  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    18  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    19  	"github.com/pion/logging"
    20  )
    21  
    22  // [RFC6347 Section-4.2.4]
    23  //                      +-----------+
    24  //                +---> | PREPARING | <--------------------+
    25  //                |     +-----------+                      |
    26  //                |           |                            |
    27  //                |           | Buffer next flight         |
    28  //                |           |                            |
    29  //                |          \|/                           |
    30  //                |     +-----------+                      |
    31  //                |     |  SENDING  |<------------------+  | Send
    32  //                |     +-----------+                   |  | HelloRequest
    33  //        Receive |           |                         |  |
    34  //           next |           | Send flight             |  | or
    35  //         flight |  +--------+                         |  |
    36  //                |  |        | Set retransmit timer    |  | Receive
    37  //                |  |       \|/                        |  | HelloRequest
    38  //                |  |  +-----------+                   |  | Send
    39  //                +--)--|  WAITING  |-------------------+  | ClientHello
    40  //                |  |  +-----------+   Timer expires   |  |
    41  //                |  |         |                        |  |
    42  //                |  |         +------------------------+  |
    43  //        Receive |  | Send           Read retransmit      |
    44  //           last |  | last                                |
    45  //         flight |  | flight                              |
    46  //                |  |                                     |
    47  //               \|/\|/                                    |
    48  //            +-----------+                                |
    49  //            | FINISHED  | -------------------------------+
    50  //            +-----------+
    51  //                 |  /|\
    52  //                 |   |
    53  //                 +---+
    54  //              Read retransmit
    55  //           Retransmit last flight
    56  
    57  type handshakeState uint8
    58  
    59  const (
    60  	handshakeErrored handshakeState = iota
    61  	handshakePreparing
    62  	handshakeSending
    63  	handshakeWaiting
    64  	handshakeFinished
    65  )
    66  
    67  func (s handshakeState) String() string {
    68  	switch s {
    69  	case handshakeErrored:
    70  		return "Errored"
    71  	case handshakePreparing:
    72  		return "Preparing"
    73  	case handshakeSending:
    74  		return "Sending"
    75  	case handshakeWaiting:
    76  		return "Waiting"
    77  	case handshakeFinished:
    78  		return "Finished"
    79  	default:
    80  		return "Unknown"
    81  	}
    82  }
    83  
    84  type handshakeFSM struct {
    85  	currentFlight flightVal
    86  	flights       []*packet
    87  	retransmit    bool
    88  	state         *State
    89  	cache         *handshakeCache
    90  	cfg           *handshakeConfig
    91  	closed        chan struct{}
    92  }
    93  
    94  type handshakeConfig struct {
    95  	localPSKCallback            PSKCallback
    96  	localPSKIdentityHint        []byte
    97  	localCipherSuites           []CipherSuite             // Available CipherSuites
    98  	localSignatureSchemes       []signaturehash.Algorithm // Available signature schemes
    99  	extendedMasterSecret        ExtendedMasterSecretType  // Policy for the Extended Master Support extension
   100  	localSRTPProtectionProfiles []SRTPProtectionProfile   // Available SRTPProtectionProfiles, if empty no SRTP support
   101  	serverName                  string
   102  	supportedProtocols          []string
   103  	clientAuth                  ClientAuthType // If we are a client should we request a client certificate
   104  	localCertificates           []tls.Certificate
   105  	nameToCertificate           map[string]*tls.Certificate
   106  	insecureSkipVerify          bool
   107  	verifyPeerCertificate       func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
   108  	verifyConnection            func(*State) error
   109  	sessionStore                SessionStore
   110  	rootCAs                     *x509.CertPool
   111  	clientCAs                   *x509.CertPool
   112  	retransmitInterval          time.Duration
   113  	customCipherSuites          func() []CipherSuite
   114  	ellipticCurves              []elliptic.Curve
   115  	insecureSkipHelloVerify     bool
   116  
   117  	onFlightState func(flightVal, handshakeState)
   118  	log           logging.LeveledLogger
   119  	keyLogWriter  io.Writer
   120  
   121  	localGetCertificate       func(*ClientHelloInfo) (*tls.Certificate, error)
   122  	localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error)
   123  
   124  	initialEpoch uint16
   125  
   126  	mu sync.Mutex
   127  }
   128  
   129  type flightConn interface {
   130  	notify(ctx context.Context, level alert.Level, desc alert.Description) error
   131  	writePackets(context.Context, []*packet) error
   132  	recvHandshake() <-chan chan struct{}
   133  	setLocalEpoch(epoch uint16)
   134  	handleQueuedPackets(context.Context) error
   135  	sessionKey() []byte
   136  }
   137  
   138  func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
   139  	if c.keyLogWriter == nil {
   140  		return
   141  	}
   142  	c.mu.Lock()
   143  	defer c.mu.Unlock()
   144  	_, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
   145  	if err != nil {
   146  		c.log.Debugf("failed to write key log file: %s", err)
   147  	}
   148  }
   149  
   150  func srvCliStr(isClient bool) string {
   151  	if isClient {
   152  		return "client"
   153  	}
   154  	return "server"
   155  }
   156  
   157  func newHandshakeFSM(
   158  	s *State, cache *handshakeCache, cfg *handshakeConfig,
   159  	initialFlight flightVal,
   160  ) *handshakeFSM {
   161  	return &handshakeFSM{
   162  		currentFlight: initialFlight,
   163  		state:         s,
   164  		cache:         cache,
   165  		cfg:           cfg,
   166  		closed:        make(chan struct{}),
   167  	}
   168  }
   169  
   170  func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
   171  	state := initialState
   172  	defer func() {
   173  		close(s.closed)
   174  	}()
   175  	for {
   176  		s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
   177  		if s.cfg.onFlightState != nil {
   178  			s.cfg.onFlightState(s.currentFlight, state)
   179  		}
   180  		var err error
   181  		switch state {
   182  		case handshakePreparing:
   183  			state, err = s.prepare(ctx, c)
   184  		case handshakeSending:
   185  			state, err = s.send(ctx, c)
   186  		case handshakeWaiting:
   187  			state, err = s.wait(ctx, c)
   188  		case handshakeFinished:
   189  			state, err = s.finish(ctx, c)
   190  		default:
   191  			return errInvalidFSMTransition
   192  		}
   193  		if err != nil {
   194  			return err
   195  		}
   196  	}
   197  }
   198  
   199  func (s *handshakeFSM) Done() <-chan struct{} {
   200  	return s.closed
   201  }
   202  
   203  func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
   204  	s.flights = nil
   205  	// Prepare flights
   206  	var (
   207  		a    *alert.Alert
   208  		err  error
   209  		pkts []*packet
   210  	)
   211  	gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
   212  	if errFlight != nil {
   213  		err = errFlight
   214  		a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
   215  	} else {
   216  		pkts, a, err = gen(c, s.state, s.cache, s.cfg)
   217  		s.retransmit = retransmit
   218  	}
   219  	if a != nil {
   220  		if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
   221  			if err != nil {
   222  				err = alertErr
   223  			}
   224  		}
   225  	}
   226  	if err != nil {
   227  		return handshakeErrored, err
   228  	}
   229  
   230  	s.flights = pkts
   231  	epoch := s.cfg.initialEpoch
   232  	nextEpoch := epoch
   233  	for _, p := range s.flights {
   234  		p.record.Header.Epoch += epoch
   235  		if p.record.Header.Epoch > nextEpoch {
   236  			nextEpoch = p.record.Header.Epoch
   237  		}
   238  		if h, ok := p.record.Content.(*handshake.Handshake); ok {
   239  			h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
   240  			s.state.handshakeSendSequence++
   241  		}
   242  	}
   243  	if epoch != nextEpoch {
   244  		s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
   245  		c.setLocalEpoch(nextEpoch)
   246  	}
   247  	return handshakeSending, nil
   248  }
   249  
   250  func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
   251  	// Send flights
   252  	if err := c.writePackets(ctx, s.flights); err != nil {
   253  		return handshakeErrored, err
   254  	}
   255  
   256  	if s.currentFlight.isLastSendFlight() {
   257  		return handshakeFinished, nil
   258  	}
   259  	return handshakeWaiting, nil
   260  }
   261  
   262  func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
   263  	parse, errFlight := s.currentFlight.getFlightParser()
   264  	if errFlight != nil {
   265  		if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
   266  			if errFlight != nil {
   267  				return handshakeErrored, alertErr
   268  			}
   269  		}
   270  		return handshakeErrored, errFlight
   271  	}
   272  
   273  	retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
   274  	for {
   275  		select {
   276  		case done := <-c.recvHandshake():
   277  			nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
   278  			close(done)
   279  			if alert != nil {
   280  				if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
   281  					if err != nil {
   282  						err = alertErr
   283  					}
   284  				}
   285  			}
   286  			if err != nil {
   287  				return handshakeErrored, err
   288  			}
   289  			if nextFlight == 0 {
   290  				break
   291  			}
   292  			s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
   293  			if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
   294  				return handshakeFinished, nil
   295  			}
   296  			s.currentFlight = nextFlight
   297  			return handshakePreparing, nil
   298  
   299  		case <-retransmitTimer.C:
   300  			if !s.retransmit {
   301  				return handshakeWaiting, nil
   302  			}
   303  			return handshakeSending, nil
   304  		case <-ctx.Done():
   305  			return handshakeErrored, ctx.Err()
   306  		}
   307  	}
   308  }
   309  
   310  func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
   311  	parse, errFlight := s.currentFlight.getFlightParser()
   312  	if errFlight != nil {
   313  		if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
   314  			if errFlight != nil {
   315  				return handshakeErrored, alertErr
   316  			}
   317  		}
   318  		return handshakeErrored, errFlight
   319  	}
   320  
   321  	retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
   322  	select {
   323  	case done := <-c.recvHandshake():
   324  		nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
   325  		close(done)
   326  		if alert != nil {
   327  			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
   328  				if err != nil {
   329  					err = alertErr
   330  				}
   331  			}
   332  		}
   333  		if err != nil {
   334  			return handshakeErrored, err
   335  		}
   336  		if nextFlight == 0 {
   337  			break
   338  		}
   339  		if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
   340  			return handshakeFinished, nil
   341  		}
   342  		<-retransmitTimer.C
   343  		// Retransmit last flight
   344  		return handshakeSending, nil
   345  
   346  	case <-ctx.Done():
   347  		return handshakeErrored, ctx.Err()
   348  	}
   349  	return handshakeFinished, nil
   350  }