github.com/fiatjaf/generic-ristretto@v0.0.1/sketch.go (about)

     1  /*
     2   * Copyright 2019 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // This package includes multiple probabalistic data structures needed for
    18  // admission/eviction metadata. Most are Counting Bloom Filter variations, but
    19  // a caching-specific feature that is also required is a "freshness" mechanism,
    20  // which basically serves as a "lifetime" process. This freshness mechanism
    21  // was described in the original TinyLFU paper [1], but other mechanisms may
    22  // be better suited for certain data distributions.
    23  //
    24  // [1]: https://arxiv.org/abs/1512.00727
    25  package ristretto
    26  
    27  import (
    28  	"fmt"
    29  	"math/rand"
    30  	"time"
    31  )
    32  
    33  // cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily
    34  // based on Damian Gryski's CM4 [1].
    35  //
    36  // [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go
    37  type cmSketch struct {
    38  	rows [cmDepth]cmRow
    39  	seed [cmDepth]uint64
    40  	mask uint64
    41  }
    42  
    43  const (
    44  	// cmDepth is the number of counter copies to store (think of it as rows).
    45  	cmDepth = 4
    46  )
    47  
    48  func newCmSketch(numCounters int64) *cmSketch {
    49  	if numCounters == 0 {
    50  		panic("cmSketch: bad numCounters")
    51  	}
    52  	// Get the next power of 2 for better cache performance.
    53  	numCounters = next2Power(numCounters)
    54  	sketch := &cmSketch{mask: uint64(numCounters - 1)}
    55  	// Initialize rows of counters and seeds.
    56  	// Cryptographic precision not needed
    57  	source := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec
    58  	for i := 0; i < cmDepth; i++ {
    59  		sketch.seed[i] = source.Uint64()
    60  		sketch.rows[i] = newCmRow(numCounters)
    61  	}
    62  	return sketch
    63  }
    64  
    65  // Increment increments the count(ers) for the specified key.
    66  func (s *cmSketch) Increment(hashed uint64) {
    67  	for i := range s.rows {
    68  		s.rows[i].increment((hashed ^ s.seed[i]) & s.mask)
    69  	}
    70  }
    71  
    72  // Estimate returns the value of the specified key.
    73  func (s *cmSketch) Estimate(hashed uint64) int64 {
    74  	min := byte(255)
    75  	for i := range s.rows {
    76  		val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask)
    77  		if val < min {
    78  			min = val
    79  		}
    80  	}
    81  	return int64(min)
    82  }
    83  
    84  // Reset halves all counter values.
    85  func (s *cmSketch) Reset() {
    86  	for _, r := range s.rows {
    87  		r.reset()
    88  	}
    89  }
    90  
    91  // Clear zeroes all counters.
    92  func (s *cmSketch) Clear() {
    93  	for _, r := range s.rows {
    94  		r.clear()
    95  	}
    96  }
    97  
    98  // cmRow is a row of bytes, with each byte holding two counters.
    99  type cmRow []byte
   100  
   101  func newCmRow(numCounters int64) cmRow {
   102  	return make(cmRow, numCounters/2)
   103  }
   104  
   105  func (r cmRow) get(n uint64) byte {
   106  	return byte(r[n/2]>>((n&1)*4)) & 0x0f
   107  }
   108  
   109  func (r cmRow) increment(n uint64) {
   110  	// Index of the counter.
   111  	i := n / 2
   112  	// Shift distance (even 0, odd 4).
   113  	s := (n & 1) * 4
   114  	// Counter value.
   115  	v := (r[i] >> s) & 0x0f
   116  	// Only increment if not max value (overflow wrap is bad for LFU).
   117  	if v < 15 {
   118  		r[i] += 1 << s
   119  	}
   120  }
   121  
   122  func (r cmRow) reset() {
   123  	// Halve each counter.
   124  	for i := range r {
   125  		r[i] = (r[i] >> 1) & 0x77
   126  	}
   127  }
   128  
   129  func (r cmRow) clear() {
   130  	// zero each counter.
   131  	for i := range r {
   132  		r[i] = 0
   133  	}
   134  }
   135  
   136  func (r cmRow) string() string {
   137  	s := ""
   138  	for i := uint64(0); i < uint64(len(r)*2); i++ {
   139  		s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f)
   140  	}
   141  	s = s[:len(s)-1]
   142  	return s
   143  }
   144  
   145  // next2Power rounds x up to the next power of 2, if it's not already one.
   146  func next2Power(x int64) int64 {
   147  	x--
   148  	x |= x >> 1
   149  	x |= x >> 2
   150  	x |= x >> 4
   151  	x |= x >> 8
   152  	x |= x >> 16
   153  	x |= x >> 32
   154  	x++
   155  	return x
   156  }