github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/group_merger.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package db 13 14 import ( 15 "fmt" 16 "sort" 17 18 "github.com/weaviate/weaviate/entities/additional" 19 "github.com/weaviate/weaviate/entities/searchparams" 20 "github.com/weaviate/weaviate/entities/storobj" 21 ) 22 23 type groupMerger struct { 24 objects []*storobj.Object 25 dists []float32 26 groupBy *searchparams.GroupBy 27 } 28 29 func newGroupMerger(objects []*storobj.Object, dists []float32, 30 groupBy *searchparams.GroupBy, 31 ) *groupMerger { 32 return &groupMerger{objects, dists, groupBy} 33 } 34 35 func (gm *groupMerger) Do() ([]*storobj.Object, []float32, error) { 36 groups := map[string][]*additional.Group{} 37 objects := map[string][]int{} 38 39 for i, obj := range gm.objects { 40 g, ok := obj.AdditionalProperties()["group"] 41 if !ok { 42 return nil, nil, fmt.Errorf("group not found for object: %v", obj.ID()) 43 } 44 group, ok := g.(*additional.Group) 45 if !ok { 46 return nil, nil, fmt.Errorf("wrong group type for object: %v", obj.ID()) 47 } 48 groups[group.GroupedBy.Value] = append(groups[group.GroupedBy.Value], group) 49 objects[group.GroupedBy.Value] = append(objects[group.GroupedBy.Value], i) 50 } 51 52 getMinDistance := func(groups []*additional.Group) float32 { 53 min := groups[0].MinDistance 54 for i := range groups { 55 if groups[i].MinDistance < min { 56 min = groups[i].MinDistance 57 } 58 } 59 return min 60 } 61 62 type groupMinDistance struct { 63 value string 64 distance float32 65 } 66 67 groupDistances := []groupMinDistance{} 68 for val, group := range groups { 69 groupDistances = append(groupDistances, groupMinDistance{ 70 value: val, distance: getMinDistance(group), 71 }) 72 } 73 74 sort.Slice(groupDistances, func(i, j int) bool { 75 return groupDistances[i].distance < groupDistances[j].distance 76 }) 77 78 desiredLength := len(groups) 79 if desiredLength > gm.groupBy.Groups { 80 desiredLength = gm.groupBy.Groups 81 } 82 83 objs := make([]*storobj.Object, desiredLength) 84 dists := make([]float32, desiredLength) 85 for i, groupDistance := range groupDistances[:desiredLength] { 86 val := groupDistance.value 87 group := groups[groupDistance.value] 88 count := 0 89 hits := []map[string]interface{}{} 90 for _, g := range group { 91 count += g.Count 92 hits = append(hits, g.Hits...) 93 } 94 95 sort.Slice(hits, func(i, j int) bool { 96 return hits[i]["_additional"].(*additional.GroupHitAdditional).Distance < 97 hits[j]["_additional"].(*additional.GroupHitAdditional).Distance 98 }) 99 100 if len(hits) > gm.groupBy.ObjectsPerGroup { 101 hits = hits[:gm.groupBy.ObjectsPerGroup] 102 count = len(hits) 103 } 104 105 indx := objects[val][0] 106 obj, dist := gm.objects[indx], gm.dists[indx] 107 obj.AdditionalProperties()["group"] = &additional.Group{ 108 ID: i, 109 GroupedBy: &additional.GroupedBy{ 110 Value: val, 111 Path: []string{gm.groupBy.Property}, 112 }, 113 Count: count, 114 Hits: hits, 115 MaxDistance: hits[0]["_additional"].(*additional.GroupHitAdditional).Distance, 116 MinDistance: hits[len(hits)-1]["_additional"].(*additional.GroupHitAdditional).Distance, 117 } 118 objs[i], dists[i] = obj, dist 119 } 120 121 return objs, dists, nil 122 }