github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/grouper.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package grouper 13 14 import ( 15 "fmt" 16 17 "github.com/sirupsen/logrus" 18 "github.com/weaviate/weaviate/entities/search" 19 "github.com/weaviate/weaviate/usecases/vectorizer" 20 ) 21 22 // Grouper groups or merges search results by how related they are 23 type Grouper struct { 24 logger logrus.FieldLogger 25 } 26 27 // NewGrouper creates a Grouper UC from the specified configuration 28 func New(logger logrus.FieldLogger) *Grouper { 29 return &Grouper{logger: logger} 30 } 31 32 // Group using the applied strategy and force 33 func (g *Grouper) Group(in []search.Result, strategy string, 34 force float32, 35 ) ([]search.Result, error) { 36 groups := groups{logger: g.logger} 37 38 for _, current := range in { 39 pos, ok := groups.hasMatch(current.Vector, force) 40 if !ok { 41 groups.new(current) 42 } else { 43 groups.Elements[pos].add(current) 44 } 45 } 46 47 return groups.flatten(strategy) 48 } 49 50 type group struct { 51 Elements []search.Result `json:"elements"` 52 } 53 54 func (g *group) add(item search.Result) { 55 g.Elements = append(g.Elements, item) 56 } 57 58 func (g group) matches(vector []float32, force float32) bool { 59 // iterate over all group Elements and consider it a match if any matches 60 for _, elem := range g.Elements { 61 dist, err := vectorizer.NormalizedDistance(vector, elem.Vector) 62 if err != nil { 63 // TODO: log error 64 // we don't expect to ever see this error, so we don't need to handle it 65 // explicitly, however, let's still log it in case that the above 66 // assumption is wrong 67 continue 68 } 69 70 if dist < force { 71 return true 72 } 73 } 74 75 return false 76 } 77 78 type groups struct { 79 Elements []group `json:"elements"` 80 logger logrus.FieldLogger 81 } 82 83 func (gs groups) hasMatch(vector []float32, force float32) (int, bool) { 84 for pos, group := range gs.Elements { 85 if group.matches(vector, force) { 86 return pos, true 87 } 88 } 89 return -1, false 90 } 91 92 func (gs *groups) new(item search.Result) { 93 gs.Elements = append(gs.Elements, group{Elements: []search.Result{item}}) 94 } 95 96 func (gs groups) flatten(strategy string) (out []search.Result, err error) { 97 gs.logger.WithField("object", "grouping_before_flatten"). 98 WithField("strategy", strategy). 99 WithField("groups", gs.Elements). 100 Debug("group before flattening") 101 102 switch strategy { 103 case "closest": 104 out, err = gs.flattenClosest() 105 case "merge": 106 out, err = gs.flattenMerge() 107 default: 108 return nil, fmt.Errorf("unrecognized grouping strategy '%s'", strategy) 109 } 110 if err != nil { 111 return 112 } 113 114 gs.logger.WithField("object", "grouping_after_flatten"). 115 WithField("strategy", strategy). 116 WithField("groups", gs.Elements). 117 Debug("group after flattening") 118 119 return out, nil 120 } 121 122 func (gs groups) flattenClosest() ([]search.Result, error) { 123 out := make([]search.Result, len(gs.Elements)) 124 for i, group := range gs.Elements { 125 out[i] = group.Elements[0] // hard-code "closest" strategy for now 126 } 127 128 return out, nil 129 } 130 131 func (gs groups) flattenMerge() ([]search.Result, error) { 132 out := make([]search.Result, len(gs.Elements)) 133 for i, group := range gs.Elements { 134 merged, err := group.flattenMerge() 135 if err != nil { 136 return nil, fmt.Errorf("group %d: %v", i, err) 137 } 138 139 out[i] = merged 140 } 141 142 return out, nil 143 }