github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/search.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package db 13 14 import ( 15 "context" 16 "fmt" 17 "sort" 18 "strings" 19 "sync" 20 21 enterrors "github.com/weaviate/weaviate/entities/errors" 22 23 "github.com/pkg/errors" 24 "github.com/weaviate/weaviate/adapters/repos/db/refcache" 25 "github.com/weaviate/weaviate/entities/additional" 26 "github.com/weaviate/weaviate/entities/aggregation" 27 "github.com/weaviate/weaviate/entities/dto" 28 "github.com/weaviate/weaviate/entities/filters" 29 "github.com/weaviate/weaviate/entities/schema" 30 "github.com/weaviate/weaviate/entities/search" 31 "github.com/weaviate/weaviate/entities/searchparams" 32 "github.com/weaviate/weaviate/entities/storobj" 33 "github.com/weaviate/weaviate/usecases/objects" 34 "github.com/weaviate/weaviate/usecases/traverser" 35 ) 36 37 func (db *DB) Aggregate(ctx context.Context, 38 params aggregation.Params, 39 ) (*aggregation.Result, error) { 40 idx := db.GetIndex(params.ClassName) 41 if idx == nil { 42 return nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName) 43 } 44 45 return idx.aggregate(ctx, params) 46 } 47 48 func (db *DB) GetQueryMaximumResults() int { 49 return int(db.config.QueryMaximumResults) 50 } 51 52 // SparseObjectSearch is used to perform an inverted index search on the db 53 // 54 // Earlier use cases required only []search.Result as a return value from the db, and the 55 // Class ClassSearch method fit this need. Later on, other use cases presented the need 56 // for the raw storage objects, such as hybrid search. 57 func (db *DB) SparseObjectSearch(ctx context.Context, params dto.GetParams) ([]*storobj.Object, []float32, error) { 58 idx := db.GetIndex(schema.ClassName(params.ClassName)) 59 if idx == nil { 60 return nil, nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName) 61 } 62 63 if params.Pagination == nil { 64 return nil, nil, fmt.Errorf("invalid params, pagination object is nil") 65 } 66 67 totalLimit, err := db.getTotalLimit(params.Pagination, params.AdditionalProperties) 68 if err != nil { 69 return nil, nil, errors.Wrapf(err, "invalid pagination params") 70 } 71 72 // if this is reference search and tenant is given (as origin class is MT) 73 // but searched class is non-MT, then skip tenant to pass validation 74 tenant := params.Tenant 75 if !idx.partitioningEnabled && params.IsRefOrigin { 76 tenant = "" 77 } 78 79 res, scores, err := idx.objectSearch(ctx, totalLimit, 80 params.Filters, params.KeywordRanking, params.Sort, params.Cursor, 81 params.AdditionalProperties, params.ReplicationProperties, tenant, params.Pagination.Autocut) 82 if err != nil { 83 return nil, nil, errors.Wrapf(err, "object search at index %s", idx.ID()) 84 } 85 86 return res, scores, nil 87 } 88 89 func (db *DB) Search(ctx context.Context, params dto.GetParams) ([]search.Result, error) { 90 if params.Pagination == nil { 91 return nil, fmt.Errorf("invalid params, pagination object is nil") 92 } 93 94 res, scores, err := db.SparseObjectSearch(ctx, params) 95 if err != nil { 96 return nil, err 97 } 98 99 res, scores = db.getStoreObjectsWithScores(res, scores, params.Pagination) 100 return db.ResolveReferences(ctx, 101 storobj.SearchResultsWithScore(res, scores, params.AdditionalProperties, params.Tenant), 102 params.Properties, params.GroupBy, params.AdditionalProperties, params.Tenant) 103 } 104 105 func (db *DB) VectorSearch(ctx context.Context, 106 params dto.GetParams, 107 ) ([]search.Result, error) { 108 if params.SearchVector == nil { 109 return db.Search(ctx, params) 110 } 111 112 totalLimit, err := db.getTotalLimit(params.Pagination, params.AdditionalProperties) 113 if err != nil { 114 return nil, fmt.Errorf("invalid pagination params: %w", err) 115 } 116 117 idx := db.GetIndex(schema.ClassName(params.ClassName)) 118 if idx == nil { 119 return nil, fmt.Errorf("tried to browse non-existing index for %s", params.ClassName) 120 } 121 122 targetDist := extractDistanceFromParams(params) 123 res, dists, err := idx.objectVectorSearch(ctx, params.SearchVector, params.TargetVector, 124 targetDist, totalLimit, params.Filters, params.Sort, params.GroupBy, 125 params.AdditionalProperties, params.ReplicationProperties, params.Tenant) 126 if err != nil { 127 return nil, errors.Wrapf(err, "object vector search at index %s", idx.ID()) 128 } 129 130 if totalLimit < 0 { 131 params.Pagination.Limit = len(res) 132 } 133 134 return db.ResolveReferences(ctx, 135 storobj.SearchResultsWithDists(db.getStoreObjects(res, params.Pagination), 136 params.AdditionalProperties, db.getDists(dists, params.Pagination)), 137 params.Properties, params.GroupBy, params.AdditionalProperties, params.Tenant) 138 } 139 140 func extractDistanceFromParams(params dto.GetParams) float32 { 141 certainty := traverser.ExtractCertaintyFromParams(params) 142 if certainty != 0 { 143 return float32(additional.CertaintyToDist(certainty)) 144 } 145 146 dist, _ := traverser.ExtractDistanceFromParams(params) 147 return float32(dist) 148 } 149 150 // DenseObjectSearch is used to perform a vector search on the db 151 // 152 // Earlier use cases required only []search.Result as a return value from the db, and the 153 // Class VectorSearch method fit this need. Later on, other use cases presented the need 154 // for the raw storage objects, such as hybrid search. 155 func (db *DB) DenseObjectSearch(ctx context.Context, class string, vector []float32, 156 targetVector string, offset int, limit int, filters *filters.LocalFilter, 157 addl additional.Properties, tenant string, 158 ) ([]*storobj.Object, []float32, error) { 159 totalLimit := offset + limit 160 161 index := db.GetIndex(schema.ClassName(class)) 162 if index == nil { 163 return nil, nil, fmt.Errorf("tried to browse non-existing index for %s", class) 164 } 165 166 // TODO: groupBy think of this 167 objs, dist, err := index.objectVectorSearch(ctx, vector, targetVector, 0, 168 totalLimit, filters, nil, nil, addl, nil, tenant) 169 if err != nil { 170 return nil, nil, fmt.Errorf("search index %s: %w", index.ID(), err) 171 } 172 173 return objs, dist, nil 174 } 175 176 func (db *DB) CrossClassVectorSearch(ctx context.Context, vector []float32, targetVector string, offset, limit int, 177 filters *filters.LocalFilter, 178 ) ([]search.Result, error) { 179 var found search.Results 180 181 wg := &sync.WaitGroup{} 182 mutex := &sync.Mutex{} 183 var searchErrors []error 184 totalLimit := offset + limit 185 186 db.indexLock.RLock() 187 for _, index := range db.indices { 188 wg.Add(1) 189 index := index 190 f := func() { 191 defer wg.Done() 192 193 objs, dist, err := index.objectVectorSearch(ctx, vector, targetVector, 194 0, totalLimit, filters, nil, nil, 195 additional.Properties{}, nil, "") 196 if err != nil { 197 mutex.Lock() 198 searchErrors = append(searchErrors, errors.Wrapf(err, "search index %s", index.ID())) 199 mutex.Unlock() 200 } 201 202 mutex.Lock() 203 found = append(found, storobj.SearchResultsWithDists(objs, additional.Properties{}, dist)...) 204 mutex.Unlock() 205 } 206 enterrors.GoWrapper(f, index.logger) 207 } 208 db.indexLock.RUnlock() 209 210 wg.Wait() 211 212 if len(searchErrors) > 0 { 213 var msg strings.Builder 214 for i, err := range searchErrors { 215 if i != 0 { 216 msg.WriteString(", ") 217 } 218 msg.WriteString(err.Error()) 219 } 220 return nil, errors.New(msg.String()) 221 } 222 223 sort.Slice(found, func(i, j int) bool { 224 return found[i].Dist < found[j].Dist 225 }) 226 227 // not enriching by refs, as a vector search result cannot provide 228 // SelectProperties 229 return db.getSearchResults(found, offset, limit), nil 230 } 231 232 // Query a specific class 233 func (db *DB) Query(ctx context.Context, q *objects.QueryInput) (search.Results, *objects.Error) { 234 totalLimit := q.Offset + q.Limit 235 if totalLimit == 0 { 236 return nil, nil 237 } 238 if len(q.Sort) > 0 { 239 scheme := db.schemaGetter.GetSchemaSkipAuth() 240 if err := filters.ValidateSort(scheme, schema.ClassName(q.Class), q.Sort); err != nil { 241 return nil, &objects.Error{Msg: "sorting", Code: objects.StatusBadRequest, Err: err} 242 } 243 } 244 idx := db.GetIndex(schema.ClassName(q.Class)) 245 if idx == nil { 246 return nil, &objects.Error{Msg: "class not found " + q.Class, Code: objects.StatusNotFound} 247 } 248 if q.Cursor != nil { 249 if err := filters.ValidateCursor(schema.ClassName(q.Class), q.Cursor, q.Offset, q.Filters, q.Sort); err != nil { 250 return nil, &objects.Error{Msg: "cursor api: invalid 'after' parameter", Code: objects.StatusBadRequest, Err: err} 251 } 252 } 253 res, _, err := idx.objectSearch(ctx, totalLimit, q.Filters, 254 nil, q.Sort, q.Cursor, q.Additional, nil, q.Tenant, 0) 255 if err != nil { 256 switch err.(type) { 257 case objects.ErrMultiTenancy: 258 return nil, &objects.Error{Msg: "search index " + idx.ID(), Code: objects.StatusUnprocessableEntity, Err: err} 259 default: 260 return nil, &objects.Error{Msg: "search index " + idx.ID(), Code: objects.StatusInternalServerError, Err: err} 261 } 262 } 263 return db.getSearchResults(storobj.SearchResults(res, q.Additional, ""), q.Offset, q.Limit), nil 264 } 265 266 // ObjectSearch search each index. 267 // Deprecated by Query which searches a specific index 268 func (db *DB) ObjectSearch(ctx context.Context, offset, limit int, 269 filters *filters.LocalFilter, sort []filters.Sort, 270 additional additional.Properties, tenant string, 271 ) (search.Results, error) { 272 return db.objectSearch(ctx, offset, limit, filters, sort, additional, tenant) 273 } 274 275 func (db *DB) objectSearch(ctx context.Context, offset, limit int, 276 filters *filters.LocalFilter, sort []filters.Sort, 277 additional additional.Properties, tenant string, 278 ) (search.Results, error) { 279 var found []*storobj.Object 280 281 if err := db.validateSort(sort); err != nil { 282 return nil, errors.Wrap(err, "search") 283 } 284 285 totalLimit := offset + limit 286 // TODO: Search in parallel, rather than sequentially or this will be 287 // painfully slow on large schemas 288 // wrapped in func to unlock mutex within defer 289 if err := func() error { 290 db.indexLock.RLock() 291 defer db.indexLock.RUnlock() 292 293 for _, index := range db.indices { 294 // TODO support all additional props 295 res, _, err := index.objectSearch(ctx, totalLimit, 296 filters, nil, sort, nil, additional, nil, tenant, 0) 297 if err != nil { 298 // Multi tenancy specific errors 299 if errors.As(err, &objects.ErrMultiTenancy{}) { 300 // validation failed (either MT class without tenant or non-MT class with tenant) 301 if strings.Contains(err.Error(), "has multi-tenancy enabled, but request was without tenant") || 302 strings.Contains(err.Error(), "has multi-tenancy disabled, but request was with tenant") { 303 continue 304 } 305 // tenant not added to class 306 if strings.Contains(err.Error(), "no tenant found with key") { 307 continue 308 } 309 // tenant does belong to this class 310 if errors.As(err, &errTenantNotFound) { 311 continue // tenant does belong to this class 312 } 313 } 314 return errors.Wrapf(err, "search index %s", index.ID()) 315 } 316 317 found = append(found, res...) 318 if len(found) >= totalLimit { 319 // we are done 320 break 321 } 322 } 323 return nil 324 }(); err != nil { 325 return nil, err 326 } 327 328 return db.getSearchResults(storobj.SearchResults(found, additional, tenant), offset, limit), nil 329 } 330 331 // ResolveReferences takes a list of search results and enriches them 332 // with any referenced objects 333 func (db *DB) ResolveReferences(ctx context.Context, objs search.Results, 334 props search.SelectProperties, groupBy *searchparams.GroupBy, 335 addl additional.Properties, tenant string, 336 ) (search.Results, error) { 337 if addl.NoProps { 338 // If we have no props, there also can't be refs among them, so we can skip 339 // the refcache resolver 340 return objs, nil 341 } 342 343 if groupBy != nil { 344 res, err := refcache.NewResolverWithGroup(refcache.NewCacherWithGroup(db, db.logger, tenant)). 345 Do(ctx, objs, props, addl) 346 if err != nil { 347 return nil, fmt.Errorf("resolve cross-refs: %w", err) 348 } 349 return res, nil 350 } 351 352 res, err := refcache.NewResolver(refcache.NewCacher(db, db.logger, tenant)). 353 Do(ctx, objs, props, addl) 354 if err != nil { 355 return nil, fmt.Errorf("resolve cross-refs: %w", err) 356 } 357 358 return res, nil 359 } 360 361 func (db *DB) validateSort(sort []filters.Sort) error { 362 if len(sort) > 0 { 363 var errorMsgs []string 364 // needs to happen before the index lock as they might deadlock each other 365 schema := db.schemaGetter.GetSchemaSkipAuth() 366 db.indexLock.RLock() 367 for _, index := range db.indices { 368 err := filters.ValidateSort(schema, 369 index.Config.ClassName, sort) 370 if err != nil { 371 errorMsg := errors.Wrapf(err, "search index %s", index.ID()).Error() 372 errorMsgs = append(errorMsgs, errorMsg) 373 } 374 } 375 db.indexLock.RUnlock() 376 if len(errorMsgs) > 0 { 377 return errors.Errorf("%s", strings.Join(errorMsgs, ", ")) 378 } 379 } 380 return nil 381 } 382 383 func (db *DB) getTotalLimit(pagination *filters.Pagination, addl additional.Properties) (int, error) { 384 if pagination.Limit == filters.LimitFlagSearchByDist { 385 return filters.LimitFlagSearchByDist, nil 386 } 387 388 totalLimit := pagination.Offset + db.getLimit(pagination.Limit) 389 if totalLimit == 0 { 390 return 0, fmt.Errorf("invalid default limit: %v", db.getLimit(pagination.Limit)) 391 } 392 if !addl.ReferenceQuery && totalLimit > int(db.config.QueryMaximumResults) { 393 return 0, errors.New("query maximum results exceeded") 394 } 395 return totalLimit, nil 396 } 397 398 func (db *DB) getSearchResults(found search.Results, paramOffset, paramLimit int) search.Results { 399 offset, limit := db.getOffsetLimit(len(found), paramOffset, paramLimit) 400 if offset == 0 && limit == 0 { 401 return nil 402 } 403 return found[offset:limit] 404 } 405 406 func (db *DB) getStoreObjects(res []*storobj.Object, pagination *filters.Pagination) []*storobj.Object { 407 offset, limit := db.getOffsetLimit(len(res), pagination.Offset, pagination.Limit) 408 if offset == 0 && limit == 0 { 409 return nil 410 } 411 return res[offset:limit] 412 } 413 414 func (db *DB) getStoreObjectsWithScores(res []*storobj.Object, scores []float32, pagination *filters.Pagination) ([]*storobj.Object, []float32) { 415 offset, limit := db.getOffsetLimit(len(res), pagination.Offset, pagination.Limit) 416 if offset == 0 && limit == 0 { 417 return nil, nil 418 } 419 return res[offset:limit], scores[offset:limit] 420 } 421 422 func (db *DB) getDists(dists []float32, pagination *filters.Pagination) []float32 { 423 offset, limit := db.getOffsetLimit(len(dists), pagination.Offset, pagination.Limit) 424 if offset == 0 && limit == 0 { 425 return nil 426 } 427 return dists[offset:limit] 428 } 429 430 func (db *DB) getOffsetLimit(arraySize int, offset, limit int) (int, int) { 431 totalLimit := offset + db.getLimit(limit) 432 if arraySize > totalLimit { 433 return offset, totalLimit 434 } else if arraySize > offset { 435 return offset, arraySize 436 } 437 return 0, 0 438 } 439 440 func (db *DB) getLimit(limit int) int { 441 if limit == filters.LimitFlagNotSet { 442 return int(db.config.QueryLimit) 443 } 444 return limit 445 }