github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/insert.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"github.com/pkg/errors"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    24  )
    25  
    26  func (h *hnsw) ValidateBeforeInsert(vector []float32) error {
    27  	dims := int(atomic.LoadInt32(&h.dims))
    28  
    29  	// no vectors exist
    30  	if dims == 0 {
    31  		return nil
    32  	}
    33  
    34  	// check if vector length is the same as existing nodes
    35  	if dims != len(vector) {
    36  		return fmt.Errorf("new node has a vector with length %v. "+
    37  			"Existing nodes have vectors with length %v", len(vector), dims)
    38  	}
    39  
    40  	return nil
    41  }
    42  
    43  func (h *hnsw) AddBatch(ctx context.Context, ids []uint64, vectors [][]float32) error {
    44  	if err := ctx.Err(); err != nil {
    45  		return err
    46  	}
    47  	if len(ids) != len(vectors) {
    48  		return errors.Errorf("ids and vectors sizes does not match")
    49  	}
    50  	if len(ids) == 0 {
    51  		return errors.Errorf("insertBatch called with empty lists")
    52  	}
    53  	h.trackDimensionsOnce.Do(func() {
    54  		atomic.StoreInt32(&h.dims, int32(len(vectors[0])))
    55  	})
    56  	levels := make([]int, len(ids))
    57  	maxId := uint64(0)
    58  	for i, id := range ids {
    59  		if maxId < id {
    60  			maxId = id
    61  		}
    62  		levels[i] = int(math.Floor(-math.Log(h.randFunc()) * h.levelNormalizer))
    63  	}
    64  	h.RLock()
    65  	if maxId >= uint64(len(h.nodes)) {
    66  		h.RUnlock()
    67  		h.Lock()
    68  		if maxId >= uint64(len(h.nodes)) {
    69  			err := h.growIndexToAccomodateNode(maxId, h.logger)
    70  			if err != nil {
    71  				h.Unlock()
    72  				return errors.Wrapf(err, "grow HNSW index to accommodate node %d", maxId)
    73  			}
    74  		}
    75  		h.Unlock()
    76  	} else {
    77  		h.RUnlock()
    78  	}
    79  
    80  	for i := range ids {
    81  		if err := ctx.Err(); err != nil {
    82  			return err
    83  		}
    84  
    85  		vector := vectors[i]
    86  		node := &vertex{
    87  			id:    ids[i],
    88  			level: levels[i],
    89  		}
    90  		globalBefore := time.Now()
    91  		if len(vector) == 0 {
    92  			return errors.Errorf("insert called with nil-vector")
    93  		}
    94  
    95  		h.metrics.InsertVector()
    96  
    97  		vector = h.normalizeVec(vector)
    98  		err := h.addOne(vector, node)
    99  		if err != nil {
   100  			return err
   101  		}
   102  
   103  		h.insertMetrics.total(globalBefore)
   104  	}
   105  	return nil
   106  }
   107  
   108  func (h *hnsw) addOne(vector []float32, node *vertex) error {
   109  	h.compressActionLock.RLock()
   110  	h.deleteVsInsertLock.RLock()
   111  
   112  	before := time.Now()
   113  
   114  	defer func() {
   115  		h.deleteVsInsertLock.RUnlock()
   116  		h.compressActionLock.RUnlock()
   117  		h.insertMetrics.updateGlobalEntrypoint(before)
   118  	}()
   119  
   120  	wasFirst := false
   121  	var firstInsertError error
   122  	h.initialInsertOnce.Do(func() {
   123  		if h.isEmpty() {
   124  			wasFirst = true
   125  			firstInsertError = h.insertInitialElement(node, vector)
   126  		}
   127  	})
   128  	if wasFirst {
   129  		if firstInsertError != nil {
   130  			return firstInsertError
   131  		}
   132  		return nil
   133  	}
   134  
   135  	node.markAsMaintenance()
   136  
   137  	h.RLock()
   138  	// initially use the "global" entrypoint which is guaranteed to be on the
   139  	// currently highest layer
   140  	entryPointID := h.entryPointID
   141  	// initially use the level of the entrypoint which is the highest level of
   142  	// the h-graph in the first iteration
   143  	currentMaximumLayer := h.currentMaximumLayer
   144  	h.RUnlock()
   145  
   146  	targetLevel := node.level
   147  	node.connections = make([][]uint64, targetLevel+1)
   148  
   149  	for i := targetLevel; i >= 0; i-- {
   150  		capacity := h.maximumConnections
   151  		if i == 0 {
   152  			capacity = h.maximumConnectionsLayerZero
   153  		}
   154  
   155  		node.connections[i] = make([]uint64, 0, capacity)
   156  	}
   157  
   158  	if err := h.commitLog.AddNode(node); err != nil {
   159  		return err
   160  	}
   161  
   162  	nodeId := node.id
   163  
   164  	h.shardedNodeLocks.Lock(nodeId)
   165  	h.nodes[nodeId] = node
   166  	h.shardedNodeLocks.Unlock(nodeId)
   167  
   168  	if h.compressed.Load() {
   169  		h.compressor.Preload(node.id, vector)
   170  	} else {
   171  		h.cache.Preload(node.id, vector)
   172  	}
   173  
   174  	h.insertMetrics.prepareAndInsertNode(before)
   175  	before = time.Now()
   176  
   177  	var err error
   178  	var distancer compressionhelpers.CompressorDistancer
   179  	var returnFn compressionhelpers.ReturnDistancerFn
   180  	if h.compressed.Load() {
   181  		distancer, returnFn = h.compressor.NewDistancer(vector)
   182  		defer returnFn()
   183  	}
   184  	entryPointID, err = h.findBestEntrypointForNode(currentMaximumLayer, targetLevel,
   185  		entryPointID, vector, distancer)
   186  	if err != nil {
   187  		return errors.Wrap(err, "find best entrypoint")
   188  	}
   189  
   190  	h.insertMetrics.findEntrypoint(before)
   191  	before = time.Now()
   192  
   193  	// TODO: check findAndConnectNeighbors...
   194  	if err := h.findAndConnectNeighbors(node, entryPointID, vector, distancer,
   195  		targetLevel, currentMaximumLayer, helpers.NewAllowList()); err != nil {
   196  		return errors.Wrap(err, "find and connect neighbors")
   197  	}
   198  
   199  	h.insertMetrics.findAndConnectTotal(before)
   200  	before = time.Now()
   201  
   202  	node.unmarkAsMaintenance()
   203  
   204  	h.RLock()
   205  	if targetLevel > h.currentMaximumLayer {
   206  		h.RUnlock()
   207  		h.Lock()
   208  		// check again to avoid changes from RUnlock to Lock again
   209  		if targetLevel > h.currentMaximumLayer {
   210  			if err := h.commitLog.SetEntryPointWithMaxLayer(nodeId, targetLevel); err != nil {
   211  				h.Unlock()
   212  				return err
   213  			}
   214  
   215  			h.entryPointID = nodeId
   216  			h.currentMaximumLayer = targetLevel
   217  		}
   218  		h.Unlock()
   219  	} else {
   220  		h.RUnlock()
   221  	}
   222  
   223  	return nil
   224  }
   225  
   226  func (h *hnsw) Add(id uint64, vector []float32) error {
   227  	return h.AddBatch(context.TODO(), []uint64{id}, [][]float32{vector})
   228  }
   229  
   230  func (h *hnsw) insertInitialElement(node *vertex, nodeVec []float32) error {
   231  	h.Lock()
   232  	defer h.Unlock()
   233  
   234  	if err := h.commitLog.SetEntryPointWithMaxLayer(node.id, 0); err != nil {
   235  		return err
   236  	}
   237  
   238  	h.entryPointID = node.id
   239  	h.currentMaximumLayer = 0
   240  	node.connections = [][]uint64{
   241  		make([]uint64, 0, h.maximumConnectionsLayerZero),
   242  	}
   243  	node.level = 0
   244  	if err := h.commitLog.AddNode(node); err != nil {
   245  		return err
   246  	}
   247  
   248  	err := h.growIndexToAccomodateNode(node.id, h.logger)
   249  	if err != nil {
   250  		return errors.Wrapf(err, "grow HNSW index to accommodate node %d", node.id)
   251  	}
   252  
   253  	h.shardedNodeLocks.Lock(node.id)
   254  	h.nodes[node.id] = node
   255  	h.shardedNodeLocks.Unlock(node.id)
   256  
   257  	if h.compressed.Load() {
   258  		h.compressor.Preload(node.id, nodeVec)
   259  	} else {
   260  		h.cache.Preload(node.id, nodeVec)
   261  	}
   262  
   263  	// go h.insertHook(node.id, 0, node.connections)
   264  	return nil
   265  }