github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/batch.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package db
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/pkg/errors"
    20  	"github.com/weaviate/weaviate/entities/additional"
    21  	"github.com/weaviate/weaviate/entities/schema"
    22  	"github.com/weaviate/weaviate/entities/storobj"
    23  	"github.com/weaviate/weaviate/usecases/objects"
    24  )
    25  
    26  type batchQueue struct {
    27  	objects       []*storobj.Object
    28  	originalIndex []int
    29  }
    30  
    31  func (db *DB) BatchPutObjects(ctx context.Context, objs objects.BatchObjects,
    32  	repl *additional.ReplicationProperties,
    33  ) (objects.BatchObjects, error) {
    34  	objectByClass := make(map[string]batchQueue)
    35  	indexByClass := make(map[string]*Index)
    36  
    37  	if err := db.memMonitor.CheckAlloc(estimateBatchMemory(objs)); err != nil {
    38  		return nil, fmt.Errorf("cannot process batch: %w", err)
    39  	}
    40  
    41  	for _, item := range objs {
    42  		if item.Err != nil {
    43  			// item has a validation error or another reason to ignore
    44  			continue
    45  		}
    46  		queue := objectByClass[item.Object.Class]
    47  		queue.objects = append(queue.objects, storobj.FromObject(item.Object, item.Object.Vector, item.Object.Vectors))
    48  		queue.originalIndex = append(queue.originalIndex, item.OriginalIndex)
    49  		objectByClass[item.Object.Class] = queue
    50  	}
    51  
    52  	// wrapped by func to acquire and safely release indexLock only for duration of loop
    53  	func() {
    54  		db.indexLock.RLock()
    55  		defer db.indexLock.RUnlock()
    56  
    57  		for class, queue := range objectByClass {
    58  			index, ok := db.indices[indexID(schema.ClassName(class))]
    59  			if !ok {
    60  				msg := fmt.Sprintf("could not find index for class %v. It might have been deleted in the meantime", class)
    61  				db.logger.Warn(msg)
    62  				for _, origIdx := range queue.originalIndex {
    63  					if origIdx >= len(objs) {
    64  						db.logger.Errorf(
    65  							"batch add queue index out of bounds. len(objs) == %d, queue.originalIndex == %d",
    66  							len(objs), origIdx)
    67  						break
    68  					}
    69  					objs[origIdx].Err = fmt.Errorf(msg)
    70  				}
    71  				continue
    72  			}
    73  			index.dropIndex.RLock()
    74  			indexByClass[class] = index
    75  		}
    76  	}()
    77  
    78  	// safely release remaining locks (in case of panic)
    79  	defer func() {
    80  		for _, index := range indexByClass {
    81  			if index != nil {
    82  				index.dropIndex.RUnlock()
    83  			}
    84  		}
    85  	}()
    86  
    87  	for class, index := range indexByClass {
    88  		queue := objectByClass[class]
    89  		errs := index.putObjectBatch(ctx, queue.objects, repl)
    90  		// remove index from map to skip releasing its lock in defer
    91  		indexByClass[class] = nil
    92  		index.dropIndex.RUnlock()
    93  		for i, err := range errs {
    94  			if err != nil {
    95  				objs[queue.originalIndex[i]].Err = err
    96  			}
    97  		}
    98  	}
    99  
   100  	return objs, nil
   101  }
   102  
   103  func (db *DB) AddBatchReferences(ctx context.Context, references objects.BatchReferences,
   104  	repl *additional.ReplicationProperties,
   105  ) (objects.BatchReferences, error) {
   106  	refByClass := make(map[schema.ClassName]objects.BatchReferences)
   107  	indexByClass := make(map[schema.ClassName]*Index)
   108  
   109  	for _, item := range references {
   110  		if item.Err != nil {
   111  			// item has a validation error or another reason to ignore
   112  			continue
   113  		}
   114  		refByClass[item.From.Class] = append(refByClass[item.From.Class], item)
   115  	}
   116  
   117  	// wrapped by func to acquire and safely release indexLock only for duration of loop
   118  	func() {
   119  		db.indexLock.RLock()
   120  		defer db.indexLock.RUnlock()
   121  
   122  		for class, queue := range refByClass {
   123  			index, ok := db.indices[indexID(class)]
   124  			if !ok {
   125  				for _, item := range queue {
   126  					references[item.OriginalIndex].Err = fmt.Errorf("could not find index for class %v. It might have been deleted in the meantime", class)
   127  				}
   128  				continue
   129  			}
   130  			index.dropIndex.RLock()
   131  			indexByClass[class] = index
   132  		}
   133  	}()
   134  
   135  	// safely release remaining locks (in case of panic)
   136  	defer func() {
   137  		for _, index := range indexByClass {
   138  			if index != nil {
   139  				index.dropIndex.RUnlock()
   140  			}
   141  		}
   142  	}()
   143  
   144  	for class, index := range indexByClass {
   145  		queue := refByClass[class]
   146  		errs := index.AddReferencesBatch(ctx, queue, repl)
   147  		// remove index from map to skip releasing its lock in defer
   148  		indexByClass[class] = nil
   149  		index.dropIndex.RUnlock()
   150  		for i, err := range errs {
   151  			if err != nil {
   152  				references[queue[i].OriginalIndex].Err = err
   153  			}
   154  		}
   155  	}
   156  
   157  	return references, nil
   158  }
   159  
   160  func (db *DB) BatchDeleteObjects(ctx context.Context, params objects.BatchDeleteParams,
   161  	repl *additional.ReplicationProperties, tenant string,
   162  ) (objects.BatchDeleteResult, error) {
   163  	// get index for a given class
   164  	className := params.ClassName
   165  	idx := db.GetIndex(className)
   166  	if idx == nil {
   167  		return objects.BatchDeleteResult{}, errors.Errorf("cannot find index for class %v", className)
   168  	}
   169  
   170  	// find all DocIDs in all shards that match the filter
   171  	shardDocIDs, err := idx.findUUIDs(ctx, params.Filters, tenant)
   172  	if err != nil {
   173  		return objects.BatchDeleteResult{}, errors.Wrapf(err, "cannot find objects")
   174  	}
   175  	// prepare to be deleted list of DocIDs from all shards
   176  	toDelete := map[string][]strfmt.UUID{}
   177  	limit := db.config.QueryMaximumResults
   178  
   179  	matches := int64(0)
   180  	for shardName, docIDs := range shardDocIDs {
   181  		docIDsLength := int64(len(docIDs))
   182  		if matches <= limit {
   183  			if matches+docIDsLength <= limit {
   184  				toDelete[shardName] = docIDs
   185  			} else {
   186  				toDelete[shardName] = docIDs[:limit-matches]
   187  			}
   188  		}
   189  		matches += docIDsLength
   190  	}
   191  	// delete the DocIDs in given shards
   192  	deletedObjects, err := idx.batchDeleteObjects(ctx, toDelete, params.DryRun, repl)
   193  	if err != nil {
   194  		return objects.BatchDeleteResult{}, errors.Wrapf(err, "cannot delete objects")
   195  	}
   196  
   197  	result := objects.BatchDeleteResult{
   198  		Matches: matches,
   199  		Limit:   db.config.QueryMaximumResults,
   200  		DryRun:  params.DryRun,
   201  		Objects: deletedObjects,
   202  	}
   203  	return result, nil
   204  }
   205  
   206  func estimateBatchMemory(objs objects.BatchObjects) int64 {
   207  	var sum int64
   208  	for _, item := range objs {
   209  		// Note: This is very much oversimplified. It assumes that we always need
   210  		// the footprint of the full vector and it assumes a fixed overhead of 30B
   211  		// per vector. In reality this depends on the HNSW settings - and possibly
   212  		// in the future we might have completely different index types.
   213  		//
   214  		// However, in the meantime this should be a fairly reasonable estimate, as
   215  		// it's not meant to fail exactly on the last available byte, but rather
   216  		// prevent OOM crashes. Given the fuzziness and async style of the
   217  		// memtrackinga somewhat decent estimate should be good enough.
   218  		sum += int64(len(item.Object.Vector)*4 + 30)
   219  	}
   220  
   221  	return sum
   222  }