github.com/pion/dtls/v2@v2.2.12/fragment_buffer.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "github.com/pion/dtls/v2/pkg/protocol" 8 "github.com/pion/dtls/v2/pkg/protocol/handshake" 9 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 10 ) 11 12 // 2 megabytes 13 const fragmentBufferMaxSize = 2000000 14 15 type fragment struct { 16 recordLayerHeader recordlayer.Header 17 handshakeHeader handshake.Header 18 data []byte 19 } 20 21 type fragmentBuffer struct { 22 // map of MessageSequenceNumbers that hold slices of fragments 23 cache map[uint16][]*fragment 24 25 currentMessageSequenceNumber uint16 26 } 27 28 func newFragmentBuffer() *fragmentBuffer { 29 return &fragmentBuffer{cache: map[uint16][]*fragment{}} 30 } 31 32 // current total size of buffer 33 func (f *fragmentBuffer) size() int { 34 size := 0 35 for i := range f.cache { 36 for j := range f.cache[i] { 37 size += len(f.cache[i][j].data) 38 } 39 } 40 return size 41 } 42 43 // Attempts to push a DTLS packet to the fragmentBuffer 44 // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled 45 // when an error returns it is fatal, and the DTLS connection should be stopped 46 func (f *fragmentBuffer) push(buf []byte) (bool, error) { 47 if f.size()+len(buf) >= fragmentBufferMaxSize { 48 return false, errFragmentBufferOverflow 49 } 50 51 frag := new(fragment) 52 if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { 53 return false, err 54 } 55 56 // fragment isn't a handshake, we don't need to handle it 57 if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { 58 return false, nil 59 } 60 61 for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) { 62 if err := frag.handshakeHeader.Unmarshal(buf); err != nil { 63 return false, err 64 } 65 66 if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { 67 f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} 68 } 69 70 // end index should be the length of handshake header but if the handshake 71 // was fragmented, we should keep them all 72 end := int(handshake.HeaderLength + frag.handshakeHeader.Length) 73 if size := len(buf); end > size { 74 end = size 75 } 76 77 // Discard all headers, when rebuilding the packet we will re-build 78 frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) 79 f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag) 80 buf = buf[end:] 81 } 82 83 return true, nil 84 } 85 86 func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { 87 frags, ok := f.cache[f.currentMessageSequenceNumber] 88 if !ok { 89 return nil, 0 90 } 91 92 // Go doesn't support recursive lambdas 93 var appendMessage func(targetOffset uint32) bool 94 95 rawMessage := []byte{} 96 appendMessage = func(targetOffset uint32) bool { 97 for _, f := range frags { 98 if f.handshakeHeader.FragmentOffset == targetOffset { 99 fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength) 100 if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 { 101 if !appendMessage(fragmentEnd) { 102 return false 103 } 104 } 105 106 rawMessage = append(f.data, rawMessage...) 107 return true 108 } 109 } 110 return false 111 } 112 113 // Recursively collect up 114 if !appendMessage(0) { 115 return nil, 0 116 } 117 118 firstHeader := frags[0].handshakeHeader 119 firstHeader.FragmentOffset = 0 120 firstHeader.FragmentLength = firstHeader.Length 121 122 rawHeader, err := firstHeader.Marshal() 123 if err != nil { 124 return nil, 0 125 } 126 127 messageEpoch := frags[0].recordLayerHeader.Epoch 128 129 delete(f.cache, f.currentMessageSequenceNumber) 130 f.currentMessageSequenceNumber++ 131 return append(rawHeader, rawMessage...), messageEpoch 132 }