github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/sorter/objects_sorter.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package sorter 13 14 import ( 15 "github.com/weaviate/weaviate/entities/filters" 16 "github.com/weaviate/weaviate/entities/schema" 17 "github.com/weaviate/weaviate/entities/storobj" 18 ) 19 20 type Sorter interface { 21 Sort(objects []*storobj.Object, distances []float32, 22 limit int, sort []filters.Sort) ([]*storobj.Object, []float32, error) 23 } 24 25 type objectsSorter struct { 26 schema schema.Schema 27 } 28 29 func NewObjectsSorter(schema schema.Schema) *objectsSorter { 30 return &objectsSorter{schema} 31 } 32 33 func (s objectsSorter) Sort(objects []*storobj.Object, 34 scores []float32, limit int, sort []filters.Sort, 35 ) ([]*storobj.Object, []float32, error) { 36 count := len(objects) 37 if count == 0 { 38 return objects, scores, nil 39 } 40 41 limit = validateLimit(limit, count) 42 propNames, orders, err := extractPropNamesAndOrders(sort) 43 if err != nil { 44 return nil, nil, err 45 } 46 47 class := s.schema.GetClass(objects[0].Class()) 48 dataTypesHelper := newDataTypesHelper(class) 49 valueExtractor := newComparableValueExtractor(dataTypesHelper) 50 comparator := newComparator(dataTypesHelper, propNames, orders) 51 creator := newComparableCreator(valueExtractor, propNames) 52 53 return newObjectsSorterHelper(comparator, creator, limit). 54 sort(objects, scores) 55 } 56 57 type objectsSorterHelper struct { 58 comparator *comparator 59 creator *comparableCreator 60 limit int 61 } 62 63 func newObjectsSorterHelper(comparator *comparator, creator *comparableCreator, limit int) *objectsSorterHelper { 64 return &objectsSorterHelper{comparator, creator, limit} 65 } 66 67 func (h *objectsSorterHelper) sort(objects []*storobj.Object, distances []float32) ([]*storobj.Object, []float32, error) { 68 withDistances := len(distances) > 0 69 count := len(objects) 70 sorter := newDefaultSorter(h.comparator, count) 71 72 for i := range objects { 73 payload := objectDistancePayload{o: objects[i]} 74 if withDistances { 75 payload.d = distances[i] 76 } 77 comparable := h.creator.createFromObjectWithPayload(objects[i], payload) 78 sorter.addComparable(comparable) 79 } 80 81 slice := h.limit 82 if slice == 0 { 83 slice = count 84 } 85 86 sorted := sorter.getSorted() 87 consume := func(i int, _ uint64, payload interface{}) bool { 88 if i >= slice { 89 return true 90 } 91 p := payload.(objectDistancePayload) 92 objects[i] = p.o 93 if withDistances { 94 distances[i] = p.d 95 } 96 return false 97 } 98 h.creator.extractPayloads(sorted, consume) 99 100 if withDistances { 101 return objects[:slice], distances[:slice], nil 102 } 103 return objects[:slice], distances, nil 104 } 105 106 type objectDistancePayload struct { 107 o *storobj.Object 108 d float32 109 }