github.com/pion/dtls/v2@v2.2.12/flight3handler.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "context" 9 10 "github.com/pion/dtls/v2/internal/ciphersuite/types" 11 "github.com/pion/dtls/v2/pkg/crypto/elliptic" 12 "github.com/pion/dtls/v2/pkg/crypto/prf" 13 "github.com/pion/dtls/v2/pkg/protocol" 14 "github.com/pion/dtls/v2/pkg/protocol/alert" 15 "github.com/pion/dtls/v2/pkg/protocol/extension" 16 "github.com/pion/dtls/v2/pkg/protocol/handshake" 17 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 18 ) 19 20 func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit 21 // Clients may receive multiple HelloVerifyRequest messages with different cookies. 22 // Clients SHOULD handle this by sending a new ClientHello with a cookie in response 23 // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 24 seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, 25 handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, 26 ) 27 if ok { 28 if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { 29 // DTLS 1.2 clients must not assume that the server will use the protocol version 30 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 31 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { 32 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion 33 } 34 state.cookie = append([]byte{}, h.Cookie...) 35 state.handshakeRecvSequence = seq 36 return flight3, nil, nil 37 } 38 } 39 40 _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, 41 handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, 42 ) 43 if !ok { 44 // Don't have enough messages. Keep reading 45 return 0, nil, nil 46 } 47 48 if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { 49 if !h.Version.Equal(protocol.Version1_2) { 50 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion 51 } 52 for _, v := range h.Extensions { 53 switch e := v.(type) { 54 case *extension.UseSRTP: 55 profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) 56 if !found { 57 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile 58 } 59 state.setSRTPProtectionProfile(profile) 60 case *extension.UseExtendedMasterSecret: 61 if cfg.extendedMasterSecret != DisableExtendedMasterSecret { 62 state.extendedMasterSecret = true 63 } 64 case *extension.ALPN: 65 if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling 66 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error? 67 } 68 state.NegotiatedProtocol = e.ProtocolNameList[0] 69 } 70 } 71 if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { 72 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS 73 } 74 if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { 75 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension 76 } 77 78 remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites) 79 if remoteCipherSuite == nil { 80 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection 81 } 82 83 selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) 84 if !found { 85 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite 86 } 87 88 state.cipherSuite = selectedCipherSuite 89 state.remoteRandom = h.Random 90 cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) 91 92 if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) { 93 return handleResumption(ctx, c, state, cache, cfg) 94 } 95 96 if len(state.SessionID) > 0 { 97 cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) 98 if err := cfg.sessionStore.Del(state.SessionID); err != nil { 99 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 100 } 101 } 102 103 if cfg.sessionStore == nil { 104 state.SessionID = []byte{} 105 } else { 106 state.SessionID = h.SessionID 107 } 108 109 state.masterSecret = []byte{} 110 } 111 112 if cfg.localPSKCallback != nil { 113 seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, 114 handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, 115 handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, 116 ) 117 } else { 118 seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, 119 handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, 120 handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, 121 handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, 122 handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, 123 ) 124 } 125 if !ok { 126 // Don't have enough messages. Keep reading 127 return 0, nil, nil 128 } 129 state.handshakeRecvSequence = seq 130 131 if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { 132 state.PeerCertificates = h.Certificate 133 } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { 134 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate 135 } 136 137 if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { 138 alertPtr, err := handleServerKeyExchange(c, state, cfg, h) 139 if err != nil { 140 return 0, alertPtr, err 141 } 142 } 143 144 if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { 145 state.remoteRequestedCertificate = true 146 } 147 148 return flight5, nil, nil 149 } 150 151 func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { 152 if err := state.initCipherSuite(); err != nil { 153 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 154 } 155 156 // Now, encrypted packets can be handled 157 if err := c.handleQueuedPackets(ctx); err != nil { 158 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 159 } 160 161 _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, 162 handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, 163 ) 164 if !ok { 165 // No valid message received. Keep reading 166 return 0, nil, nil 167 } 168 169 var finished *handshake.MessageFinished 170 if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { 171 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil 172 } 173 plainText := cache.pullAndMerge( 174 handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, 175 handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, 176 ) 177 178 expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) 179 if err != nil { 180 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 181 } 182 if !bytes.Equal(expectedVerifyData, finished.VerifyData) { 183 return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch 184 } 185 186 clientRandom := state.localRandom.MarshalFixed() 187 cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) 188 189 return flight5b, nil, nil 190 } 191 192 func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { 193 var err error 194 if state.cipherSuite == nil { 195 return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite 196 } 197 if cfg.localPSKCallback != nil { 198 var psk []byte 199 if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { 200 return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 201 } 202 state.IdentityHint = h.IdentityHint 203 switch state.cipherSuite.KeyExchangeAlgorithm() { 204 case types.KeyExchangeAlgorithmPsk: 205 state.preMasterSecret = prf.PSKPreMasterSecret(psk) 206 case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): 207 if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { 208 return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 209 } 210 state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) 211 if err != nil { 212 return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 213 } 214 default: 215 return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite 216 } 217 } else { 218 if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { 219 return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 220 } 221 222 if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { 223 return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err 224 } 225 } 226 227 return nil, nil //nolint:nilnil 228 } 229 230 func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { 231 extensions := []extension.Extension{ 232 &extension.SupportedSignatureAlgorithms{ 233 SignatureHashAlgorithms: cfg.localSignatureSchemes, 234 }, 235 &extension.RenegotiationInfo{ 236 RenegotiatedConnection: 0, 237 }, 238 } 239 if state.namedCurve != 0 { 240 extensions = append(extensions, []extension.Extension{ 241 &extension.SupportedEllipticCurves{ 242 EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, 243 }, 244 &extension.SupportedPointFormats{ 245 PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, 246 }, 247 }...) 248 } 249 250 if len(cfg.localSRTPProtectionProfiles) > 0 { 251 extensions = append(extensions, &extension.UseSRTP{ 252 ProtectionProfiles: cfg.localSRTPProtectionProfiles, 253 }) 254 } 255 256 if cfg.extendedMasterSecret == RequestExtendedMasterSecret || 257 cfg.extendedMasterSecret == RequireExtendedMasterSecret { 258 extensions = append(extensions, &extension.UseExtendedMasterSecret{ 259 Supported: true, 260 }) 261 } 262 263 if len(cfg.serverName) > 0 { 264 extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) 265 } 266 267 if len(cfg.supportedProtocols) > 0 { 268 extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) 269 } 270 271 return []*packet{ 272 { 273 record: &recordlayer.RecordLayer{ 274 Header: recordlayer.Header{ 275 Version: protocol.Version1_2, 276 }, 277 Content: &handshake.Handshake{ 278 Message: &handshake.MessageClientHello{ 279 Version: protocol.Version1_2, 280 SessionID: state.SessionID, 281 Cookie: state.cookie, 282 Random: state.localRandom, 283 CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), 284 CompressionMethods: defaultCompressionMethods(), 285 Extensions: extensions, 286 }, 287 }, 288 }, 289 }, 290 }, nil, nil 291 }