github.com/MetalBlockchain/metalgo@v1.11.9/utils/bloom/filter.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package bloom 5 6 import ( 7 "crypto/rand" 8 "encoding/binary" 9 "errors" 10 "fmt" 11 "math/bits" 12 "sync" 13 ) 14 15 const ( 16 minHashes = 1 17 maxHashes = 16 // Supports a false positive probability of 2^-16 when using optimal size values 18 minEntries = 1 19 20 bitsPerByte = 8 21 bytesPerUint64 = 8 22 hashRotation = 17 23 ) 24 25 var ( 26 errInvalidNumHashes = errors.New("invalid num hashes") 27 errTooFewHashes = errors.New("too few hashes") 28 errTooManyHashes = errors.New("too many hashes") 29 errTooFewEntries = errors.New("too few entries") 30 ) 31 32 type Filter struct { 33 // numBits is always equal to [bitsPerByte * len(entries)] 34 numBits uint64 35 36 lock sync.RWMutex 37 hashSeeds []uint64 38 entries []byte 39 count int 40 } 41 42 // New creates a new Filter with the specified number of hashes and bytes for 43 // entries. The returned bloom filter is safe for concurrent usage. 44 func New(numHashes, numEntries int) (*Filter, error) { 45 if numEntries < minEntries { 46 return nil, errTooFewEntries 47 } 48 49 hashSeeds, err := newHashSeeds(numHashes) 50 if err != nil { 51 return nil, err 52 } 53 54 return &Filter{ 55 numBits: uint64(numEntries * bitsPerByte), 56 hashSeeds: hashSeeds, 57 entries: make([]byte, numEntries), 58 count: 0, 59 }, nil 60 } 61 62 func (f *Filter) Add(hash uint64) { 63 f.lock.Lock() 64 defer f.lock.Unlock() 65 66 _ = 1 % f.numBits // hint to the compiler that numBits is not 0 67 for _, seed := range f.hashSeeds { 68 hash = bits.RotateLeft64(hash, hashRotation) ^ seed 69 index := hash % f.numBits 70 byteIndex := index / bitsPerByte 71 bitIndex := index % bitsPerByte 72 f.entries[byteIndex] |= 1 << bitIndex 73 } 74 f.count++ 75 } 76 77 // Count returns the number of elements that have been added to the bloom 78 // filter. 79 func (f *Filter) Count() int { 80 f.lock.RLock() 81 defer f.lock.RUnlock() 82 83 return f.count 84 } 85 86 func (f *Filter) Contains(hash uint64) bool { 87 f.lock.RLock() 88 defer f.lock.RUnlock() 89 90 return contains(f.hashSeeds, f.entries, hash) 91 } 92 93 func (f *Filter) Marshal() []byte { 94 f.lock.RLock() 95 defer f.lock.RUnlock() 96 97 return marshal(f.hashSeeds, f.entries) 98 } 99 100 func newHashSeeds(count int) ([]uint64, error) { 101 switch { 102 case count < minHashes: 103 return nil, fmt.Errorf("%w: %d < %d", errTooFewHashes, count, minHashes) 104 case count > maxHashes: 105 return nil, fmt.Errorf("%w: %d > %d", errTooManyHashes, count, maxHashes) 106 } 107 108 bytes := make([]byte, count*bytesPerUint64) 109 if _, err := rand.Reader.Read(bytes); err != nil { 110 return nil, err 111 } 112 113 seeds := make([]uint64, count) 114 for i := range seeds { 115 seeds[i] = binary.BigEndian.Uint64(bytes[i*bytesPerUint64:]) 116 } 117 return seeds, nil 118 } 119 120 func contains(hashSeeds []uint64, entries []byte, hash uint64) bool { 121 var ( 122 numBits = bitsPerByte * uint64(len(entries)) 123 _ = 1 % numBits // hint to the compiler that numBits is not 0 124 accumulator byte = 1 125 ) 126 for seedIndex := 0; seedIndex < len(hashSeeds) && accumulator != 0; seedIndex++ { 127 hash = bits.RotateLeft64(hash, hashRotation) ^ hashSeeds[seedIndex] 128 index := hash % numBits 129 byteIndex := index / bitsPerByte 130 bitIndex := index % bitsPerByte 131 accumulator &= entries[byteIndex] >> bitIndex 132 } 133 return accumulator != 0 134 } 135 136 func marshal(hashSeeds []uint64, entries []byte) []byte { 137 numHashes := len(hashSeeds) 138 entriesOffset := 1 + numHashes*bytesPerUint64 139 140 bytes := make([]byte, entriesOffset+len(entries)) 141 bytes[0] = byte(numHashes) 142 for i, seed := range hashSeeds { 143 binary.BigEndian.PutUint64(bytes[1+i*bytesPerUint64:], seed) 144 } 145 copy(bytes[entriesOffset:], entries) 146 return bytes 147 }