github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/sorter/lsm_sorter.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package sorter
    13  
    14  import (
    15  	"context"
    16  	"encoding/binary"
    17  	"fmt"
    18  
    19  	"github.com/pkg/errors"
    20  	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    22  	"github.com/weaviate/weaviate/entities/filters"
    23  	"github.com/weaviate/weaviate/entities/schema"
    24  	"github.com/weaviate/weaviate/entities/storobj"
    25  )
    26  
    27  type LSMSorter interface {
    28  	Sort(ctx context.Context, limit int, sort []filters.Sort) ([]uint64, error)
    29  	SortDocIDs(ctx context.Context, limit int, sort []filters.Sort, ids helpers.AllowList) ([]uint64, error)
    30  	SortDocIDsAndDists(ctx context.Context, limit int, sort []filters.Sort,
    31  		ids []uint64, dists []float32) ([]uint64, []float32, error)
    32  }
    33  
    34  type lsmSorter struct {
    35  	bucket          *lsmkv.Bucket
    36  	dataTypesHelper *dataTypesHelper
    37  	valueExtractor  *comparableValueExtractor
    38  }
    39  
    40  func NewLSMSorter(store *lsmkv.Store, sch schema.Schema, className schema.ClassName) (LSMSorter, error) {
    41  	bucket := store.Bucket(helpers.ObjectsBucketLSM)
    42  	if bucket == nil {
    43  		return nil, fmt.Errorf("lsm sorter - bucket %s for class %s not found", helpers.ObjectsBucketLSM, className)
    44  	}
    45  	class := sch.GetClass(schema.ClassName(className))
    46  	if class == nil {
    47  		return nil, fmt.Errorf("lsm sorter - class %s not found", className)
    48  	}
    49  	dataTypesHelper := newDataTypesHelper(class)
    50  	comparableValuesExtractor := newComparableValueExtractor(dataTypesHelper)
    51  
    52  	return &lsmSorter{bucket, dataTypesHelper, comparableValuesExtractor}, nil
    53  }
    54  
    55  func (s *lsmSorter) Sort(ctx context.Context, limit int, sort []filters.Sort) ([]uint64, error) {
    56  	helper, err := s.createHelper(sort, validateLimit(limit, s.bucket.Count()))
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	return helper.getSorted(ctx)
    61  }
    62  
    63  func (s *lsmSorter) SortDocIDs(ctx context.Context, limit int, sort []filters.Sort, ids helpers.AllowList) ([]uint64, error) {
    64  	helper, err := s.createHelper(sort, validateLimit(limit, ids.Len()))
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	return helper.getSortedDocIDs(ctx, ids)
    69  }
    70  
    71  func (s *lsmSorter) SortDocIDsAndDists(ctx context.Context, limit int, sort []filters.Sort,
    72  	ids []uint64, dists []float32,
    73  ) ([]uint64, []float32, error) {
    74  	helper, err := s.createHelper(sort, validateLimit(limit, len(ids)))
    75  	if err != nil {
    76  		return nil, nil, err
    77  	}
    78  	return helper.getSortedDocIDsAndDistances(ctx, ids, dists)
    79  }
    80  
    81  func (s *lsmSorter) createHelper(sort []filters.Sort, limit int) (*lsmSorterHelper, error) {
    82  	propNames, orders, err := extractPropNamesAndOrders(sort)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	comparator := newComparator(s.dataTypesHelper, propNames, orders)
    88  	creator := newComparableCreator(s.valueExtractor, propNames)
    89  	return newLsmSorterHelper(s.bucket, comparator, creator, limit), nil
    90  }
    91  
    92  type lsmSorterHelper struct {
    93  	bucket     *lsmkv.Bucket
    94  	comparator *comparator
    95  	creator    *comparableCreator
    96  	limit      int
    97  }
    98  
    99  func newLsmSorterHelper(bucket *lsmkv.Bucket, comparator *comparator,
   100  	creator *comparableCreator, limit int,
   101  ) *lsmSorterHelper {
   102  	return &lsmSorterHelper{bucket, comparator, creator, limit}
   103  }
   104  
   105  func (h *lsmSorterHelper) getSorted(ctx context.Context) ([]uint64, error) {
   106  	cursor := h.bucket.Cursor()
   107  	defer cursor.Close()
   108  
   109  	sorter := newInsertSorter(h.comparator, h.limit)
   110  
   111  	for k, objData := cursor.First(); k != nil; k, objData = cursor.Next() {
   112  		docID, err := storobj.DocIDFromBinary(objData)
   113  		if err != nil {
   114  			return nil, errors.Wrapf(err, "lsm sorter - could not get doc id")
   115  		}
   116  		comparable := h.creator.createFromBytes(docID, objData)
   117  		sorter.addComparable(comparable)
   118  	}
   119  
   120  	return h.creator.extractDocIDs(sorter.getSorted()), nil
   121  }
   122  
   123  func (h *lsmSorterHelper) getSortedDocIDs(ctx context.Context, docIDs helpers.AllowList) ([]uint64, error) {
   124  	sorter := newInsertSorter(h.comparator, h.limit)
   125  	docIDBytes := make([]byte, 8)
   126  	it := docIDs.Iterator()
   127  
   128  	for docID, ok := it.Next(); ok; docID, ok = it.Next() {
   129  		binary.LittleEndian.PutUint64(docIDBytes, docID)
   130  		objData, err := h.bucket.GetBySecondary(0, docIDBytes)
   131  		if err != nil {
   132  			return nil, errors.Wrapf(err, "lsm sorter - could not get obj by doc id %d", docID)
   133  		}
   134  		if objData == nil {
   135  			continue
   136  		}
   137  
   138  		comparable := h.creator.createFromBytes(docID, objData)
   139  		sorter.addComparable(comparable)
   140  	}
   141  
   142  	return h.creator.extractDocIDs(sorter.getSorted()), nil
   143  }
   144  
   145  func (h *lsmSorterHelper) getSortedDocIDsAndDistances(ctx context.Context, docIDs []uint64,
   146  	distances []float32,
   147  ) ([]uint64, []float32, error) {
   148  	sorter := newInsertSorter(h.comparator, h.limit)
   149  	docIDBytes := make([]byte, 8)
   150  
   151  	for i, docID := range docIDs {
   152  		binary.LittleEndian.PutUint64(docIDBytes, docID)
   153  		objData, err := h.bucket.GetBySecondary(0, docIDBytes)
   154  		if err != nil {
   155  			return nil, nil, errors.Wrapf(err, "lsm sorter - could not get obj by doc id %d", docID)
   156  		}
   157  		if objData == nil {
   158  			continue
   159  		}
   160  
   161  		comparable := h.creator.createFromBytesWithPayload(docID, objData, distances[i])
   162  		sorter.addComparable(comparable)
   163  	}
   164  
   165  	sorted := sorter.getSorted()
   166  	sortedDistances := make([]float32, len(sorted))
   167  	consume := func(i int, _ uint64, payload interface{}) bool {
   168  		sortedDistances[i] = payload.(float32)
   169  		return false
   170  	}
   171  	h.creator.extractPayloads(sorted, consume)
   172  
   173  	return h.creator.extractDocIDs(sorted), sortedDistances, nil
   174  }