github.com/pion/dtls/v2@v2.2.12/flight3handler.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  
    10  	"github.com/pion/dtls/v2/internal/ciphersuite/types"
    11  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    12  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    13  	"github.com/pion/dtls/v2/pkg/protocol"
    14  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    15  	"github.com/pion/dtls/v2/pkg/protocol/extension"
    16  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    17  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    18  )
    19  
    20  func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
    21  	// Clients may receive multiple HelloVerifyRequest messages with different cookies.
    22  	// Clients SHOULD handle this by sending a new ClientHello with a cookie in response
    23  	// to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
    24  	seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    25  		handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
    26  	)
    27  	if ok {
    28  		if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
    29  			// DTLS 1.2 clients must not assume that the server will use the protocol version
    30  			// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
    31  			if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
    32  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
    33  			}
    34  			state.cookie = append([]byte{}, h.Cookie...)
    35  			state.handshakeRecvSequence = seq
    36  			return flight3, nil, nil
    37  		}
    38  	}
    39  
    40  	_, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    41  		handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
    42  	)
    43  	if !ok {
    44  		// Don't have enough messages. Keep reading
    45  		return 0, nil, nil
    46  	}
    47  
    48  	if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk {
    49  		if !h.Version.Equal(protocol.Version1_2) {
    50  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
    51  		}
    52  		for _, v := range h.Extensions {
    53  			switch e := v.(type) {
    54  			case *extension.UseSRTP:
    55  				profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
    56  				if !found {
    57  					return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
    58  				}
    59  				state.setSRTPProtectionProfile(profile)
    60  			case *extension.UseExtendedMasterSecret:
    61  				if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
    62  					state.extendedMasterSecret = true
    63  				}
    64  			case *extension.ALPN:
    65  				if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling
    66  					return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
    67  				}
    68  				state.NegotiatedProtocol = e.ProtocolNameList[0]
    69  			}
    70  		}
    71  		if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
    72  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
    73  		}
    74  		if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 {
    75  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
    76  		}
    77  
    78  		remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
    79  		if remoteCipherSuite == nil {
    80  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
    81  		}
    82  
    83  		selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
    84  		if !found {
    85  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
    86  		}
    87  
    88  		state.cipherSuite = selectedCipherSuite
    89  		state.remoteRandom = h.Random
    90  		cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
    91  
    92  		if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) {
    93  			return handleResumption(ctx, c, state, cache, cfg)
    94  		}
    95  
    96  		if len(state.SessionID) > 0 {
    97  			cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID)
    98  			if err := cfg.sessionStore.Del(state.SessionID); err != nil {
    99  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   100  			}
   101  		}
   102  
   103  		if cfg.sessionStore == nil {
   104  			state.SessionID = []byte{}
   105  		} else {
   106  			state.SessionID = h.SessionID
   107  		}
   108  
   109  		state.masterSecret = []byte{}
   110  	}
   111  
   112  	if cfg.localPSKCallback != nil {
   113  		seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
   114  			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
   115  			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
   116  		)
   117  	} else {
   118  		seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
   119  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
   120  			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
   121  			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
   122  			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
   123  		)
   124  	}
   125  	if !ok {
   126  		// Don't have enough messages. Keep reading
   127  		return 0, nil, nil
   128  	}
   129  	state.handshakeRecvSequence = seq
   130  
   131  	if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
   132  		state.PeerCertificates = h.Certificate
   133  	} else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
   134  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
   135  	}
   136  
   137  	if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
   138  		alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
   139  		if err != nil {
   140  			return 0, alertPtr, err
   141  		}
   142  	}
   143  
   144  	if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
   145  		state.remoteRequestedCertificate = true
   146  	}
   147  
   148  	return flight5, nil, nil
   149  }
   150  
   151  func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
   152  	if err := state.initCipherSuite(); err != nil {
   153  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   154  	}
   155  
   156  	// Now, encrypted packets can be handled
   157  	if err := c.handleQueuedPackets(ctx); err != nil {
   158  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   159  	}
   160  
   161  	_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
   162  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
   163  	)
   164  	if !ok {
   165  		// No valid message received. Keep reading
   166  		return 0, nil, nil
   167  	}
   168  
   169  	var finished *handshake.MessageFinished
   170  	if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
   171  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
   172  	}
   173  	plainText := cache.pullAndMerge(
   174  		handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
   175  		handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
   176  	)
   177  
   178  	expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
   179  	if err != nil {
   180  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   181  	}
   182  	if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
   183  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
   184  	}
   185  
   186  	clientRandom := state.localRandom.MarshalFixed()
   187  	cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
   188  
   189  	return flight5b, nil, nil
   190  }
   191  
   192  func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
   193  	var err error
   194  	if state.cipherSuite == nil {
   195  		return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
   196  	}
   197  	if cfg.localPSKCallback != nil {
   198  		var psk []byte
   199  		if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
   200  			return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   201  		}
   202  		state.IdentityHint = h.IdentityHint
   203  		switch state.cipherSuite.KeyExchangeAlgorithm() {
   204  		case types.KeyExchangeAlgorithmPsk:
   205  			state.preMasterSecret = prf.PSKPreMasterSecret(psk)
   206  		case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk):
   207  			if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
   208  				return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   209  			}
   210  			state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
   211  			if err != nil {
   212  				return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   213  			}
   214  		default:
   215  			return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
   216  		}
   217  	} else {
   218  		if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
   219  			return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   220  		}
   221  
   222  		if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
   223  			return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   224  		}
   225  	}
   226  
   227  	return nil, nil //nolint:nilnil
   228  }
   229  
   230  func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
   231  	extensions := []extension.Extension{
   232  		&extension.SupportedSignatureAlgorithms{
   233  			SignatureHashAlgorithms: cfg.localSignatureSchemes,
   234  		},
   235  		&extension.RenegotiationInfo{
   236  			RenegotiatedConnection: 0,
   237  		},
   238  	}
   239  	if state.namedCurve != 0 {
   240  		extensions = append(extensions, []extension.Extension{
   241  			&extension.SupportedEllipticCurves{
   242  				EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
   243  			},
   244  			&extension.SupportedPointFormats{
   245  				PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
   246  			},
   247  		}...)
   248  	}
   249  
   250  	if len(cfg.localSRTPProtectionProfiles) > 0 {
   251  		extensions = append(extensions, &extension.UseSRTP{
   252  			ProtectionProfiles: cfg.localSRTPProtectionProfiles,
   253  		})
   254  	}
   255  
   256  	if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
   257  		cfg.extendedMasterSecret == RequireExtendedMasterSecret {
   258  		extensions = append(extensions, &extension.UseExtendedMasterSecret{
   259  			Supported: true,
   260  		})
   261  	}
   262  
   263  	if len(cfg.serverName) > 0 {
   264  		extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
   265  	}
   266  
   267  	if len(cfg.supportedProtocols) > 0 {
   268  		extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
   269  	}
   270  
   271  	return []*packet{
   272  		{
   273  			record: &recordlayer.RecordLayer{
   274  				Header: recordlayer.Header{
   275  					Version: protocol.Version1_2,
   276  				},
   277  				Content: &handshake.Handshake{
   278  					Message: &handshake.MessageClientHello{
   279  						Version:            protocol.Version1_2,
   280  						SessionID:          state.SessionID,
   281  						Cookie:             state.cookie,
   282  						Random:             state.localRandom,
   283  						CipherSuiteIDs:     cipherSuiteIDs(cfg.localCipherSuites),
   284  						CompressionMethods: defaultCompressionMethods(),
   285  						Extensions:         extensions,
   286  					},
   287  				},
   288  			},
   289  		},
   290  	}, nil, nil
   291  }