github.com/pion/dtls/v2@v2.2.12/flight5handler.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto"
    10  	"crypto/x509"
    11  
    12  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    13  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    14  	"github.com/pion/dtls/v2/pkg/protocol"
    15  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    16  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    17  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    18  )
    19  
    20  func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
    21  	_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    22  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
    23  	)
    24  	if !ok {
    25  		// No valid message received. Keep reading
    26  		return 0, nil, nil
    27  	}
    28  
    29  	var finished *handshake.MessageFinished
    30  	if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
    31  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
    32  	}
    33  	plainText := cache.pullAndMerge(
    34  		handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
    35  		handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
    36  		handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
    37  		handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
    38  		handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
    39  		handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
    40  		handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
    41  		handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
    42  		handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
    43  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
    44  	)
    45  
    46  	expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
    47  	if err != nil {
    48  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
    49  	}
    50  	if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
    51  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
    52  	}
    53  
    54  	if len(state.SessionID) > 0 {
    55  		s := Session{
    56  			ID:     state.SessionID,
    57  			Secret: state.masterSecret,
    58  		}
    59  		cfg.log.Tracef("[handshake] save new session: %x", s.ID)
    60  		if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
    61  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
    62  		}
    63  	}
    64  
    65  	return flight5, nil, nil
    66  }
    67  
    68  func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
    69  	var privateKey crypto.PrivateKey
    70  	var pkts []*packet
    71  	if state.remoteRequestedCertificate {
    72  		_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite,
    73  			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false})
    74  		if !ok {
    75  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
    76  		}
    77  		reqInfo := CertificateRequestInfo{}
    78  		if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
    79  			reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames
    80  		} else {
    81  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
    82  		}
    83  		certificate, err := cfg.getClientCertificate(&reqInfo)
    84  		if err != nil {
    85  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
    86  		}
    87  		if certificate == nil {
    88  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain
    89  		}
    90  		if certificate.Certificate != nil {
    91  			privateKey = certificate.PrivateKey
    92  		}
    93  		pkts = append(pkts,
    94  			&packet{
    95  				record: &recordlayer.RecordLayer{
    96  					Header: recordlayer.Header{
    97  						Version: protocol.Version1_2,
    98  					},
    99  					Content: &handshake.Handshake{
   100  						Message: &handshake.MessageCertificate{
   101  							Certificate: certificate.Certificate,
   102  						},
   103  					},
   104  				},
   105  			})
   106  	}
   107  
   108  	clientKeyExchange := &handshake.MessageClientKeyExchange{}
   109  	if cfg.localPSKCallback == nil {
   110  		clientKeyExchange.PublicKey = state.localKeypair.PublicKey
   111  	} else {
   112  		clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
   113  	}
   114  	if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 {
   115  		clientKeyExchange.PublicKey = state.localKeypair.PublicKey
   116  	}
   117  
   118  	pkts = append(pkts,
   119  		&packet{
   120  			record: &recordlayer.RecordLayer{
   121  				Header: recordlayer.Header{
   122  					Version: protocol.Version1_2,
   123  				},
   124  				Content: &handshake.Handshake{
   125  					Message: clientKeyExchange,
   126  				},
   127  			},
   128  		})
   129  
   130  	serverKeyExchangeData := cache.pullAndMerge(
   131  		handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
   132  	)
   133  
   134  	serverKeyExchange := &handshake.MessageServerKeyExchange{}
   135  
   136  	// handshakeMessageServerKeyExchange is optional for PSK
   137  	if len(serverKeyExchangeData) == 0 {
   138  		alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
   139  		if err != nil {
   140  			return nil, alertPtr, err
   141  		}
   142  	} else {
   143  		rawHandshake := &handshake.Handshake{
   144  			KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(),
   145  		}
   146  		err := rawHandshake.Unmarshal(serverKeyExchangeData)
   147  		if err != nil {
   148  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
   149  		}
   150  
   151  		switch h := rawHandshake.Message.(type) {
   152  		case *handshake.MessageServerKeyExchange:
   153  			serverKeyExchange = h
   154  		default:
   155  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
   156  		}
   157  	}
   158  
   159  	// Append not-yet-sent packets
   160  	merged := []byte{}
   161  	seqPred := uint16(state.handshakeSendSequence)
   162  	for _, p := range pkts {
   163  		h, ok := p.record.Content.(*handshake.Handshake)
   164  		if !ok {
   165  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
   166  		}
   167  		h.Header.MessageSequence = seqPred
   168  		seqPred++
   169  		raw, err := h.Marshal()
   170  		if err != nil {
   171  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   172  		}
   173  		merged = append(merged, raw...)
   174  	}
   175  
   176  	if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
   177  		return nil, alertPtr, err
   178  	}
   179  
   180  	// If the client has sent a certificate with signing ability, a digitally-signed
   181  	// CertificateVerify message is sent to explicitly verify possession of the
   182  	// private key in the certificate.
   183  	if state.remoteRequestedCertificate && privateKey != nil {
   184  		plainText := append(cache.pullAndMerge(
   185  			handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
   186  			handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
   187  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
   188  			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
   189  			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
   190  			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
   191  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
   192  			handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
   193  		), merged...)
   194  
   195  		// Find compatible signature scheme
   196  		signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
   197  		if err != nil {
   198  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
   199  		}
   200  
   201  		certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
   202  		if err != nil {
   203  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   204  		}
   205  		state.localCertificatesVerify = certVerify
   206  
   207  		p := &packet{
   208  			record: &recordlayer.RecordLayer{
   209  				Header: recordlayer.Header{
   210  					Version: protocol.Version1_2,
   211  				},
   212  				Content: &handshake.Handshake{
   213  					Message: &handshake.MessageCertificateVerify{
   214  						HashAlgorithm:      signatureHashAlgo.Hash,
   215  						SignatureAlgorithm: signatureHashAlgo.Signature,
   216  						Signature:          state.localCertificatesVerify,
   217  					},
   218  				},
   219  			},
   220  		}
   221  		pkts = append(pkts, p)
   222  
   223  		h, ok := p.record.Content.(*handshake.Handshake)
   224  		if !ok {
   225  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
   226  		}
   227  		h.Header.MessageSequence = seqPred
   228  		// seqPred++ // this is the last use of seqPred
   229  		raw, err := h.Marshal()
   230  		if err != nil {
   231  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   232  		}
   233  		merged = append(merged, raw...)
   234  	}
   235  
   236  	pkts = append(pkts,
   237  		&packet{
   238  			record: &recordlayer.RecordLayer{
   239  				Header: recordlayer.Header{
   240  					Version: protocol.Version1_2,
   241  				},
   242  				Content: &protocol.ChangeCipherSpec{},
   243  			},
   244  		})
   245  
   246  	if len(state.localVerifyData) == 0 {
   247  		plainText := cache.pullAndMerge(
   248  			handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
   249  			handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
   250  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
   251  			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
   252  			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
   253  			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
   254  			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
   255  			handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
   256  			handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
   257  			handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
   258  		)
   259  
   260  		var err error
   261  		state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
   262  		if err != nil {
   263  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   264  		}
   265  	}
   266  
   267  	pkts = append(pkts,
   268  		&packet{
   269  			record: &recordlayer.RecordLayer{
   270  				Header: recordlayer.Header{
   271  					Version: protocol.Version1_2,
   272  					Epoch:   1,
   273  				},
   274  				Content: &handshake.Handshake{
   275  					Message: &handshake.MessageFinished{
   276  						VerifyData: state.localVerifyData,
   277  					},
   278  				},
   279  			},
   280  			shouldEncrypt:            true,
   281  			resetLocalSequenceNumber: true,
   282  		})
   283  
   284  	return pkts, nil, nil
   285  }
   286  
   287  func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
   288  	if state.cipherSuite.IsInitialized() {
   289  		return nil, nil //nolint
   290  	}
   291  
   292  	clientRandom := state.localRandom.MarshalFixed()
   293  	serverRandom := state.remoteRandom.MarshalFixed()
   294  
   295  	var err error
   296  
   297  	if state.extendedMasterSecret {
   298  		var sessionHash []byte
   299  		sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
   300  		if err != nil {
   301  			return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   302  		}
   303  
   304  		state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
   305  		if err != nil {
   306  			return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
   307  		}
   308  	} else {
   309  		state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
   310  		if err != nil {
   311  			return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   312  		}
   313  	}
   314  
   315  	if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
   316  		// Verify that the pair of hash algorithm and signiture is listed.
   317  		var validSignatureScheme bool
   318  		for _, ss := range cfg.localSignatureSchemes {
   319  			if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
   320  				validSignatureScheme = true
   321  				break
   322  			}
   323  		}
   324  		if !validSignatureScheme {
   325  			return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
   326  		}
   327  
   328  		expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
   329  		if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
   330  			return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   331  		}
   332  		var chains [][]*x509.Certificate
   333  		if !cfg.insecureSkipVerify {
   334  			if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
   335  				return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   336  			}
   337  		}
   338  		if cfg.verifyPeerCertificate != nil {
   339  			if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
   340  				return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   341  			}
   342  		}
   343  	}
   344  	if cfg.verifyConnection != nil {
   345  		if err = cfg.verifyConnection(state.clone()); err != nil {
   346  			return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
   347  		}
   348  	}
   349  
   350  	if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
   351  		return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   352  	}
   353  
   354  	cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
   355  
   356  	return nil, nil //nolint
   357  }