github.com/pion/dtls/v2@v2.2.12/flight4bhandler.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  
    10  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    11  	"github.com/pion/dtls/v2/pkg/protocol"
    12  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    13  	"github.com/pion/dtls/v2/pkg/protocol/extension"
    14  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    15  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    16  )
    17  
    18  func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
    19  	_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    20  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
    21  	)
    22  	if !ok {
    23  		// No valid message received. Keep reading
    24  		return 0, nil, nil
    25  	}
    26  
    27  	var finished *handshake.MessageFinished
    28  	if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
    29  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
    30  	}
    31  
    32  	plainText := cache.pullAndMerge(
    33  		handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
    34  		handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
    35  		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
    36  	)
    37  
    38  	expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc())
    39  	if err != nil {
    40  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
    41  	}
    42  	if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
    43  		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
    44  	}
    45  
    46  	// Other party may re-transmit the last flight. Keep state to be flight4b.
    47  	return flight4b, nil, nil
    48  }
    49  
    50  func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
    51  	var pkts []*packet
    52  
    53  	extensions := []extension.Extension{&extension.RenegotiationInfo{
    54  		RenegotiatedConnection: 0,
    55  	}}
    56  	if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
    57  		cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
    58  		extensions = append(extensions, &extension.UseExtendedMasterSecret{
    59  			Supported: true,
    60  		})
    61  	}
    62  	if state.getSRTPProtectionProfile() != 0 {
    63  		extensions = append(extensions, &extension.UseSRTP{
    64  			ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
    65  		})
    66  	}
    67  
    68  	selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
    69  	if err != nil {
    70  		return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
    71  	}
    72  	if selectedProto != "" {
    73  		extensions = append(extensions, &extension.ALPN{
    74  			ProtocolNameList: []string{selectedProto},
    75  		})
    76  		state.NegotiatedProtocol = selectedProto
    77  	}
    78  
    79  	cipherSuiteID := uint16(state.cipherSuite.ID())
    80  	serverHello := &handshake.Handshake{
    81  		Message: &handshake.MessageServerHello{
    82  			Version:           protocol.Version1_2,
    83  			Random:            state.localRandom,
    84  			SessionID:         state.SessionID,
    85  			CipherSuiteID:     &cipherSuiteID,
    86  			CompressionMethod: defaultCompressionMethods()[0],
    87  			Extensions:        extensions,
    88  		},
    89  	}
    90  
    91  	serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
    92  
    93  	if len(state.localVerifyData) == 0 {
    94  		plainText := cache.pullAndMerge(
    95  			handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
    96  		)
    97  		raw, err := serverHello.Marshal()
    98  		if err != nil {
    99  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   100  		}
   101  		plainText = append(plainText, raw...)
   102  
   103  		state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
   104  		if err != nil {
   105  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   106  		}
   107  	}
   108  
   109  	pkts = append(pkts,
   110  		&packet{
   111  			record: &recordlayer.RecordLayer{
   112  				Header: recordlayer.Header{
   113  					Version: protocol.Version1_2,
   114  				},
   115  				Content: serverHello,
   116  			},
   117  		},
   118  		&packet{
   119  			record: &recordlayer.RecordLayer{
   120  				Header: recordlayer.Header{
   121  					Version: protocol.Version1_2,
   122  				},
   123  				Content: &protocol.ChangeCipherSpec{},
   124  			},
   125  		},
   126  		&packet{
   127  			record: &recordlayer.RecordLayer{
   128  				Header: recordlayer.Header{
   129  					Version: protocol.Version1_2,
   130  					Epoch:   1,
   131  				},
   132  				Content: &handshake.Handshake{
   133  					Message: &handshake.MessageFinished{
   134  						VerifyData: state.localVerifyData,
   135  					},
   136  				},
   137  			},
   138  			shouldEncrypt:            true,
   139  			resetLocalSequenceNumber: true,
   140  		},
   141  	)
   142  
   143  	return pkts, nil, nil
   144  }