github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/bloom/bloom.go (about)

     1  // Package bloom provides a simple bloom filter implementation.
     2  package bloom
     3  
     4  import (
     5  	"fmt"
     6  	"hash"
     7  	"io"
     8  	"math"
     9  	_ "unsafe"
    10  
    11  	"github.com/fluhus/gostuff/bnry"
    12  	"github.com/spaolacci/murmur3"
    13  )
    14  
    15  //go:linkname fastrand runtime.fastrand
    16  func fastrand() uint32
    17  
    18  // Filter is a single bloom filter.
    19  type Filter struct {
    20  	b    []byte        // Filter data.
    21  	h    []hash.Hash64 // Hash functions.
    22  	seed uint32
    23  }
    24  
    25  // NHash returns the number of hash functions this filter uses.
    26  func (f *Filter) NHash() int {
    27  	return len(f.h)
    28  }
    29  
    30  // NBits returns the number of bits this filter uses.
    31  func (f *Filter) NBits() int {
    32  	return 8 * len(f.b)
    33  }
    34  
    35  // NElements returns an approximation of the number of elements added to the
    36  // filter.
    37  func (f *Filter) NElements() int {
    38  	m := float64(f.NBits())
    39  	k := float64(f.NHash())
    40  	x := 0.0 // Number of bits that are 1.
    41  	for _, bt := range f.b {
    42  		for bt > 0 {
    43  			if bt&1 > 0 {
    44  				x++
    45  			}
    46  			bt >>= 1
    47  		}
    48  	}
    49  	return int(math.Round(-m / k * math.Log(1-x/m)))
    50  }
    51  
    52  // Has checks if all k hash values of v were encountered.
    53  // Makes at most k hash calculations.
    54  func (f *Filter) Has(v []byte) bool {
    55  	for i := range f.h {
    56  		f.h[i].Reset()
    57  		f.h[i].Write(v)
    58  		hash := int(f.h[i].Sum64() % uint64(len(f.b)*8))
    59  		if getBit(f.b, hash) == 0 {
    60  			return false
    61  		}
    62  	}
    63  	return true
    64  }
    65  
    66  // Add adds v to the filter, and returns the value of Has(v) before adding.
    67  // After calling Add, Has(v) will always be true. Makes k calls to hash.
    68  func (f *Filter) Add(v []byte) bool {
    69  	has := true
    70  	for i := range f.h {
    71  		f.h[i].Reset()
    72  		f.h[i].Write(v)
    73  		hash := int(f.h[i].Sum64() % uint64(len(f.b)*8))
    74  		if getBit(f.b, hash) == 0 {
    75  			has = false
    76  			setBit(f.b, hash, 1)
    77  		}
    78  	}
    79  	return has
    80  }
    81  
    82  // AddFilter merges other into f. After merging, f is equivalent to have been added
    83  // all the elements of other.
    84  func (f *Filter) AddFilter(other *Filter) {
    85  	// Make sure the two filters are compatible.
    86  	if f.NBits() != other.NBits() {
    87  		panic(fmt.Sprintf("mismatching number of bits: this has %v, other has %v",
    88  			f.NBits(), other.NBits()))
    89  	}
    90  	if f.NHash() != other.NHash() {
    91  		panic(fmt.Sprintf("mismatching number of hashes: this has %v, other has %v",
    92  			f.NHash(), other.NHash()))
    93  	}
    94  	if f.Seed() != other.Seed() {
    95  		panic(fmt.Sprintf("mismatching seeds: this has %v, other has %v",
    96  			f.Seed(), other.Seed()))
    97  	}
    98  
    99  	// Merge.
   100  	for i := range f.b {
   101  		f.b[i] |= other.b[i]
   102  	}
   103  }
   104  
   105  // Seed returns the hash seed of this filter.
   106  // A new filter starts with a random seed.
   107  func (f *Filter) Seed() uint32 {
   108  	return f.seed
   109  }
   110  
   111  // SetSeed sets the hash seed of this filter.
   112  // The filter must be empty.
   113  func (f *Filter) SetSeed(seed uint32) {
   114  	for _, b := range f.b {
   115  		if b != 0 {
   116  			panic("cannot change seed after elements were added")
   117  		}
   118  	}
   119  	f.seed = seed
   120  	h := murmur3.New32WithSeed(seed)
   121  	for i := range f.h {
   122  		h.Write([]byte{1})
   123  		f.h[i] = murmur3.New64WithSeed(h.Sum32())
   124  	}
   125  }
   126  
   127  // Encode writes this filter to the stream. Can be reproduced later with Decode.
   128  func (f *Filter) Encode(w io.Writer) error {
   129  	// Order is k, seed, bytes.
   130  	return bnry.Write(w, uint64(len(f.h)), f.seed, f.b)
   131  }
   132  
   133  // Decode reads an encoded filter from the stream and sets this filter's state
   134  // to match it. Destroys the previously existing state of this filter.
   135  func (f *Filter) Decode(r io.ByteReader) error {
   136  	var k uint64
   137  	var seed uint32
   138  	var b []byte
   139  	if err := bnry.Read(r, &k, &seed, &b); err != nil {
   140  		return err
   141  	}
   142  	f.h = make([]hash.Hash64, k)
   143  	f.SetSeed(uint32(seed))
   144  	f.b = b
   145  
   146  	return nil
   147  }
   148  
   149  // New creates a new bloom filter with the given parameters. Number of
   150  // bits is rounded up to the nearest multiple of 8.
   151  //
   152  // See NewOptimal for an alternative way to decide on the parameters.
   153  func New(bits int, k int) *Filter {
   154  	if bits < 1 {
   155  		panic(fmt.Sprintf("number of bits should be at least 1, got %v", bits))
   156  	}
   157  	if k < 1 {
   158  		panic(fmt.Sprintf("k should be at least 1, got %v", k))
   159  	}
   160  
   161  	result := &Filter{
   162  		b: make([]byte, ((bits-1)/8)+1),
   163  		h: make([]hash.Hash64, k),
   164  	}
   165  	result.SetSeed(fastrand())
   166  	return result
   167  }
   168  
   169  // NewOptimal creates a new bloom filter, with parameters optimal for the
   170  // expected number of elements (n) and the required false-positive rate (p).
   171  //
   172  // The calculation is taken from:
   173  // https://en.wikipedia.org/wiki/Bloom_filter#Optimal_number_of_hash_functions
   174  func NewOptimal(n int, p float64) *Filter {
   175  	m := math.Round(-float64(n) * math.Log(p) / math.Ln2 / math.Ln2)
   176  	k := math.Round(-math.Log2(p))
   177  	return New(int(m), int(k))
   178  }
   179  
   180  // Returns the value of the n'th bit in a byte slice.
   181  func getBit(b []byte, n int) int {
   182  	return int(b[n/8] >> (n % 8) & 1)
   183  }
   184  
   185  // Sets the value of the n'th bit in a byte slice.
   186  func setBit(b []byte, n, v int) {
   187  	if v == 0 {
   188  		b[n/8] &= ^(byte(1) << (n % 8))
   189  	} else if v == 1 {
   190  		b[n/8] |= byte(1) << (n % 8)
   191  	} else {
   192  		panic(fmt.Sprintf("bad value: %v, expected 0 or 1", v))
   193  	}
   194  }