github.com/pion/dtls/v2@v2.2.12/flight4bhandler.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "context" 9 10 "github.com/pion/dtls/v2/pkg/crypto/prf" 11 "github.com/pion/dtls/v2/pkg/protocol" 12 "github.com/pion/dtls/v2/pkg/protocol/alert" 13 "github.com/pion/dtls/v2/pkg/protocol/extension" 14 "github.com/pion/dtls/v2/pkg/protocol/handshake" 15 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 16 ) 17 18 func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { 19 _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, 20 handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, 21 ) 22 if !ok { 23 // No valid message received. Keep reading 24 return 0, nil, nil 25 } 26 27 var finished *handshake.MessageFinished 28 if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { 29 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil 30 } 31 32 plainText := cache.pullAndMerge( 33 handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, 34 handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, 35 handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, 36 ) 37 38 expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) 39 if err != nil { 40 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 41 } 42 if !bytes.Equal(expectedVerifyData, finished.VerifyData) { 43 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch 44 } 45 46 // Other party may re-transmit the last flight. Keep state to be flight4b. 47 return flight4b, nil, nil 48 } 49 50 func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { 51 var pkts []*packet 52 53 extensions := []extension.Extension{&extension.RenegotiationInfo{ 54 RenegotiatedConnection: 0, 55 }} 56 if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || 57 cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { 58 extensions = append(extensions, &extension.UseExtendedMasterSecret{ 59 Supported: true, 60 }) 61 } 62 if state.getSRTPProtectionProfile() != 0 { 63 extensions = append(extensions, &extension.UseSRTP{ 64 ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, 65 }) 66 } 67 68 selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) 69 if err != nil { 70 return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err 71 } 72 if selectedProto != "" { 73 extensions = append(extensions, &extension.ALPN{ 74 ProtocolNameList: []string{selectedProto}, 75 }) 76 state.NegotiatedProtocol = selectedProto 77 } 78 79 cipherSuiteID := uint16(state.cipherSuite.ID()) 80 serverHello := &handshake.Handshake{ 81 Message: &handshake.MessageServerHello{ 82 Version: protocol.Version1_2, 83 Random: state.localRandom, 84 SessionID: state.SessionID, 85 CipherSuiteID: &cipherSuiteID, 86 CompressionMethod: defaultCompressionMethods()[0], 87 Extensions: extensions, 88 }, 89 } 90 91 serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) 92 93 if len(state.localVerifyData) == 0 { 94 plainText := cache.pullAndMerge( 95 handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, 96 ) 97 raw, err := serverHello.Marshal() 98 if err != nil { 99 return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 100 } 101 plainText = append(plainText, raw...) 102 103 state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) 104 if err != nil { 105 return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 106 } 107 } 108 109 pkts = append(pkts, 110 &packet{ 111 record: &recordlayer.RecordLayer{ 112 Header: recordlayer.Header{ 113 Version: protocol.Version1_2, 114 }, 115 Content: serverHello, 116 }, 117 }, 118 &packet{ 119 record: &recordlayer.RecordLayer{ 120 Header: recordlayer.Header{ 121 Version: protocol.Version1_2, 122 }, 123 Content: &protocol.ChangeCipherSpec{}, 124 }, 125 }, 126 &packet{ 127 record: &recordlayer.RecordLayer{ 128 Header: recordlayer.Header{ 129 Version: protocol.Version1_2, 130 Epoch: 1, 131 }, 132 Content: &handshake.Handshake{ 133 Message: &handshake.MessageFinished{ 134 VerifyData: state.localVerifyData, 135 }, 136 }, 137 }, 138 shouldEncrypt: true, 139 resetLocalSequenceNumber: true, 140 }, 141 ) 142 143 return pkts, nil, nil 144 }