github.com/pion/dtls/v2@v2.2.12/fragment_buffer.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"github.com/pion/dtls/v2/pkg/protocol"
     8  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
     9  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    10  )
    11  
    12  // 2 megabytes
    13  const fragmentBufferMaxSize = 2000000
    14  
    15  type fragment struct {
    16  	recordLayerHeader recordlayer.Header
    17  	handshakeHeader   handshake.Header
    18  	data              []byte
    19  }
    20  
    21  type fragmentBuffer struct {
    22  	// map of MessageSequenceNumbers that hold slices of fragments
    23  	cache map[uint16][]*fragment
    24  
    25  	currentMessageSequenceNumber uint16
    26  }
    27  
    28  func newFragmentBuffer() *fragmentBuffer {
    29  	return &fragmentBuffer{cache: map[uint16][]*fragment{}}
    30  }
    31  
    32  // current total size of buffer
    33  func (f *fragmentBuffer) size() int {
    34  	size := 0
    35  	for i := range f.cache {
    36  		for j := range f.cache[i] {
    37  			size += len(f.cache[i][j].data)
    38  		}
    39  	}
    40  	return size
    41  }
    42  
    43  // Attempts to push a DTLS packet to the fragmentBuffer
    44  // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
    45  // when an error returns it is fatal, and the DTLS connection should be stopped
    46  func (f *fragmentBuffer) push(buf []byte) (bool, error) {
    47  	if f.size()+len(buf) >= fragmentBufferMaxSize {
    48  		return false, errFragmentBufferOverflow
    49  	}
    50  
    51  	frag := new(fragment)
    52  	if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
    53  		return false, err
    54  	}
    55  
    56  	// fragment isn't a handshake, we don't need to handle it
    57  	if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
    58  		return false, nil
    59  	}
    60  
    61  	for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
    62  		if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
    63  			return false, err
    64  		}
    65  
    66  		if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
    67  			f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
    68  		}
    69  
    70  		// end index should be the length of handshake header but if the handshake
    71  		// was fragmented, we should keep them all
    72  		end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
    73  		if size := len(buf); end > size {
    74  			end = size
    75  		}
    76  
    77  		// Discard all headers, when rebuilding the packet we will re-build
    78  		frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
    79  		f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
    80  		buf = buf[end:]
    81  	}
    82  
    83  	return true, nil
    84  }
    85  
    86  func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
    87  	frags, ok := f.cache[f.currentMessageSequenceNumber]
    88  	if !ok {
    89  		return nil, 0
    90  	}
    91  
    92  	// Go doesn't support recursive lambdas
    93  	var appendMessage func(targetOffset uint32) bool
    94  
    95  	rawMessage := []byte{}
    96  	appendMessage = func(targetOffset uint32) bool {
    97  		for _, f := range frags {
    98  			if f.handshakeHeader.FragmentOffset == targetOffset {
    99  				fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
   100  				if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 {
   101  					if !appendMessage(fragmentEnd) {
   102  						return false
   103  					}
   104  				}
   105  
   106  				rawMessage = append(f.data, rawMessage...)
   107  				return true
   108  			}
   109  		}
   110  		return false
   111  	}
   112  
   113  	// Recursively collect up
   114  	if !appendMessage(0) {
   115  		return nil, 0
   116  	}
   117  
   118  	firstHeader := frags[0].handshakeHeader
   119  	firstHeader.FragmentOffset = 0
   120  	firstHeader.FragmentLength = firstHeader.Length
   121  
   122  	rawHeader, err := firstHeader.Marshal()
   123  	if err != nil {
   124  		return nil, 0
   125  	}
   126  
   127  	messageEpoch := frags[0].recordLayerHeader.Epoch
   128  
   129  	delete(f.cache, f.currentMessageSequenceNumber)
   130  	f.currentMessageSequenceNumber++
   131  	return append(rawHeader, rawMessage...), messageEpoch
   132  }