github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/cache/sharded_lock_cache.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package cache 13 14 import ( 15 "context" 16 "sync" 17 "sync/atomic" 18 "time" 19 "unsafe" 20 21 enterrors "github.com/weaviate/weaviate/entities/errors" 22 23 "github.com/sirupsen/logrus" 24 "github.com/weaviate/weaviate/adapters/repos/db/vector/common" 25 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 26 ) 27 28 type shardedLockCache[T float32 | byte | uint64] struct { 29 shardedLocks *common.ShardedRWLocks 30 cache [][]T 31 vectorForID common.VectorForID[T] 32 normalizeOnRead bool 33 maxSize int64 34 count int64 35 cancel chan bool 36 logger logrus.FieldLogger 37 deletionInterval time.Duration 38 39 // The maintenanceLock makes sure that only one maintenance operation, such 40 // as growing the cache or clearing the cache happens at the same time. 41 maintenanceLock sync.RWMutex 42 } 43 44 const ( 45 InitialSize = 1000 46 MinimumIndexGrowthDelta = 2000 47 indexGrowthRate = 1.25 48 ) 49 50 func NewShardedFloat32LockCache(vecForID common.VectorForID[float32], maxSize int, 51 logger logrus.FieldLogger, normalizeOnRead bool, deletionInterval time.Duration, 52 ) Cache[float32] { 53 vc := &shardedLockCache[float32]{ 54 vectorForID: func(ctx context.Context, id uint64) ([]float32, error) { 55 vec, err := vecForID(ctx, id) 56 if err != nil { 57 return nil, err 58 } 59 if normalizeOnRead { 60 vec = distancer.Normalize(vec) 61 } 62 return vec, nil 63 }, 64 cache: make([][]float32, InitialSize), 65 normalizeOnRead: normalizeOnRead, 66 count: 0, 67 maxSize: int64(maxSize), 68 cancel: make(chan bool), 69 logger: logger, 70 shardedLocks: common.NewDefaultShardedRWLocks(), 71 maintenanceLock: sync.RWMutex{}, 72 deletionInterval: deletionInterval, 73 } 74 75 vc.watchForDeletion() 76 return vc 77 } 78 79 func NewShardedByteLockCache(vecForID common.VectorForID[byte], maxSize int, 80 logger logrus.FieldLogger, deletionInterval time.Duration, 81 ) Cache[byte] { 82 vc := &shardedLockCache[byte]{ 83 vectorForID: vecForID, 84 cache: make([][]byte, InitialSize), 85 normalizeOnRead: false, 86 count: 0, 87 maxSize: int64(maxSize), 88 cancel: make(chan bool), 89 logger: logger, 90 shardedLocks: common.NewDefaultShardedRWLocks(), 91 maintenanceLock: sync.RWMutex{}, 92 deletionInterval: deletionInterval, 93 } 94 95 vc.watchForDeletion() 96 return vc 97 } 98 99 func NewShardedUInt64LockCache(vecForID common.VectorForID[uint64], maxSize int, 100 logger logrus.FieldLogger, deletionInterval time.Duration, 101 ) Cache[uint64] { 102 vc := &shardedLockCache[uint64]{ 103 vectorForID: vecForID, 104 cache: make([][]uint64, InitialSize), 105 normalizeOnRead: false, 106 count: 0, 107 maxSize: int64(maxSize), 108 cancel: make(chan bool), 109 logger: logger, 110 shardedLocks: common.NewDefaultShardedRWLocks(), 111 maintenanceLock: sync.RWMutex{}, 112 deletionInterval: deletionInterval, 113 } 114 115 vc.watchForDeletion() 116 return vc 117 } 118 119 func (s *shardedLockCache[T]) All() [][]T { 120 return s.cache 121 } 122 123 func (s *shardedLockCache[T]) Get(ctx context.Context, id uint64) ([]T, error) { 124 s.shardedLocks.RLock(id) 125 vec := s.cache[id] 126 s.shardedLocks.RUnlock(id) 127 128 if vec != nil { 129 return vec, nil 130 } 131 132 return s.handleCacheMiss(ctx, id) 133 } 134 135 func (s *shardedLockCache[T]) Delete(ctx context.Context, id uint64) { 136 s.shardedLocks.Lock(id) 137 defer s.shardedLocks.Unlock(id) 138 139 if int(id) >= len(s.cache) || s.cache[id] == nil { 140 return 141 } 142 143 s.cache[id] = nil 144 atomic.AddInt64(&s.count, -1) 145 } 146 147 func (s *shardedLockCache[T]) handleCacheMiss(ctx context.Context, id uint64) ([]T, error) { 148 vec, err := s.vectorForID(ctx, id) 149 if err != nil { 150 return nil, err 151 } 152 153 atomic.AddInt64(&s.count, 1) 154 s.shardedLocks.Lock(id) 155 s.cache[id] = vec 156 s.shardedLocks.Unlock(id) 157 158 return vec, nil 159 } 160 161 func (s *shardedLockCache[T]) MultiGet(ctx context.Context, ids []uint64) ([][]T, []error) { 162 out := make([][]T, len(ids)) 163 errs := make([]error, len(ids)) 164 165 for i, id := range ids { 166 s.shardedLocks.RLock(id) 167 vec := s.cache[id] 168 s.shardedLocks.RUnlock(id) 169 170 if vec == nil { 171 vecFromDisk, err := s.handleCacheMiss(ctx, id) 172 errs[i] = err 173 vec = vecFromDisk 174 } 175 176 out[i] = vec 177 } 178 179 return out, errs 180 } 181 182 var prefetchFunc func(in uintptr) = func(in uintptr) { 183 // do nothing on default arch 184 // this function will be overridden for amd64 185 } 186 187 func (s *shardedLockCache[T]) Prefetch(id uint64) { 188 s.shardedLocks.RLock(id) 189 defer s.shardedLocks.RUnlock(id) 190 191 prefetchFunc(uintptr(unsafe.Pointer(&s.cache[id]))) 192 } 193 194 func (s *shardedLockCache[T]) Preload(id uint64, vec []T) { 195 s.shardedLocks.Lock(id) 196 defer s.shardedLocks.Unlock(id) 197 198 atomic.AddInt64(&s.count, 1) 199 s.cache[id] = vec 200 } 201 202 func (s *shardedLockCache[T]) Grow(node uint64) { 203 s.maintenanceLock.RLock() 204 if node < uint64(len(s.cache)) { 205 s.maintenanceLock.RUnlock() 206 return 207 } 208 s.maintenanceLock.RUnlock() 209 210 s.maintenanceLock.Lock() 211 defer s.maintenanceLock.Unlock() 212 213 // make sure cache still needs growing 214 // (it could have grown while waiting for maintenance lock) 215 if node < uint64(len(s.cache)) { 216 return 217 } 218 219 s.shardedLocks.LockAll() 220 defer s.shardedLocks.UnlockAll() 221 222 newSize := node + MinimumIndexGrowthDelta 223 newCache := make([][]T, newSize) 224 copy(newCache, s.cache) 225 s.cache = newCache 226 } 227 228 func (s *shardedLockCache[T]) Len() int32 { 229 s.maintenanceLock.RLock() 230 defer s.maintenanceLock.RUnlock() 231 232 return int32(len(s.cache)) 233 } 234 235 func (s *shardedLockCache[T]) CountVectors() int64 { 236 return atomic.LoadInt64(&s.count) 237 } 238 239 func (s *shardedLockCache[T]) Drop() { 240 s.deleteAllVectors() 241 if s.deletionInterval != 0 { 242 s.cancel <- true 243 } 244 } 245 246 func (s *shardedLockCache[T]) deleteAllVectors() { 247 s.shardedLocks.LockAll() 248 defer s.shardedLocks.UnlockAll() 249 250 s.logger.WithField("action", "hnsw_delete_vector_cache"). 251 Debug("deleting full vector cache") 252 for i := range s.cache { 253 s.cache[i] = nil 254 } 255 256 atomic.StoreInt64(&s.count, 0) 257 } 258 259 func (s *shardedLockCache[T]) watchForDeletion() { 260 if s.deletionInterval != 0 { 261 f := func() { 262 t := time.NewTicker(s.deletionInterval) 263 defer t.Stop() 264 for { 265 select { 266 case <-s.cancel: 267 return 268 case <-t.C: 269 s.replaceIfFull() 270 } 271 } 272 } 273 enterrors.GoWrapper(f, s.logger) 274 } 275 } 276 277 func (s *shardedLockCache[T]) replaceIfFull() { 278 if atomic.LoadInt64(&s.count) >= atomic.LoadInt64(&s.maxSize) { 279 s.deleteAllVectors() 280 } 281 } 282 283 func (s *shardedLockCache[T]) UpdateMaxSize(size int64) { 284 atomic.StoreInt64(&s.maxSize, size) 285 } 286 287 func (s *shardedLockCache[T]) CopyMaxSize() int64 { 288 sizeCopy := atomic.LoadInt64(&s.maxSize) 289 return sizeCopy 290 } 291 292 // noopCache can be helpful in debugging situations, where we want to 293 // explicitly pass through each vectorForID call to the underlying vectorForID 294 // function without caching in between. 295 type noopCache struct { 296 vectorForID common.VectorForID[float32] 297 } 298 299 func NewNoopCache(vecForID common.VectorForID[float32], maxSize int, 300 logger logrus.FieldLogger, 301 ) *noopCache { 302 return &noopCache{vectorForID: vecForID} 303 }