github.com/pion/dtls/v2@v2.2.12/handshake_cache.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "sync" 8 9 "github.com/pion/dtls/v2/pkg/crypto/prf" 10 "github.com/pion/dtls/v2/pkg/protocol/handshake" 11 ) 12 13 type handshakeCacheItem struct { 14 typ handshake.Type 15 isClient bool 16 epoch uint16 17 messageSequence uint16 18 data []byte 19 } 20 21 type handshakeCachePullRule struct { 22 typ handshake.Type 23 epoch uint16 24 isClient bool 25 optional bool 26 } 27 28 type handshakeCache struct { 29 cache []*handshakeCacheItem 30 mu sync.Mutex 31 } 32 33 func newHandshakeCache() *handshakeCache { 34 return &handshakeCache{} 35 } 36 37 func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) { 38 h.mu.Lock() 39 defer h.mu.Unlock() 40 41 h.cache = append(h.cache, &handshakeCacheItem{ 42 data: append([]byte{}, data...), 43 epoch: epoch, 44 messageSequence: messageSequence, 45 typ: typ, 46 isClient: isClient, 47 }) 48 } 49 50 // returns a list handshakes that match the requested rules 51 // the list will contain null entries for rules that can't be satisfied 52 // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies) 53 func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { 54 h.mu.Lock() 55 defer h.mu.Unlock() 56 57 out := make([]*handshakeCacheItem, len(rules)) 58 for i, r := range rules { 59 for _, c := range h.cache { 60 if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { 61 switch { 62 case out[i] == nil: 63 out[i] = c 64 case out[i].messageSequence < c.messageSequence: 65 out[i] = c 66 } 67 } 68 } 69 } 70 71 return out 72 } 73 74 // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. 75 func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) { 76 h.mu.Lock() 77 defer h.mu.Unlock() 78 79 ci := make(map[handshake.Type]*handshakeCacheItem) 80 for _, r := range rules { 81 var item *handshakeCacheItem 82 for _, c := range h.cache { 83 if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { 84 switch { 85 case item == nil: 86 item = c 87 case item.messageSequence < c.messageSequence: 88 item = c 89 } 90 } 91 } 92 if !r.optional && item == nil { 93 // Missing mandatory message. 94 return startSeq, nil, false 95 } 96 ci[r.typ] = item 97 } 98 out := make(map[handshake.Type]handshake.Message) 99 seq := startSeq 100 for _, r := range rules { 101 t := r.typ 102 i := ci[t] 103 if i == nil { 104 continue 105 } 106 var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm 107 if cipherSuite != nil { 108 keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm() 109 } 110 rawHandshake := &handshake.Handshake{ 111 KeyExchangeAlgorithm: keyExchangeAlgorithm, 112 } 113 if err := rawHandshake.Unmarshal(i.data); err != nil { 114 return startSeq, nil, false 115 } 116 if uint16(seq) != rawHandshake.Header.MessageSequence { 117 // There is a gap. Some messages are not arrived. 118 return startSeq, nil, false 119 } 120 seq++ 121 out[t] = rawHandshake.Message 122 } 123 return seq, out, true 124 } 125 126 // pullAndMerge calls pull and then merges the results, ignoring any null entries 127 func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { 128 merged := []byte{} 129 130 for _, p := range h.pull(rules...) { 131 if p != nil { 132 merged = append(merged, p.data...) 133 } 134 } 135 return merged 136 } 137 138 // sessionHash returns the session hash for Extended Master Secret support 139 // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 140 func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) { 141 merged := []byte{} 142 143 // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 144 handshakeBuffer := h.pull( 145 handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false}, 146 handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false}, 147 handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false}, 148 handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false}, 149 handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false}, 150 handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false}, 151 handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false}, 152 handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false}, 153 ) 154 155 for _, p := range handshakeBuffer { 156 if p == nil { 157 continue 158 } 159 160 merged = append(merged, p.data...) 161 } 162 for _, a := range additional { 163 merged = append(merged, a...) 164 } 165 166 hash := hf() 167 if _, err := hash.Write(merged); err != nil { 168 return []byte{}, err 169 } 170 171 return hash.Sum(nil), nil 172 }