github.com/quic-go/quic-go@v0.44.0/frame_sorter.go (about)

     1  package quic
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  
     7  	"github.com/quic-go/quic-go/internal/protocol"
     8  	list "github.com/quic-go/quic-go/internal/utils/linkedlist"
     9  )
    10  
    11  // byteInterval is an interval from one ByteCount to the other
    12  type byteInterval struct {
    13  	Start protocol.ByteCount
    14  	End   protocol.ByteCount
    15  }
    16  
    17  var byteIntervalElementPool sync.Pool
    18  
    19  func init() {
    20  	byteIntervalElementPool = *list.NewPool[byteInterval]()
    21  }
    22  
    23  type frameSorterEntry struct {
    24  	Data   []byte
    25  	DoneCb func()
    26  }
    27  
    28  type frameSorter struct {
    29  	queue   map[protocol.ByteCount]frameSorterEntry
    30  	readPos protocol.ByteCount
    31  	gaps    *list.List[byteInterval]
    32  }
    33  
    34  var errDuplicateStreamData = errors.New("duplicate stream data")
    35  
    36  func newFrameSorter() *frameSorter {
    37  	s := frameSorter{
    38  		gaps:  list.NewWithPool[byteInterval](&byteIntervalElementPool),
    39  		queue: make(map[protocol.ByteCount]frameSorterEntry),
    40  	}
    41  	s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
    42  	return &s
    43  }
    44  
    45  func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
    46  	err := s.push(data, offset, doneCb)
    47  	if err == errDuplicateStreamData {
    48  		if doneCb != nil {
    49  			doneCb()
    50  		}
    51  		return nil
    52  	}
    53  	return err
    54  }
    55  
    56  func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
    57  	if len(data) == 0 {
    58  		return errDuplicateStreamData
    59  	}
    60  
    61  	start := offset
    62  	end := offset + protocol.ByteCount(len(data))
    63  
    64  	if end <= s.gaps.Front().Value.Start {
    65  		return errDuplicateStreamData
    66  	}
    67  
    68  	startGap, startsInGap := s.findStartGap(start)
    69  	endGap, endsInGap := s.findEndGap(startGap, end)
    70  
    71  	startGapEqualsEndGap := startGap == endGap
    72  
    73  	if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
    74  		(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
    75  		return errDuplicateStreamData
    76  	}
    77  
    78  	startGapNext := startGap.Next()
    79  	startGapEnd := startGap.Value.End // save it, in case startGap is modified
    80  	endGapStart := endGap.Value.Start // save it, in case endGap is modified
    81  	endGapEnd := endGap.Value.End     // save it, in case endGap is modified
    82  	var adjustedStartGapEnd bool
    83  	var wasCut bool
    84  
    85  	pos := start
    86  	var hasReplacedAtLeastOne bool
    87  	for {
    88  		oldEntry, ok := s.queue[pos]
    89  		if !ok {
    90  			break
    91  		}
    92  		oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
    93  		if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
    94  			// The existing frame is shorter than the new frame. Replace it.
    95  			delete(s.queue, pos)
    96  			pos += oldEntryLen
    97  			hasReplacedAtLeastOne = true
    98  			if oldEntry.DoneCb != nil {
    99  				oldEntry.DoneCb()
   100  			}
   101  		} else {
   102  			if !hasReplacedAtLeastOne {
   103  				return errDuplicateStreamData
   104  			}
   105  			// The existing frame is longer than the new frame.
   106  			// Cut the new frame such that the end aligns with the start of the existing frame.
   107  			data = data[:pos-start]
   108  			end = pos
   109  			wasCut = true
   110  			break
   111  		}
   112  	}
   113  
   114  	if !startsInGap && !hasReplacedAtLeastOne {
   115  		// cut the frame, such that it starts at the start of the gap
   116  		data = data[startGap.Value.Start-start:]
   117  		start = startGap.Value.Start
   118  		wasCut = true
   119  	}
   120  	if start <= startGap.Value.Start {
   121  		if end >= startGap.Value.End {
   122  			// The frame covers the whole startGap. Delete the gap.
   123  			s.gaps.Remove(startGap)
   124  		} else {
   125  			startGap.Value.Start = end
   126  		}
   127  	} else if !hasReplacedAtLeastOne {
   128  		startGap.Value.End = start
   129  		adjustedStartGapEnd = true
   130  	}
   131  
   132  	if !startGapEqualsEndGap {
   133  		s.deleteConsecutive(startGapEnd)
   134  		var nextGap *list.Element[byteInterval]
   135  		for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
   136  			nextGap = gap.Next()
   137  			s.deleteConsecutive(gap.Value.End)
   138  			s.gaps.Remove(gap)
   139  		}
   140  	}
   141  
   142  	if !endsInGap && start != endGapEnd && end > endGapEnd {
   143  		// cut the frame, such that it ends at the end of the gap
   144  		data = data[:endGapEnd-start]
   145  		end = endGapEnd
   146  		wasCut = true
   147  	}
   148  	if end == endGapEnd {
   149  		if !startGapEqualsEndGap {
   150  			// The frame covers the whole endGap. Delete the gap.
   151  			s.gaps.Remove(endGap)
   152  		}
   153  	} else {
   154  		if startGapEqualsEndGap && adjustedStartGapEnd {
   155  			// The frame split the existing gap into two.
   156  			s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
   157  		} else if !startGapEqualsEndGap {
   158  			endGap.Value.Start = end
   159  		}
   160  	}
   161  
   162  	if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
   163  		newData := make([]byte, len(data))
   164  		copy(newData, data)
   165  		data = newData
   166  		if doneCb != nil {
   167  			doneCb()
   168  			doneCb = nil
   169  		}
   170  	}
   171  
   172  	if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
   173  		return errors.New("too many gaps in received data")
   174  	}
   175  
   176  	s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
   177  	return nil
   178  }
   179  
   180  func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
   181  	for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
   182  		if offset >= gap.Value.Start && offset <= gap.Value.End {
   183  			return gap, true
   184  		}
   185  		if offset < gap.Value.Start {
   186  			return gap, false
   187  		}
   188  	}
   189  	panic("no gap found")
   190  }
   191  
   192  func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
   193  	for gap := startGap; gap != nil; gap = gap.Next() {
   194  		if offset >= gap.Value.Start && offset < gap.Value.End {
   195  			return gap, true
   196  		}
   197  		if offset < gap.Value.Start {
   198  			return gap.Prev(), false
   199  		}
   200  	}
   201  	panic("no gap found")
   202  }
   203  
   204  // deleteConsecutive deletes consecutive frames from the queue, starting at pos
   205  func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
   206  	for {
   207  		oldEntry, ok := s.queue[pos]
   208  		if !ok {
   209  			break
   210  		}
   211  		oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
   212  		delete(s.queue, pos)
   213  		if oldEntry.DoneCb != nil {
   214  			oldEntry.DoneCb()
   215  		}
   216  		pos += oldEntryLen
   217  	}
   218  }
   219  
   220  func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
   221  	entry, ok := s.queue[s.readPos]
   222  	if !ok {
   223  		return s.readPos, nil, nil
   224  	}
   225  	delete(s.queue, s.readPos)
   226  	offset := s.readPos
   227  	s.readPos += protocol.ByteCount(len(entry.Data))
   228  	if s.gaps.Front().Value.End <= s.readPos {
   229  		panic("frame sorter BUG: read position higher than a gap")
   230  	}
   231  	return offset, entry.Data, entry.DoneCb
   232  }
   233  
   234  // HasMoreData says if there is any more data queued at *any* offset.
   235  func (s *frameSorter) HasMoreData() bool {
   236  	return len(s.queue) > 0
   237  }