github.com/pion/dtls/v2@v2.2.12/state.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/gob"
     9  	"sync/atomic"
    10  
    11  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    12  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    13  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    14  	"github.com/pion/transport/v2/replaydetector"
    15  )
    16  
    17  // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
    18  type State struct {
    19  	localEpoch, remoteEpoch   atomic.Value
    20  	localSequenceNumber       []uint64 // uint48
    21  	localRandom, remoteRandom handshake.Random
    22  	masterSecret              []byte
    23  	cipherSuite               CipherSuite // nil if a cipherSuite hasn't been chosen
    24  
    25  	srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
    26  	PeerCertificates      [][]byte
    27  	IdentityHint          []byte
    28  	SessionID             []byte
    29  
    30  	isClient bool
    31  
    32  	preMasterSecret      []byte
    33  	extendedMasterSecret bool
    34  
    35  	namedCurve                 elliptic.Curve
    36  	localKeypair               *elliptic.Keypair
    37  	cookie                     []byte
    38  	handshakeSendSequence      int
    39  	handshakeRecvSequence      int
    40  	serverName                 string
    41  	remoteRequestedCertificate bool   // Did we get a CertificateRequest
    42  	localCertificatesVerify    []byte // cache CertificateVerify
    43  	localVerifyData            []byte // cached VerifyData
    44  	localKeySignature          []byte // cached keySignature
    45  	peerCertificatesVerified   bool
    46  
    47  	replayDetector []replaydetector.ReplayDetector
    48  
    49  	peerSupportedProtocols []string
    50  	NegotiatedProtocol     string
    51  }
    52  
    53  type serializedState struct {
    54  	LocalEpoch            uint16
    55  	RemoteEpoch           uint16
    56  	LocalRandom           [handshake.RandomLength]byte
    57  	RemoteRandom          [handshake.RandomLength]byte
    58  	CipherSuiteID         uint16
    59  	MasterSecret          []byte
    60  	SequenceNumber        uint64
    61  	SRTPProtectionProfile uint16
    62  	PeerCertificates      [][]byte
    63  	IdentityHint          []byte
    64  	SessionID             []byte
    65  	IsClient              bool
    66  }
    67  
    68  func (s *State) clone() *State {
    69  	serialized := s.serialize()
    70  	state := &State{}
    71  	state.deserialize(*serialized)
    72  
    73  	return state
    74  }
    75  
    76  func (s *State) serialize() *serializedState {
    77  	// Marshal random values
    78  	localRnd := s.localRandom.MarshalFixed()
    79  	remoteRnd := s.remoteRandom.MarshalFixed()
    80  
    81  	epoch := s.getLocalEpoch()
    82  	return &serializedState{
    83  		LocalEpoch:            s.getLocalEpoch(),
    84  		RemoteEpoch:           s.getRemoteEpoch(),
    85  		CipherSuiteID:         uint16(s.cipherSuite.ID()),
    86  		MasterSecret:          s.masterSecret,
    87  		SequenceNumber:        atomic.LoadUint64(&s.localSequenceNumber[epoch]),
    88  		LocalRandom:           localRnd,
    89  		RemoteRandom:          remoteRnd,
    90  		SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
    91  		PeerCertificates:      s.PeerCertificates,
    92  		IdentityHint:          s.IdentityHint,
    93  		SessionID:             s.SessionID,
    94  		IsClient:              s.isClient,
    95  	}
    96  }
    97  
    98  func (s *State) deserialize(serialized serializedState) {
    99  	// Set epoch values
   100  	epoch := serialized.LocalEpoch
   101  	s.localEpoch.Store(serialized.LocalEpoch)
   102  	s.remoteEpoch.Store(serialized.RemoteEpoch)
   103  
   104  	for len(s.localSequenceNumber) <= int(epoch) {
   105  		s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
   106  	}
   107  
   108  	// Set random values
   109  	localRandom := &handshake.Random{}
   110  	localRandom.UnmarshalFixed(serialized.LocalRandom)
   111  	s.localRandom = *localRandom
   112  
   113  	remoteRandom := &handshake.Random{}
   114  	remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
   115  	s.remoteRandom = *remoteRandom
   116  
   117  	s.isClient = serialized.IsClient
   118  
   119  	// Set master secret
   120  	s.masterSecret = serialized.MasterSecret
   121  
   122  	// Set cipher suite
   123  	s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil)
   124  
   125  	atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
   126  	s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))
   127  
   128  	// Set remote certificate
   129  	s.PeerCertificates = serialized.PeerCertificates
   130  	s.IdentityHint = serialized.IdentityHint
   131  	s.SessionID = serialized.SessionID
   132  }
   133  
   134  func (s *State) initCipherSuite() error {
   135  	if s.cipherSuite.IsInitialized() {
   136  		return nil
   137  	}
   138  
   139  	localRandom := s.localRandom.MarshalFixed()
   140  	remoteRandom := s.remoteRandom.MarshalFixed()
   141  
   142  	var err error
   143  	if s.isClient {
   144  		err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
   145  	} else {
   146  		err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
   147  	}
   148  	if err != nil {
   149  		return err
   150  	}
   151  	return nil
   152  }
   153  
   154  // MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
   155  func (s *State) MarshalBinary() ([]byte, error) {
   156  	serialized := s.serialize()
   157  
   158  	var buf bytes.Buffer
   159  	enc := gob.NewEncoder(&buf)
   160  	if err := enc.Encode(*serialized); err != nil {
   161  		return nil, err
   162  	}
   163  	return buf.Bytes(), nil
   164  }
   165  
   166  // UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation
   167  func (s *State) UnmarshalBinary(data []byte) error {
   168  	enc := gob.NewDecoder(bytes.NewBuffer(data))
   169  	var serialized serializedState
   170  	if err := enc.Decode(&serialized); err != nil {
   171  		return err
   172  	}
   173  
   174  	s.deserialize(serialized)
   175  
   176  	return s.initCipherSuite()
   177  }
   178  
   179  // ExportKeyingMaterial returns length bytes of exported key material in a new
   180  // slice as defined in RFC 5705.
   181  // This allows protocols to use DTLS for key establishment, but
   182  // then use some of the keying material for their own purposes
   183  func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
   184  	if s.getLocalEpoch() == 0 {
   185  		return nil, errHandshakeInProgress
   186  	} else if len(context) != 0 {
   187  		return nil, errContextUnsupported
   188  	} else if _, ok := invalidKeyingLabels()[label]; ok {
   189  		return nil, errReservedExportKeyingMaterial
   190  	}
   191  
   192  	localRandom := s.localRandom.MarshalFixed()
   193  	remoteRandom := s.remoteRandom.MarshalFixed()
   194  
   195  	seed := []byte(label)
   196  	if s.isClient {
   197  		seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
   198  	} else {
   199  		seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
   200  	}
   201  	return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
   202  }
   203  
   204  func (s *State) getRemoteEpoch() uint16 {
   205  	if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
   206  		return remoteEpoch
   207  	}
   208  	return 0
   209  }
   210  
   211  func (s *State) getLocalEpoch() uint16 {
   212  	if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
   213  		return localEpoch
   214  	}
   215  	return 0
   216  }
   217  
   218  func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
   219  	s.srtpProtectionProfile.Store(profile)
   220  }
   221  
   222  func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
   223  	if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
   224  		return val
   225  	}
   226  
   227  	return 0
   228  }