github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/sorter/objects_sorter.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package sorter
    13  
    14  import (
    15  	"github.com/weaviate/weaviate/entities/filters"
    16  	"github.com/weaviate/weaviate/entities/schema"
    17  	"github.com/weaviate/weaviate/entities/storobj"
    18  )
    19  
    20  type Sorter interface {
    21  	Sort(objects []*storobj.Object, distances []float32,
    22  		limit int, sort []filters.Sort) ([]*storobj.Object, []float32, error)
    23  }
    24  
    25  type objectsSorter struct {
    26  	schema schema.Schema
    27  }
    28  
    29  func NewObjectsSorter(schema schema.Schema) *objectsSorter {
    30  	return &objectsSorter{schema}
    31  }
    32  
    33  func (s objectsSorter) Sort(objects []*storobj.Object,
    34  	scores []float32, limit int, sort []filters.Sort,
    35  ) ([]*storobj.Object, []float32, error) {
    36  	count := len(objects)
    37  	if count == 0 {
    38  		return objects, scores, nil
    39  	}
    40  
    41  	limit = validateLimit(limit, count)
    42  	propNames, orders, err := extractPropNamesAndOrders(sort)
    43  	if err != nil {
    44  		return nil, nil, err
    45  	}
    46  
    47  	class := s.schema.GetClass(objects[0].Class())
    48  	dataTypesHelper := newDataTypesHelper(class)
    49  	valueExtractor := newComparableValueExtractor(dataTypesHelper)
    50  	comparator := newComparator(dataTypesHelper, propNames, orders)
    51  	creator := newComparableCreator(valueExtractor, propNames)
    52  
    53  	return newObjectsSorterHelper(comparator, creator, limit).
    54  		sort(objects, scores)
    55  }
    56  
    57  type objectsSorterHelper struct {
    58  	comparator *comparator
    59  	creator    *comparableCreator
    60  	limit      int
    61  }
    62  
    63  func newObjectsSorterHelper(comparator *comparator, creator *comparableCreator, limit int) *objectsSorterHelper {
    64  	return &objectsSorterHelper{comparator, creator, limit}
    65  }
    66  
    67  func (h *objectsSorterHelper) sort(objects []*storobj.Object, distances []float32) ([]*storobj.Object, []float32, error) {
    68  	withDistances := len(distances) > 0
    69  	count := len(objects)
    70  	sorter := newDefaultSorter(h.comparator, count)
    71  
    72  	for i := range objects {
    73  		payload := objectDistancePayload{o: objects[i]}
    74  		if withDistances {
    75  			payload.d = distances[i]
    76  		}
    77  		comparable := h.creator.createFromObjectWithPayload(objects[i], payload)
    78  		sorter.addComparable(comparable)
    79  	}
    80  
    81  	slice := h.limit
    82  	if slice == 0 {
    83  		slice = count
    84  	}
    85  
    86  	sorted := sorter.getSorted()
    87  	consume := func(i int, _ uint64, payload interface{}) bool {
    88  		if i >= slice {
    89  			return true
    90  		}
    91  		p := payload.(objectDistancePayload)
    92  		objects[i] = p.o
    93  		if withDistances {
    94  			distances[i] = p.d
    95  		}
    96  		return false
    97  	}
    98  	h.creator.extractPayloads(sorted, consume)
    99  
   100  	if withDistances {
   101  		return objects[:slice], distances[:slice], nil
   102  	}
   103  	return objects[:slice], distances, nil
   104  }
   105  
   106  type objectDistancePayload struct {
   107  	o *storobj.Object
   108  	d float32
   109  }