github.com/pion/dtls/v2@v2.2.12/handshake_cache.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"sync"
     8  
     9  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    10  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    11  )
    12  
    13  type handshakeCacheItem struct {
    14  	typ             handshake.Type
    15  	isClient        bool
    16  	epoch           uint16
    17  	messageSequence uint16
    18  	data            []byte
    19  }
    20  
    21  type handshakeCachePullRule struct {
    22  	typ      handshake.Type
    23  	epoch    uint16
    24  	isClient bool
    25  	optional bool
    26  }
    27  
    28  type handshakeCache struct {
    29  	cache []*handshakeCacheItem
    30  	mu    sync.Mutex
    31  }
    32  
    33  func newHandshakeCache() *handshakeCache {
    34  	return &handshakeCache{}
    35  }
    36  
    37  func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) {
    38  	h.mu.Lock()
    39  	defer h.mu.Unlock()
    40  
    41  	h.cache = append(h.cache, &handshakeCacheItem{
    42  		data:            append([]byte{}, data...),
    43  		epoch:           epoch,
    44  		messageSequence: messageSequence,
    45  		typ:             typ,
    46  		isClient:        isClient,
    47  	})
    48  }
    49  
    50  // returns a list handshakes that match the requested rules
    51  // the list will contain null entries for rules that can't be satisfied
    52  // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies)
    53  func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem {
    54  	h.mu.Lock()
    55  	defer h.mu.Unlock()
    56  
    57  	out := make([]*handshakeCacheItem, len(rules))
    58  	for i, r := range rules {
    59  		for _, c := range h.cache {
    60  			if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
    61  				switch {
    62  				case out[i] == nil:
    63  					out[i] = c
    64  				case out[i].messageSequence < c.messageSequence:
    65  					out[i] = c
    66  				}
    67  			}
    68  		}
    69  	}
    70  
    71  	return out
    72  }
    73  
    74  // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
    75  func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) {
    76  	h.mu.Lock()
    77  	defer h.mu.Unlock()
    78  
    79  	ci := make(map[handshake.Type]*handshakeCacheItem)
    80  	for _, r := range rules {
    81  		var item *handshakeCacheItem
    82  		for _, c := range h.cache {
    83  			if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
    84  				switch {
    85  				case item == nil:
    86  					item = c
    87  				case item.messageSequence < c.messageSequence:
    88  					item = c
    89  				}
    90  			}
    91  		}
    92  		if !r.optional && item == nil {
    93  			// Missing mandatory message.
    94  			return startSeq, nil, false
    95  		}
    96  		ci[r.typ] = item
    97  	}
    98  	out := make(map[handshake.Type]handshake.Message)
    99  	seq := startSeq
   100  	for _, r := range rules {
   101  		t := r.typ
   102  		i := ci[t]
   103  		if i == nil {
   104  			continue
   105  		}
   106  		var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm
   107  		if cipherSuite != nil {
   108  			keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm()
   109  		}
   110  		rawHandshake := &handshake.Handshake{
   111  			KeyExchangeAlgorithm: keyExchangeAlgorithm,
   112  		}
   113  		if err := rawHandshake.Unmarshal(i.data); err != nil {
   114  			return startSeq, nil, false
   115  		}
   116  		if uint16(seq) != rawHandshake.Header.MessageSequence {
   117  			// There is a gap. Some messages are not arrived.
   118  			return startSeq, nil, false
   119  		}
   120  		seq++
   121  		out[t] = rawHandshake.Message
   122  	}
   123  	return seq, out, true
   124  }
   125  
   126  // pullAndMerge calls pull and then merges the results, ignoring any null entries
   127  func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte {
   128  	merged := []byte{}
   129  
   130  	for _, p := range h.pull(rules...) {
   131  		if p != nil {
   132  			merged = append(merged, p.data...)
   133  		}
   134  	}
   135  	return merged
   136  }
   137  
   138  // sessionHash returns the session hash for Extended Master Secret support
   139  // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
   140  func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
   141  	merged := []byte{}
   142  
   143  	// Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
   144  	handshakeBuffer := h.pull(
   145  		handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false},
   146  		handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false},
   147  		handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false},
   148  		handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false},
   149  		handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false},
   150  		handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false},
   151  		handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false},
   152  		handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false},
   153  	)
   154  
   155  	for _, p := range handshakeBuffer {
   156  		if p == nil {
   157  			continue
   158  		}
   159  
   160  		merged = append(merged, p.data...)
   161  	}
   162  	for _, a := range additional {
   163  		merged = append(merged, a...)
   164  	}
   165  
   166  	hash := hf()
   167  	if _, err := hash.Write(merged); err != nil {
   168  		return []byte{}, err
   169  	}
   170  
   171  	return hash.Sum(nil), nil
   172  }