github.com/pion/dtls/v2@v2.2.12/flight4handler.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"context"
     8  	"crypto/rand"
     9  	"crypto/x509"
    10  
    11  	"github.com/pion/dtls/v2/internal/ciphersuite"
    12  	"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
    13  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    14  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    15  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    16  	"github.com/pion/dtls/v2/pkg/protocol"
    17  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    18  	"github.com/pion/dtls/v2/pkg/protocol/extension"
    19  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    20  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    21  )
    22  
    23  func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
    24  	seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    25  		handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true},
    26  		handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
    27  		handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true},
    28  	)
    29  	if !ok {
    30  		// No valid message received. Keep reading
    31  		return 0, nil, nil
    32  	}
    33  
    34  	// Validate type
    35  	var clientKeyExchange *handshake.MessageClientKeyExchange
    36  	if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok {
    37  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
    38  	}
    39  
    40  	if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert {
    41  		state.PeerCertificates = h.Certificate
    42  		// If the client offer its certificate, just disable session resumption.
    43  		// Otherwise, we have to store the certificate identitfication and expire time.
    44  		// And we have to check whether this certificate expired, revoked or changed.
    45  		//
    46  		// https://curl.se/docs/CVE-2016-5419.html
    47  		state.SessionID = nil
    48  	}
    49  
    50  	if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify {
    51  		if state.PeerCertificates == nil {
    52  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate
    53  		}
    54  
    55  		plainText := cache.pullAndMerge(
    56  			handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
    57  			handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
    58  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
    59  			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
    60  			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
    61  			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
    62  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
    63  			handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
    64  		)
    65  
    66  		// Verify that the pair of hash algorithm and signiture is listed.
    67  		var validSignatureScheme bool
    68  		for _, ss := range cfg.localSignatureSchemes {
    69  			if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
    70  				validSignatureScheme = true
    71  				break
    72  			}
    73  		}
    74  		if !validSignatureScheme {
    75  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
    76  		}
    77  
    78  		if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil {
    79  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
    80  		}
    81  		var chains [][]*x509.Certificate
    82  		var err error
    83  		var verified bool
    84  		if cfg.clientAuth >= VerifyClientCertIfGiven {
    85  			if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
    86  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
    87  			}
    88  			verified = true
    89  		}
    90  		if cfg.verifyPeerCertificate != nil {
    91  			if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
    92  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
    93  			}
    94  		}
    95  		state.peerCertificatesVerified = verified
    96  	} else if state.PeerCertificates != nil {
    97  		// A certificate was received, but we haven't seen a CertificateVerify
    98  		// keep reading until we receive one
    99  		return 0, nil, nil
   100  	}
   101  
   102  	if !state.cipherSuite.IsInitialized() {
   103  		serverRandom := state.localRandom.MarshalFixed()
   104  		clientRandom := state.remoteRandom.MarshalFixed()
   105  
   106  		var err error
   107  		var preMasterSecret []byte
   108  		if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
   109  			var psk []byte
   110  			if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
   111  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   112  			}
   113  			state.IdentityHint = clientKeyExchange.IdentityHint
   114  			switch state.cipherSuite.KeyExchangeAlgorithm() {
   115  			case CipherSuiteKeyExchangeAlgorithmPsk:
   116  				preMasterSecret = prf.PSKPreMasterSecret(psk)
   117  			case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe):
   118  				if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
   119  					return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   120  				}
   121  			default:
   122  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite
   123  			}
   124  		} else {
   125  			preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
   126  			if err != nil {
   127  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
   128  			}
   129  		}
   130  
   131  		if state.extendedMasterSecret {
   132  			var sessionHash []byte
   133  			sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch)
   134  			if err != nil {
   135  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   136  			}
   137  
   138  			state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
   139  			if err != nil {
   140  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   141  			}
   142  		} else {
   143  			state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
   144  			if err != nil {
   145  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   146  			}
   147  		}
   148  
   149  		if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
   150  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   151  		}
   152  		cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
   153  	}
   154  
   155  	if len(state.SessionID) > 0 {
   156  		s := Session{
   157  			ID:     state.SessionID,
   158  			Secret: state.masterSecret,
   159  		}
   160  		cfg.log.Tracef("[handshake] save new session: %x", s.ID)
   161  		if err := cfg.sessionStore.Set(state.SessionID, s); err != nil {
   162  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   163  		}
   164  	}
   165  
   166  	// Now, encrypted packets can be handled
   167  	if err := c.handleQueuedPackets(ctx); err != nil {
   168  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   169  	}
   170  
   171  	seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite,
   172  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
   173  	)
   174  	if !ok {
   175  		// No valid message received. Keep reading
   176  		return 0, nil, nil
   177  	}
   178  	state.handshakeRecvSequence = seq
   179  
   180  	if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
   181  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
   182  	}
   183  
   184  	if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
   185  		if cfg.verifyConnection != nil {
   186  			if err := cfg.verifyConnection(state.clone()); err != nil {
   187  				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   188  			}
   189  		}
   190  		return flight6, nil, nil
   191  	}
   192  
   193  	switch cfg.clientAuth {
   194  	case RequireAnyClientCert:
   195  		if state.PeerCertificates == nil {
   196  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
   197  		}
   198  	case VerifyClientCertIfGiven:
   199  		if state.PeerCertificates != nil && !state.peerCertificatesVerified {
   200  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
   201  		}
   202  	case RequireAndVerifyClientCert:
   203  		if state.PeerCertificates == nil {
   204  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
   205  		}
   206  		if !state.peerCertificatesVerified {
   207  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
   208  		}
   209  	case NoClientCert, RequestClientCert:
   210  		// go to flight6
   211  	}
   212  	if cfg.verifyConnection != nil {
   213  		if err := cfg.verifyConnection(state.clone()); err != nil {
   214  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   215  		}
   216  	}
   217  
   218  	return flight6, nil, nil
   219  }
   220  
   221  func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
   222  	extensions := []extension.Extension{&extension.RenegotiationInfo{
   223  		RenegotiatedConnection: 0,
   224  	}}
   225  	if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
   226  		cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
   227  		extensions = append(extensions, &extension.UseExtendedMasterSecret{
   228  			Supported: true,
   229  		})
   230  	}
   231  	if state.getSRTPProtectionProfile() != 0 {
   232  		extensions = append(extensions, &extension.UseSRTP{
   233  			ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
   234  		})
   235  	}
   236  	if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
   237  		extensions = append(extensions, &extension.SupportedPointFormats{
   238  			PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
   239  		})
   240  	}
   241  
   242  	selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
   243  	if err != nil {
   244  		return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
   245  	}
   246  	if selectedProto != "" {
   247  		extensions = append(extensions, &extension.ALPN{
   248  			ProtocolNameList: []string{selectedProto},
   249  		})
   250  		state.NegotiatedProtocol = selectedProto
   251  	}
   252  
   253  	var pkts []*packet
   254  	cipherSuiteID := uint16(state.cipherSuite.ID())
   255  
   256  	if cfg.sessionStore != nil {
   257  		state.SessionID = make([]byte, sessionLength)
   258  		if _, err := rand.Read(state.SessionID); err != nil {
   259  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   260  		}
   261  	}
   262  
   263  	pkts = append(pkts, &packet{
   264  		record: &recordlayer.RecordLayer{
   265  			Header: recordlayer.Header{
   266  				Version: protocol.Version1_2,
   267  			},
   268  			Content: &handshake.Handshake{
   269  				Message: &handshake.MessageServerHello{
   270  					Version:           protocol.Version1_2,
   271  					Random:            state.localRandom,
   272  					SessionID:         state.SessionID,
   273  					CipherSuiteID:     &cipherSuiteID,
   274  					CompressionMethod: defaultCompressionMethods()[0],
   275  					Extensions:        extensions,
   276  				},
   277  			},
   278  		},
   279  	})
   280  
   281  	switch {
   282  	case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
   283  		certificate, err := cfg.getCertificate(&ClientHelloInfo{
   284  			ServerName:   state.serverName,
   285  			CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()},
   286  		})
   287  		if err != nil {
   288  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
   289  		}
   290  
   291  		pkts = append(pkts, &packet{
   292  			record: &recordlayer.RecordLayer{
   293  				Header: recordlayer.Header{
   294  					Version: protocol.Version1_2,
   295  				},
   296  				Content: &handshake.Handshake{
   297  					Message: &handshake.MessageCertificate{
   298  						Certificate: certificate.Certificate,
   299  					},
   300  				},
   301  			},
   302  		})
   303  
   304  		serverRandom := state.localRandom.MarshalFixed()
   305  		clientRandom := state.remoteRandom.MarshalFixed()
   306  
   307  		// Find compatible signature scheme
   308  		signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
   309  		if err != nil {
   310  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
   311  		}
   312  
   313  		signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash)
   314  		if err != nil {
   315  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   316  		}
   317  		state.localKeySignature = signature
   318  
   319  		pkts = append(pkts, &packet{
   320  			record: &recordlayer.RecordLayer{
   321  				Header: recordlayer.Header{
   322  					Version: protocol.Version1_2,
   323  				},
   324  				Content: &handshake.Handshake{
   325  					Message: &handshake.MessageServerKeyExchange{
   326  						EllipticCurveType:  elliptic.CurveTypeNamedCurve,
   327  						NamedCurve:         state.namedCurve,
   328  						PublicKey:          state.localKeypair.PublicKey,
   329  						HashAlgorithm:      signatureHashAlgo.Hash,
   330  						SignatureAlgorithm: signatureHashAlgo.Signature,
   331  						Signature:          state.localKeySignature,
   332  					},
   333  				},
   334  			},
   335  		})
   336  
   337  		if cfg.clientAuth > NoClientCert {
   338  			// An empty list of certificateAuthorities signals to
   339  			// the client that it may send any certificate in response
   340  			// to our request. When we know the CAs we trust, then
   341  			// we can send them down, so that the client can choose
   342  			// an appropriate certificate to give to us.
   343  			var certificateAuthorities [][]byte
   344  			if cfg.clientCAs != nil {
   345  				// nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty.
   346  				certificateAuthorities = cfg.clientCAs.Subjects()
   347  			}
   348  			pkts = append(pkts, &packet{
   349  				record: &recordlayer.RecordLayer{
   350  					Header: recordlayer.Header{
   351  						Version: protocol.Version1_2,
   352  					},
   353  					Content: &handshake.Handshake{
   354  						Message: &handshake.MessageCertificateRequest{
   355  							CertificateTypes:            []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign},
   356  							SignatureHashAlgorithms:     cfg.localSignatureSchemes,
   357  							CertificateAuthoritiesNames: certificateAuthorities,
   358  						},
   359  					},
   360  				},
   361  			})
   362  		}
   363  	case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe):
   364  		// To help the client in selecting which identity to use, the server
   365  		// can provide a "PSK identity hint" in the ServerKeyExchange message.
   366  		// If no hint is provided and cipher suite doesn't use elliptic curve,
   367  		// the ServerKeyExchange message is omitted.
   368  		//
   369  		// https://tools.ietf.org/html/rfc4279#section-2
   370  		srvExchange := &handshake.MessageServerKeyExchange{
   371  			IdentityHint: cfg.localPSKIdentityHint,
   372  		}
   373  		if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) {
   374  			srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve
   375  			srvExchange.NamedCurve = state.namedCurve
   376  			srvExchange.PublicKey = state.localKeypair.PublicKey
   377  		}
   378  		pkts = append(pkts, &packet{
   379  			record: &recordlayer.RecordLayer{
   380  				Header: recordlayer.Header{
   381  					Version: protocol.Version1_2,
   382  				},
   383  				Content: &handshake.Handshake{
   384  					Message: srvExchange,
   385  				},
   386  			},
   387  		})
   388  	}
   389  
   390  	pkts = append(pkts, &packet{
   391  		record: &recordlayer.RecordLayer{
   392  			Header: recordlayer.Header{
   393  				Version: protocol.Version1_2,
   394  			},
   395  			Content: &handshake.Handshake{
   396  				Message: &handshake.MessageServerHelloDone{},
   397  			},
   398  		},
   399  	})
   400  
   401  	return pkts, nil, nil
   402  }