github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/slidingwindow.go (about)

     1  package ss2022
     2  
     3  import "math/bits"
     4  
     5  const swfBlockBits = bits.UintSize
     6  
     7  // SlidingWindowFilter maintains a sliding window of uint64 counters.
     8  type SlidingWindowFilter struct {
     9  	size               uint64
    10  	last               uint64
    11  	ring               []uint
    12  	ringBlockIndexMask uint64
    13  }
    14  
    15  // NewSlidingWindowFilter returns a new sliding window filter with the given size.
    16  func NewSlidingWindowFilter(size uint64) *SlidingWindowFilter {
    17  	ringBits := uint64(1 << bits.Len64(size+swfBlockBits-1))
    18  	ringBlocks := ringBits / swfBlockBits
    19  	return &SlidingWindowFilter{
    20  		size:               size,
    21  		ring:               make([]uint, ringBlocks),
    22  		ringBlockIndexMask: ringBlocks - 1,
    23  	}
    24  }
    25  
    26  // Size returns the size of the sliding window.
    27  func (f *SlidingWindowFilter) Size() uint64 {
    28  	return f.size
    29  }
    30  
    31  // Reset resets the filter to its initial state.
    32  func (f *SlidingWindowFilter) Reset() {
    33  	f.last = 0
    34  	f.ring[0] = 0
    35  }
    36  
    37  func (*SlidingWindowFilter) unmaskedBlockIndex(counter uint64) uint64 {
    38  	return counter / swfBlockBits
    39  }
    40  
    41  func (f *SlidingWindowFilter) blockIndex(counter uint64) uint64 {
    42  	return counter / swfBlockBits & f.ringBlockIndexMask
    43  }
    44  
    45  func (*SlidingWindowFilter) bitIndex(counter uint64) uint64 {
    46  	return counter % swfBlockBits
    47  }
    48  
    49  // IsOk checks whether counter can be accepted by the sliding window filter.
    50  func (f *SlidingWindowFilter) IsOk(counter uint64) bool {
    51  	// Accept counter if it is ahead of window.
    52  	if counter > f.last {
    53  		return true
    54  	}
    55  
    56  	// Reject counter if it is behind window.
    57  	if f.last-counter >= f.size {
    58  		return false
    59  	}
    60  
    61  	// Within window, accept if not seen by window.
    62  	return f.ring[f.blockIndex(counter)]&(1<<f.bitIndex(counter)) == 0
    63  }
    64  
    65  // MustAdd adds counter to the sliding window without checking if the counter is valid.
    66  // Call IsOk beforehand to make sure the counter is valid.
    67  func (f *SlidingWindowFilter) MustAdd(counter uint64) {
    68  	blockIndex := f.unmaskedBlockIndex(counter)
    69  
    70  	// When counter is ahead of window, clear blocks ahead.
    71  	if counter > f.last {
    72  		lastBlockIndex := f.unmaskedBlockIndex(f.last)
    73  		clearBlockCount := min(int(blockIndex-lastBlockIndex), len(f.ring))
    74  
    75  		// Clear blocks ahead.
    76  		for range clearBlockCount {
    77  			lastBlockIndex = (lastBlockIndex + 1) & f.ringBlockIndexMask
    78  			f.ring[lastBlockIndex] = 0
    79  		}
    80  
    81  		f.last = counter
    82  	}
    83  
    84  	blockIndex &= f.ringBlockIndexMask
    85  	f.ring[blockIndex] |= 1 << f.bitIndex(counter)
    86  }
    87  
    88  // Add attempts to add counter to the sliding window and returns
    89  // whether the counter is successfully added to the sliding window.
    90  func (f *SlidingWindowFilter) Add(counter uint64) bool {
    91  	unmaskedBlockIndex := f.unmaskedBlockIndex(counter)
    92  	blockIndex := unmaskedBlockIndex & f.ringBlockIndexMask
    93  	bitIndex := f.bitIndex(counter)
    94  
    95  	switch {
    96  	case counter > f.last: // Ahead of window, clear blocks ahead.
    97  		lastBlockIndex := f.unmaskedBlockIndex(f.last)
    98  		clearBlockCount := min(int(unmaskedBlockIndex-lastBlockIndex), len(f.ring))
    99  
   100  		// Clear blocks ahead.
   101  		for range clearBlockCount {
   102  			lastBlockIndex = (lastBlockIndex + 1) & f.ringBlockIndexMask
   103  			f.ring[lastBlockIndex] = 0
   104  		}
   105  
   106  		f.last = counter
   107  
   108  	case f.last-counter >= f.size: // Behind window.
   109  		return false
   110  
   111  	case f.ring[blockIndex]&(1<<bitIndex) != 0: // Within window, already seen.
   112  		return false
   113  	}
   114  
   115  	f.ring[blockIndex] |= 1 << bitIndex
   116  	return true
   117  }