github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/insert.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hnsw 13 14 import ( 15 "context" 16 "fmt" 17 "math" 18 "sync/atomic" 19 "time" 20 21 "github.com/pkg/errors" 22 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 23 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 24 ) 25 26 func (h *hnsw) ValidateBeforeInsert(vector []float32) error { 27 dims := int(atomic.LoadInt32(&h.dims)) 28 29 // no vectors exist 30 if dims == 0 { 31 return nil 32 } 33 34 // check if vector length is the same as existing nodes 35 if dims != len(vector) { 36 return fmt.Errorf("new node has a vector with length %v. "+ 37 "Existing nodes have vectors with length %v", len(vector), dims) 38 } 39 40 return nil 41 } 42 43 func (h *hnsw) AddBatch(ctx context.Context, ids []uint64, vectors [][]float32) error { 44 if err := ctx.Err(); err != nil { 45 return err 46 } 47 if len(ids) != len(vectors) { 48 return errors.Errorf("ids and vectors sizes does not match") 49 } 50 if len(ids) == 0 { 51 return errors.Errorf("insertBatch called with empty lists") 52 } 53 h.trackDimensionsOnce.Do(func() { 54 atomic.StoreInt32(&h.dims, int32(len(vectors[0]))) 55 }) 56 levels := make([]int, len(ids)) 57 maxId := uint64(0) 58 for i, id := range ids { 59 if maxId < id { 60 maxId = id 61 } 62 levels[i] = int(math.Floor(-math.Log(h.randFunc()) * h.levelNormalizer)) 63 } 64 h.RLock() 65 if maxId >= uint64(len(h.nodes)) { 66 h.RUnlock() 67 h.Lock() 68 if maxId >= uint64(len(h.nodes)) { 69 err := h.growIndexToAccomodateNode(maxId, h.logger) 70 if err != nil { 71 h.Unlock() 72 return errors.Wrapf(err, "grow HNSW index to accommodate node %d", maxId) 73 } 74 } 75 h.Unlock() 76 } else { 77 h.RUnlock() 78 } 79 80 for i := range ids { 81 if err := ctx.Err(); err != nil { 82 return err 83 } 84 85 vector := vectors[i] 86 node := &vertex{ 87 id: ids[i], 88 level: levels[i], 89 } 90 globalBefore := time.Now() 91 if len(vector) == 0 { 92 return errors.Errorf("insert called with nil-vector") 93 } 94 95 h.metrics.InsertVector() 96 97 vector = h.normalizeVec(vector) 98 err := h.addOne(vector, node) 99 if err != nil { 100 return err 101 } 102 103 h.insertMetrics.total(globalBefore) 104 } 105 return nil 106 } 107 108 func (h *hnsw) addOne(vector []float32, node *vertex) error { 109 h.compressActionLock.RLock() 110 h.deleteVsInsertLock.RLock() 111 112 before := time.Now() 113 114 defer func() { 115 h.deleteVsInsertLock.RUnlock() 116 h.compressActionLock.RUnlock() 117 h.insertMetrics.updateGlobalEntrypoint(before) 118 }() 119 120 wasFirst := false 121 var firstInsertError error 122 h.initialInsertOnce.Do(func() { 123 if h.isEmpty() { 124 wasFirst = true 125 firstInsertError = h.insertInitialElement(node, vector) 126 } 127 }) 128 if wasFirst { 129 if firstInsertError != nil { 130 return firstInsertError 131 } 132 return nil 133 } 134 135 node.markAsMaintenance() 136 137 h.RLock() 138 // initially use the "global" entrypoint which is guaranteed to be on the 139 // currently highest layer 140 entryPointID := h.entryPointID 141 // initially use the level of the entrypoint which is the highest level of 142 // the h-graph in the first iteration 143 currentMaximumLayer := h.currentMaximumLayer 144 h.RUnlock() 145 146 targetLevel := node.level 147 node.connections = make([][]uint64, targetLevel+1) 148 149 for i := targetLevel; i >= 0; i-- { 150 capacity := h.maximumConnections 151 if i == 0 { 152 capacity = h.maximumConnectionsLayerZero 153 } 154 155 node.connections[i] = make([]uint64, 0, capacity) 156 } 157 158 if err := h.commitLog.AddNode(node); err != nil { 159 return err 160 } 161 162 nodeId := node.id 163 164 h.shardedNodeLocks.Lock(nodeId) 165 h.nodes[nodeId] = node 166 h.shardedNodeLocks.Unlock(nodeId) 167 168 if h.compressed.Load() { 169 h.compressor.Preload(node.id, vector) 170 } else { 171 h.cache.Preload(node.id, vector) 172 } 173 174 h.insertMetrics.prepareAndInsertNode(before) 175 before = time.Now() 176 177 var err error 178 var distancer compressionhelpers.CompressorDistancer 179 var returnFn compressionhelpers.ReturnDistancerFn 180 if h.compressed.Load() { 181 distancer, returnFn = h.compressor.NewDistancer(vector) 182 defer returnFn() 183 } 184 entryPointID, err = h.findBestEntrypointForNode(currentMaximumLayer, targetLevel, 185 entryPointID, vector, distancer) 186 if err != nil { 187 return errors.Wrap(err, "find best entrypoint") 188 } 189 190 h.insertMetrics.findEntrypoint(before) 191 before = time.Now() 192 193 // TODO: check findAndConnectNeighbors... 194 if err := h.findAndConnectNeighbors(node, entryPointID, vector, distancer, 195 targetLevel, currentMaximumLayer, helpers.NewAllowList()); err != nil { 196 return errors.Wrap(err, "find and connect neighbors") 197 } 198 199 h.insertMetrics.findAndConnectTotal(before) 200 before = time.Now() 201 202 node.unmarkAsMaintenance() 203 204 h.RLock() 205 if targetLevel > h.currentMaximumLayer { 206 h.RUnlock() 207 h.Lock() 208 // check again to avoid changes from RUnlock to Lock again 209 if targetLevel > h.currentMaximumLayer { 210 if err := h.commitLog.SetEntryPointWithMaxLayer(nodeId, targetLevel); err != nil { 211 h.Unlock() 212 return err 213 } 214 215 h.entryPointID = nodeId 216 h.currentMaximumLayer = targetLevel 217 } 218 h.Unlock() 219 } else { 220 h.RUnlock() 221 } 222 223 return nil 224 } 225 226 func (h *hnsw) Add(id uint64, vector []float32) error { 227 return h.AddBatch(context.TODO(), []uint64{id}, [][]float32{vector}) 228 } 229 230 func (h *hnsw) insertInitialElement(node *vertex, nodeVec []float32) error { 231 h.Lock() 232 defer h.Unlock() 233 234 if err := h.commitLog.SetEntryPointWithMaxLayer(node.id, 0); err != nil { 235 return err 236 } 237 238 h.entryPointID = node.id 239 h.currentMaximumLayer = 0 240 node.connections = [][]uint64{ 241 make([]uint64, 0, h.maximumConnectionsLayerZero), 242 } 243 node.level = 0 244 if err := h.commitLog.AddNode(node); err != nil { 245 return err 246 } 247 248 err := h.growIndexToAccomodateNode(node.id, h.logger) 249 if err != nil { 250 return errors.Wrapf(err, "grow HNSW index to accommodate node %d", node.id) 251 } 252 253 h.shardedNodeLocks.Lock(node.id) 254 h.nodes[node.id] = node 255 h.shardedNodeLocks.Unlock(node.id) 256 257 if h.compressed.Load() { 258 h.compressor.Preload(node.id, nodeVec) 259 } else { 260 h.cache.Preload(node.id, nodeVec) 261 } 262 263 // go h.insertHook(node.id, 0, node.connections) 264 return nil 265 }