github.com/pion/dtls/v2@v2.2.12/state.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "encoding/gob" 9 "sync/atomic" 10 11 "github.com/pion/dtls/v2/pkg/crypto/elliptic" 12 "github.com/pion/dtls/v2/pkg/crypto/prf" 13 "github.com/pion/dtls/v2/pkg/protocol/handshake" 14 "github.com/pion/transport/v2/replaydetector" 15 ) 16 17 // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler 18 type State struct { 19 localEpoch, remoteEpoch atomic.Value 20 localSequenceNumber []uint64 // uint48 21 localRandom, remoteRandom handshake.Random 22 masterSecret []byte 23 cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen 24 25 srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile 26 PeerCertificates [][]byte 27 IdentityHint []byte 28 SessionID []byte 29 30 isClient bool 31 32 preMasterSecret []byte 33 extendedMasterSecret bool 34 35 namedCurve elliptic.Curve 36 localKeypair *elliptic.Keypair 37 cookie []byte 38 handshakeSendSequence int 39 handshakeRecvSequence int 40 serverName string 41 remoteRequestedCertificate bool // Did we get a CertificateRequest 42 localCertificatesVerify []byte // cache CertificateVerify 43 localVerifyData []byte // cached VerifyData 44 localKeySignature []byte // cached keySignature 45 peerCertificatesVerified bool 46 47 replayDetector []replaydetector.ReplayDetector 48 49 peerSupportedProtocols []string 50 NegotiatedProtocol string 51 } 52 53 type serializedState struct { 54 LocalEpoch uint16 55 RemoteEpoch uint16 56 LocalRandom [handshake.RandomLength]byte 57 RemoteRandom [handshake.RandomLength]byte 58 CipherSuiteID uint16 59 MasterSecret []byte 60 SequenceNumber uint64 61 SRTPProtectionProfile uint16 62 PeerCertificates [][]byte 63 IdentityHint []byte 64 SessionID []byte 65 IsClient bool 66 } 67 68 func (s *State) clone() *State { 69 serialized := s.serialize() 70 state := &State{} 71 state.deserialize(*serialized) 72 73 return state 74 } 75 76 func (s *State) serialize() *serializedState { 77 // Marshal random values 78 localRnd := s.localRandom.MarshalFixed() 79 remoteRnd := s.remoteRandom.MarshalFixed() 80 81 epoch := s.getLocalEpoch() 82 return &serializedState{ 83 LocalEpoch: s.getLocalEpoch(), 84 RemoteEpoch: s.getRemoteEpoch(), 85 CipherSuiteID: uint16(s.cipherSuite.ID()), 86 MasterSecret: s.masterSecret, 87 SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), 88 LocalRandom: localRnd, 89 RemoteRandom: remoteRnd, 90 SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()), 91 PeerCertificates: s.PeerCertificates, 92 IdentityHint: s.IdentityHint, 93 SessionID: s.SessionID, 94 IsClient: s.isClient, 95 } 96 } 97 98 func (s *State) deserialize(serialized serializedState) { 99 // Set epoch values 100 epoch := serialized.LocalEpoch 101 s.localEpoch.Store(serialized.LocalEpoch) 102 s.remoteEpoch.Store(serialized.RemoteEpoch) 103 104 for len(s.localSequenceNumber) <= int(epoch) { 105 s.localSequenceNumber = append(s.localSequenceNumber, uint64(0)) 106 } 107 108 // Set random values 109 localRandom := &handshake.Random{} 110 localRandom.UnmarshalFixed(serialized.LocalRandom) 111 s.localRandom = *localRandom 112 113 remoteRandom := &handshake.Random{} 114 remoteRandom.UnmarshalFixed(serialized.RemoteRandom) 115 s.remoteRandom = *remoteRandom 116 117 s.isClient = serialized.IsClient 118 119 // Set master secret 120 s.masterSecret = serialized.MasterSecret 121 122 // Set cipher suite 123 s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil) 124 125 atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber) 126 s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile)) 127 128 // Set remote certificate 129 s.PeerCertificates = serialized.PeerCertificates 130 s.IdentityHint = serialized.IdentityHint 131 s.SessionID = serialized.SessionID 132 } 133 134 func (s *State) initCipherSuite() error { 135 if s.cipherSuite.IsInitialized() { 136 return nil 137 } 138 139 localRandom := s.localRandom.MarshalFixed() 140 remoteRandom := s.remoteRandom.MarshalFixed() 141 142 var err error 143 if s.isClient { 144 err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true) 145 } else { 146 err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false) 147 } 148 if err != nil { 149 return err 150 } 151 return nil 152 } 153 154 // MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation 155 func (s *State) MarshalBinary() ([]byte, error) { 156 serialized := s.serialize() 157 158 var buf bytes.Buffer 159 enc := gob.NewEncoder(&buf) 160 if err := enc.Encode(*serialized); err != nil { 161 return nil, err 162 } 163 return buf.Bytes(), nil 164 } 165 166 // UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation 167 func (s *State) UnmarshalBinary(data []byte) error { 168 enc := gob.NewDecoder(bytes.NewBuffer(data)) 169 var serialized serializedState 170 if err := enc.Decode(&serialized); err != nil { 171 return err 172 } 173 174 s.deserialize(serialized) 175 176 return s.initCipherSuite() 177 } 178 179 // ExportKeyingMaterial returns length bytes of exported key material in a new 180 // slice as defined in RFC 5705. 181 // This allows protocols to use DTLS for key establishment, but 182 // then use some of the keying material for their own purposes 183 func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { 184 if s.getLocalEpoch() == 0 { 185 return nil, errHandshakeInProgress 186 } else if len(context) != 0 { 187 return nil, errContextUnsupported 188 } else if _, ok := invalidKeyingLabels()[label]; ok { 189 return nil, errReservedExportKeyingMaterial 190 } 191 192 localRandom := s.localRandom.MarshalFixed() 193 remoteRandom := s.remoteRandom.MarshalFixed() 194 195 seed := []byte(label) 196 if s.isClient { 197 seed = append(append(seed, localRandom[:]...), remoteRandom[:]...) 198 } else { 199 seed = append(append(seed, remoteRandom[:]...), localRandom[:]...) 200 } 201 return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc()) 202 } 203 204 func (s *State) getRemoteEpoch() uint16 { 205 if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok { 206 return remoteEpoch 207 } 208 return 0 209 } 210 211 func (s *State) getLocalEpoch() uint16 { 212 if localEpoch, ok := s.localEpoch.Load().(uint16); ok { 213 return localEpoch 214 } 215 return 0 216 } 217 218 func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) { 219 s.srtpProtectionProfile.Store(profile) 220 } 221 222 func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { 223 if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok { 224 return val 225 } 226 227 return 0 228 }