github.com/MetalBlockchain/metalgo@v1.11.9/utils/bloom/filter.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package bloom
     5  
     6  import (
     7  	"crypto/rand"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"math/bits"
    12  	"sync"
    13  )
    14  
    15  const (
    16  	minHashes  = 1
    17  	maxHashes  = 16 // Supports a false positive probability of 2^-16 when using optimal size values
    18  	minEntries = 1
    19  
    20  	bitsPerByte    = 8
    21  	bytesPerUint64 = 8
    22  	hashRotation   = 17
    23  )
    24  
    25  var (
    26  	errInvalidNumHashes = errors.New("invalid num hashes")
    27  	errTooFewHashes     = errors.New("too few hashes")
    28  	errTooManyHashes    = errors.New("too many hashes")
    29  	errTooFewEntries    = errors.New("too few entries")
    30  )
    31  
    32  type Filter struct {
    33  	// numBits is always equal to [bitsPerByte * len(entries)]
    34  	numBits uint64
    35  
    36  	lock      sync.RWMutex
    37  	hashSeeds []uint64
    38  	entries   []byte
    39  	count     int
    40  }
    41  
    42  // New creates a new Filter with the specified number of hashes and bytes for
    43  // entries. The returned bloom filter is safe for concurrent usage.
    44  func New(numHashes, numEntries int) (*Filter, error) {
    45  	if numEntries < minEntries {
    46  		return nil, errTooFewEntries
    47  	}
    48  
    49  	hashSeeds, err := newHashSeeds(numHashes)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	return &Filter{
    55  		numBits:   uint64(numEntries * bitsPerByte),
    56  		hashSeeds: hashSeeds,
    57  		entries:   make([]byte, numEntries),
    58  		count:     0,
    59  	}, nil
    60  }
    61  
    62  func (f *Filter) Add(hash uint64) {
    63  	f.lock.Lock()
    64  	defer f.lock.Unlock()
    65  
    66  	_ = 1 % f.numBits // hint to the compiler that numBits is not 0
    67  	for _, seed := range f.hashSeeds {
    68  		hash = bits.RotateLeft64(hash, hashRotation) ^ seed
    69  		index := hash % f.numBits
    70  		byteIndex := index / bitsPerByte
    71  		bitIndex := index % bitsPerByte
    72  		f.entries[byteIndex] |= 1 << bitIndex
    73  	}
    74  	f.count++
    75  }
    76  
    77  // Count returns the number of elements that have been added to the bloom
    78  // filter.
    79  func (f *Filter) Count() int {
    80  	f.lock.RLock()
    81  	defer f.lock.RUnlock()
    82  
    83  	return f.count
    84  }
    85  
    86  func (f *Filter) Contains(hash uint64) bool {
    87  	f.lock.RLock()
    88  	defer f.lock.RUnlock()
    89  
    90  	return contains(f.hashSeeds, f.entries, hash)
    91  }
    92  
    93  func (f *Filter) Marshal() []byte {
    94  	f.lock.RLock()
    95  	defer f.lock.RUnlock()
    96  
    97  	return marshal(f.hashSeeds, f.entries)
    98  }
    99  
   100  func newHashSeeds(count int) ([]uint64, error) {
   101  	switch {
   102  	case count < minHashes:
   103  		return nil, fmt.Errorf("%w: %d < %d", errTooFewHashes, count, minHashes)
   104  	case count > maxHashes:
   105  		return nil, fmt.Errorf("%w: %d > %d", errTooManyHashes, count, maxHashes)
   106  	}
   107  
   108  	bytes := make([]byte, count*bytesPerUint64)
   109  	if _, err := rand.Reader.Read(bytes); err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	seeds := make([]uint64, count)
   114  	for i := range seeds {
   115  		seeds[i] = binary.BigEndian.Uint64(bytes[i*bytesPerUint64:])
   116  	}
   117  	return seeds, nil
   118  }
   119  
   120  func contains(hashSeeds []uint64, entries []byte, hash uint64) bool {
   121  	var (
   122  		numBits          = bitsPerByte * uint64(len(entries))
   123  		_                = 1 % numBits // hint to the compiler that numBits is not 0
   124  		accumulator byte = 1
   125  	)
   126  	for seedIndex := 0; seedIndex < len(hashSeeds) && accumulator != 0; seedIndex++ {
   127  		hash = bits.RotateLeft64(hash, hashRotation) ^ hashSeeds[seedIndex]
   128  		index := hash % numBits
   129  		byteIndex := index / bitsPerByte
   130  		bitIndex := index % bitsPerByte
   131  		accumulator &= entries[byteIndex] >> bitIndex
   132  	}
   133  	return accumulator != 0
   134  }
   135  
   136  func marshal(hashSeeds []uint64, entries []byte) []byte {
   137  	numHashes := len(hashSeeds)
   138  	entriesOffset := 1 + numHashes*bytesPerUint64
   139  
   140  	bytes := make([]byte, entriesOffset+len(entries))
   141  	bytes[0] = byte(numHashes)
   142  	for i, seed := range hashSeeds {
   143  		binary.BigEndian.PutUint64(bytes[1+i*bytesPerUint64:], seed)
   144  	}
   145  	copy(bytes[entriesOffset:], entries)
   146  	return bytes
   147  }