github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/search.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package db
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"sort"
    18  	"strings"
    19  	"sync"
    20  
    21  	enterrors "github.com/weaviate/weaviate/entities/errors"
    22  
    23  	"github.com/pkg/errors"
    24  	"github.com/weaviate/weaviate/adapters/repos/db/refcache"
    25  	"github.com/weaviate/weaviate/entities/additional"
    26  	"github.com/weaviate/weaviate/entities/aggregation"
    27  	"github.com/weaviate/weaviate/entities/dto"
    28  	"github.com/weaviate/weaviate/entities/filters"
    29  	"github.com/weaviate/weaviate/entities/schema"
    30  	"github.com/weaviate/weaviate/entities/search"
    31  	"github.com/weaviate/weaviate/entities/searchparams"
    32  	"github.com/weaviate/weaviate/entities/storobj"
    33  	"github.com/weaviate/weaviate/usecases/objects"
    34  	"github.com/weaviate/weaviate/usecases/traverser"
    35  )
    36  
    37  func (db *DB) Aggregate(ctx context.Context,
    38  	params aggregation.Params,
    39  ) (*aggregation.Result, error) {
    40  	idx := db.GetIndex(params.ClassName)
    41  	if idx == nil {
    42  		return nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName)
    43  	}
    44  
    45  	return idx.aggregate(ctx, params)
    46  }
    47  
    48  func (db *DB) GetQueryMaximumResults() int {
    49  	return int(db.config.QueryMaximumResults)
    50  }
    51  
    52  // SparseObjectSearch is used to perform an inverted index search on the db
    53  //
    54  // Earlier use cases required only []search.Result as a return value from the db, and the
    55  // Class ClassSearch method fit this need. Later on, other use cases presented the need
    56  // for the raw storage objects, such as hybrid search.
    57  func (db *DB) SparseObjectSearch(ctx context.Context, params dto.GetParams) ([]*storobj.Object, []float32, error) {
    58  	idx := db.GetIndex(schema.ClassName(params.ClassName))
    59  	if idx == nil {
    60  		return nil, nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName)
    61  	}
    62  
    63  	if params.Pagination == nil {
    64  		return nil, nil, fmt.Errorf("invalid params, pagination object is nil")
    65  	}
    66  
    67  	totalLimit, err := db.getTotalLimit(params.Pagination, params.AdditionalProperties)
    68  	if err != nil {
    69  		return nil, nil, errors.Wrapf(err, "invalid pagination params")
    70  	}
    71  
    72  	// if this is reference search and tenant is given (as origin class is MT)
    73  	// but searched class is non-MT, then skip tenant to pass validation
    74  	tenant := params.Tenant
    75  	if !idx.partitioningEnabled && params.IsRefOrigin {
    76  		tenant = ""
    77  	}
    78  
    79  	res, scores, err := idx.objectSearch(ctx, totalLimit,
    80  		params.Filters, params.KeywordRanking, params.Sort, params.Cursor,
    81  		params.AdditionalProperties, params.ReplicationProperties, tenant, params.Pagination.Autocut)
    82  	if err != nil {
    83  		return nil, nil, errors.Wrapf(err, "object search at index %s", idx.ID())
    84  	}
    85  
    86  	return res, scores, nil
    87  }
    88  
    89  func (db *DB) Search(ctx context.Context, params dto.GetParams) ([]search.Result, error) {
    90  	if params.Pagination == nil {
    91  		return nil, fmt.Errorf("invalid params, pagination object is nil")
    92  	}
    93  
    94  	res, scores, err := db.SparseObjectSearch(ctx, params)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	res, scores = db.getStoreObjectsWithScores(res, scores, params.Pagination)
   100  	return db.ResolveReferences(ctx,
   101  		storobj.SearchResultsWithScore(res, scores, params.AdditionalProperties, params.Tenant),
   102  		params.Properties, params.GroupBy, params.AdditionalProperties, params.Tenant)
   103  }
   104  
   105  func (db *DB) VectorSearch(ctx context.Context,
   106  	params dto.GetParams,
   107  ) ([]search.Result, error) {
   108  	if params.SearchVector == nil {
   109  		return db.Search(ctx, params)
   110  	}
   111  
   112  	totalLimit, err := db.getTotalLimit(params.Pagination, params.AdditionalProperties)
   113  	if err != nil {
   114  		return nil, fmt.Errorf("invalid pagination params: %w", err)
   115  	}
   116  
   117  	idx := db.GetIndex(schema.ClassName(params.ClassName))
   118  	if idx == nil {
   119  		return nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName)
   120  	}
   121  
   122  	targetDist := extractDistanceFromParams(params)
   123  	res, dists, err := idx.objectVectorSearch(ctx, params.SearchVector, params.TargetVector,
   124  		targetDist, totalLimit, params.Filters, params.Sort, params.GroupBy,
   125  		params.AdditionalProperties, params.ReplicationProperties, params.Tenant)
   126  	if err != nil {
   127  		return nil, errors.Wrapf(err, "object vector search at index %s", idx.ID())
   128  	}
   129  
   130  	if totalLimit < 0 {
   131  		params.Pagination.Limit = len(res)
   132  	}
   133  
   134  	return db.ResolveReferences(ctx,
   135  		storobj.SearchResultsWithDists(db.getStoreObjects(res, params.Pagination),
   136  			params.AdditionalProperties, db.getDists(dists, params.Pagination)),
   137  		params.Properties, params.GroupBy, params.AdditionalProperties, params.Tenant)
   138  }
   139  
   140  func extractDistanceFromParams(params dto.GetParams) float32 {
   141  	certainty := traverser.ExtractCertaintyFromParams(params)
   142  	if certainty != 0 {
   143  		return float32(additional.CertaintyToDist(certainty))
   144  	}
   145  
   146  	dist, _ := traverser.ExtractDistanceFromParams(params)
   147  	return float32(dist)
   148  }
   149  
   150  // DenseObjectSearch is used to perform a vector search on the db
   151  //
   152  // Earlier use cases required only []search.Result as a return value from the db, and the
   153  // Class VectorSearch method fit this need. Later on, other use cases presented the need
   154  // for the raw storage objects, such as hybrid search.
   155  func (db *DB) DenseObjectSearch(ctx context.Context, class string, vector []float32,
   156  	targetVector string, offset int, limit int, filters *filters.LocalFilter,
   157  	addl additional.Properties, tenant string,
   158  ) ([]*storobj.Object, []float32, error) {
   159  	totalLimit := offset + limit
   160  
   161  	index := db.GetIndex(schema.ClassName(class))
   162  	if index == nil {
   163  		return nil, nil, fmt.Errorf("tried to browse non-existing index for %s", class)
   164  	}
   165  
   166  	// TODO: groupBy think of this
   167  	objs, dist, err := index.objectVectorSearch(ctx, vector, targetVector, 0,
   168  		totalLimit, filters, nil, nil, addl, nil, tenant)
   169  	if err != nil {
   170  		return nil, nil, fmt.Errorf("search index %s: %w", index.ID(), err)
   171  	}
   172  
   173  	return objs, dist, nil
   174  }
   175  
   176  func (db *DB) CrossClassVectorSearch(ctx context.Context, vector []float32, targetVector string, offset, limit int,
   177  	filters *filters.LocalFilter,
   178  ) ([]search.Result, error) {
   179  	var found search.Results
   180  
   181  	wg := &sync.WaitGroup{}
   182  	mutex := &sync.Mutex{}
   183  	var searchErrors []error
   184  	totalLimit := offset + limit
   185  
   186  	db.indexLock.RLock()
   187  	for _, index := range db.indices {
   188  		wg.Add(1)
   189  		index := index
   190  		f := func() {
   191  			defer wg.Done()
   192  
   193  			objs, dist, err := index.objectVectorSearch(ctx, vector, targetVector,
   194  				0, totalLimit, filters, nil, nil,
   195  				additional.Properties{}, nil, "")
   196  			if err != nil {
   197  				mutex.Lock()
   198  				searchErrors = append(searchErrors, errors.Wrapf(err, "search index %s", index.ID()))
   199  				mutex.Unlock()
   200  			}
   201  
   202  			mutex.Lock()
   203  			found = append(found, storobj.SearchResultsWithDists(objs, additional.Properties{}, dist)...)
   204  			mutex.Unlock()
   205  		}
   206  		enterrors.GoWrapper(f, index.logger)
   207  	}
   208  	db.indexLock.RUnlock()
   209  
   210  	wg.Wait()
   211  
   212  	if len(searchErrors) > 0 {
   213  		var msg strings.Builder
   214  		for i, err := range searchErrors {
   215  			if i != 0 {
   216  				msg.WriteString(", ")
   217  			}
   218  			msg.WriteString(err.Error())
   219  		}
   220  		return nil, errors.New(msg.String())
   221  	}
   222  
   223  	sort.Slice(found, func(i, j int) bool {
   224  		return found[i].Dist < found[j].Dist
   225  	})
   226  
   227  	// not enriching by refs, as a vector search result cannot provide
   228  	// SelectProperties
   229  	return db.getSearchResults(found, offset, limit), nil
   230  }
   231  
   232  // Query a specific class
   233  func (db *DB) Query(ctx context.Context, q *objects.QueryInput) (search.Results, *objects.Error) {
   234  	totalLimit := q.Offset + q.Limit
   235  	if totalLimit == 0 {
   236  		return nil, nil
   237  	}
   238  	if len(q.Sort) > 0 {
   239  		scheme := db.schemaGetter.GetSchemaSkipAuth()
   240  		if err := filters.ValidateSort(scheme, schema.ClassName(q.Class), q.Sort); err != nil {
   241  			return nil, &objects.Error{Msg: "sorting", Code: objects.StatusBadRequest, Err: err}
   242  		}
   243  	}
   244  	idx := db.GetIndex(schema.ClassName(q.Class))
   245  	if idx == nil {
   246  		return nil, &objects.Error{Msg: "class not found " + q.Class, Code: objects.StatusNotFound}
   247  	}
   248  	if q.Cursor != nil {
   249  		if err := filters.ValidateCursor(schema.ClassName(q.Class), q.Cursor, q.Offset, q.Filters, q.Sort); err != nil {
   250  			return nil, &objects.Error{Msg: "cursor api: invalid 'after' parameter", Code: objects.StatusBadRequest, Err: err}
   251  		}
   252  	}
   253  	res, _, err := idx.objectSearch(ctx, totalLimit, q.Filters,
   254  		nil, q.Sort, q.Cursor, q.Additional, nil, q.Tenant, 0)
   255  	if err != nil {
   256  		switch err.(type) {
   257  		case objects.ErrMultiTenancy:
   258  			return nil, &objects.Error{Msg: "search index " + idx.ID(), Code: objects.StatusUnprocessableEntity, Err: err}
   259  		default:
   260  			return nil, &objects.Error{Msg: "search index " + idx.ID(), Code: objects.StatusInternalServerError, Err: err}
   261  		}
   262  	}
   263  	return db.getSearchResults(storobj.SearchResults(res, q.Additional, ""), q.Offset, q.Limit), nil
   264  }
   265  
   266  // ObjectSearch search each index.
   267  // Deprecated by Query which searches a specific index
   268  func (db *DB) ObjectSearch(ctx context.Context, offset, limit int,
   269  	filters *filters.LocalFilter, sort []filters.Sort,
   270  	additional additional.Properties, tenant string,
   271  ) (search.Results, error) {
   272  	return db.objectSearch(ctx, offset, limit, filters, sort, additional, tenant)
   273  }
   274  
   275  func (db *DB) objectSearch(ctx context.Context, offset, limit int,
   276  	filters *filters.LocalFilter, sort []filters.Sort,
   277  	additional additional.Properties, tenant string,
   278  ) (search.Results, error) {
   279  	var found []*storobj.Object
   280  
   281  	if err := db.validateSort(sort); err != nil {
   282  		return nil, errors.Wrap(err, "search")
   283  	}
   284  
   285  	totalLimit := offset + limit
   286  	// TODO: Search in parallel, rather than sequentially or this will be
   287  	// painfully slow on large schemas
   288  	// wrapped in func to unlock mutex within defer
   289  	if err := func() error {
   290  		db.indexLock.RLock()
   291  		defer db.indexLock.RUnlock()
   292  
   293  		for _, index := range db.indices {
   294  			// TODO support all additional props
   295  			res, _, err := index.objectSearch(ctx, totalLimit,
   296  				filters, nil, sort, nil, additional, nil, tenant, 0)
   297  			if err != nil {
   298  				// Multi tenancy specific errors
   299  				if errors.As(err, &objects.ErrMultiTenancy{}) {
   300  					// validation failed (either MT class without tenant or non-MT class with tenant)
   301  					if strings.Contains(err.Error(), "has multi-tenancy enabled, but request was without tenant") ||
   302  						strings.Contains(err.Error(), "has multi-tenancy disabled, but request was with tenant") {
   303  						continue
   304  					}
   305  					// tenant not added to class
   306  					if strings.Contains(err.Error(), "no tenant found with key") {
   307  						continue
   308  					}
   309  					// tenant does belong to this class
   310  					if errors.As(err, &errTenantNotFound) {
   311  						continue // tenant does belong to this class
   312  					}
   313  				}
   314  				return errors.Wrapf(err, "search index %s", index.ID())
   315  			}
   316  
   317  			found = append(found, res...)
   318  			if len(found) >= totalLimit {
   319  				// we are done
   320  				break
   321  			}
   322  		}
   323  		return nil
   324  	}(); err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	return db.getSearchResults(storobj.SearchResults(found, additional, tenant), offset, limit), nil
   329  }
   330  
   331  // ResolveReferences takes a list of search results and enriches them
   332  // with any referenced objects
   333  func (db *DB) ResolveReferences(ctx context.Context, objs search.Results,
   334  	props search.SelectProperties, groupBy *searchparams.GroupBy,
   335  	addl additional.Properties, tenant string,
   336  ) (search.Results, error) {
   337  	if addl.NoProps {
   338  		// If we have no props, there also can't be refs among them, so we can skip
   339  		// the refcache resolver
   340  		return objs, nil
   341  	}
   342  
   343  	if groupBy != nil {
   344  		res, err := refcache.NewResolverWithGroup(refcache.NewCacherWithGroup(db, db.logger, tenant)).
   345  			Do(ctx, objs, props, addl)
   346  		if err != nil {
   347  			return nil, fmt.Errorf("resolve cross-refs: %w", err)
   348  		}
   349  		return res, nil
   350  	}
   351  
   352  	res, err := refcache.NewResolver(refcache.NewCacher(db, db.logger, tenant)).
   353  		Do(ctx, objs, props, addl)
   354  	if err != nil {
   355  		return nil, fmt.Errorf("resolve cross-refs: %w", err)
   356  	}
   357  
   358  	return res, nil
   359  }
   360  
   361  func (db *DB) validateSort(sort []filters.Sort) error {
   362  	if len(sort) > 0 {
   363  		var errorMsgs []string
   364  		// needs to happen before the index lock as they might deadlock each other
   365  		schema := db.schemaGetter.GetSchemaSkipAuth()
   366  		db.indexLock.RLock()
   367  		for _, index := range db.indices {
   368  			err := filters.ValidateSort(schema,
   369  				index.Config.ClassName, sort)
   370  			if err != nil {
   371  				errorMsg := errors.Wrapf(err, "search index %s", index.ID()).Error()
   372  				errorMsgs = append(errorMsgs, errorMsg)
   373  			}
   374  		}
   375  		db.indexLock.RUnlock()
   376  		if len(errorMsgs) > 0 {
   377  			return errors.Errorf("%s", strings.Join(errorMsgs, ", "))
   378  		}
   379  	}
   380  	return nil
   381  }
   382  
   383  func (db *DB) getTotalLimit(pagination *filters.Pagination, addl additional.Properties) (int, error) {
   384  	if pagination.Limit == filters.LimitFlagSearchByDist {
   385  		return filters.LimitFlagSearchByDist, nil
   386  	}
   387  
   388  	totalLimit := pagination.Offset + db.getLimit(pagination.Limit)
   389  	if totalLimit == 0 {
   390  		return 0, fmt.Errorf("invalid default limit: %v", db.getLimit(pagination.Limit))
   391  	}
   392  	if !addl.ReferenceQuery && totalLimit > int(db.config.QueryMaximumResults) {
   393  		return 0, errors.New("query maximum results exceeded")
   394  	}
   395  	return totalLimit, nil
   396  }
   397  
   398  func (db *DB) getSearchResults(found search.Results, paramOffset, paramLimit int) search.Results {
   399  	offset, limit := db.getOffsetLimit(len(found), paramOffset, paramLimit)
   400  	if offset == 0 && limit == 0 {
   401  		return nil
   402  	}
   403  	return found[offset:limit]
   404  }
   405  
   406  func (db *DB) getStoreObjects(res []*storobj.Object, pagination *filters.Pagination) []*storobj.Object {
   407  	offset, limit := db.getOffsetLimit(len(res), pagination.Offset, pagination.Limit)
   408  	if offset == 0 && limit == 0 {
   409  		return nil
   410  	}
   411  	return res[offset:limit]
   412  }
   413  
   414  func (db *DB) getStoreObjectsWithScores(res []*storobj.Object, scores []float32, pagination *filters.Pagination) ([]*storobj.Object, []float32) {
   415  	offset, limit := db.getOffsetLimit(len(res), pagination.Offset, pagination.Limit)
   416  	if offset == 0 && limit == 0 {
   417  		return nil, nil
   418  	}
   419  	return res[offset:limit], scores[offset:limit]
   420  }
   421  
   422  func (db *DB) getDists(dists []float32, pagination *filters.Pagination) []float32 {
   423  	offset, limit := db.getOffsetLimit(len(dists), pagination.Offset, pagination.Limit)
   424  	if offset == 0 && limit == 0 {
   425  		return nil
   426  	}
   427  	return dists[offset:limit]
   428  }
   429  
   430  func (db *DB) getOffsetLimit(arraySize int, offset, limit int) (int, int) {
   431  	totalLimit := offset + db.getLimit(limit)
   432  	if arraySize > totalLimit {
   433  		return offset, totalLimit
   434  	} else if arraySize > offset {
   435  		return offset, arraySize
   436  	}
   437  	return 0, 0
   438  }
   439  
   440  func (db *DB) getLimit(limit int) int {
   441  	if limit == filters.LimitFlagNotSet {
   442  		return int(db.config.QueryLimit)
   443  	}
   444  	return limit
   445  }