github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/grouper.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package aggregator 13 14 import ( 15 "context" 16 "fmt" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/adapters/repos/db/docid" 20 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 21 "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" 22 "github.com/weaviate/weaviate/entities/aggregation" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/storobj" 25 "github.com/weaviate/weaviate/usecases/traverser/hybrid" 26 bolt "go.etcd.io/bbolt" 27 ) 28 29 // grouper is the component which identifies the top-n groups for a specific 30 // group-by parameter. It is used as part of the grouped aggregator, which then 31 // additionally performs an aggregation for each group. 32 type grouper struct { 33 *Aggregator 34 values map[interface{}]map[uint64]struct{} // map[value][docID]struct, to keep docIds unique 35 topGroups []group 36 limit int 37 } 38 39 func newGrouper(a *Aggregator, limit int) *grouper { 40 return &grouper{ 41 Aggregator: a, 42 values: map[interface{}]map[uint64]struct{}{}, 43 limit: limit, 44 } 45 } 46 47 func (g *grouper) Do(ctx context.Context) ([]group, error) { 48 if len(g.params.GroupBy.Slice()) > 1 { 49 return nil, fmt.Errorf("grouping by cross-refs not supported") 50 } 51 52 if g.params.Filters == nil && len(g.params.SearchVector) == 0 && g.params.Hybrid == nil { 53 return g.groupAll(ctx) 54 } else { 55 return g.groupFiltered(ctx) 56 } 57 } 58 59 func (g *grouper) groupAll(ctx context.Context) ([]group, error) { 60 err := ScanAllLSM(g.store, func(prop *models.PropertySchema, docID uint64) (bool, error) { 61 return true, g.addElementById(prop, docID) 62 }) 63 if err != nil { 64 return nil, errors.Wrap(err, "group all (unfiltered)") 65 } 66 67 return g.aggregateAndSelect() 68 } 69 70 func (g *grouper) groupFiltered(ctx context.Context) ([]group, error) { 71 ids, err := g.fetchDocIDs(ctx) 72 if err != nil { 73 return nil, err 74 } 75 76 if err := docid.ScanObjectsLSM(g.store, ids, 77 func(prop *models.PropertySchema, docID uint64) (bool, error) { 78 return true, g.addElementById(prop, docID) 79 }, []string{g.params.GroupBy.Property.String()}); err != nil { 80 return nil, err 81 } 82 83 return g.aggregateAndSelect() 84 } 85 86 func (g *grouper) fetchDocIDs(ctx context.Context) (ids []uint64, err error) { 87 allowList, err := g.buildAllowList(ctx) 88 if err != nil { 89 return nil, err 90 } 91 92 if len(g.params.SearchVector) > 0 { 93 ids, _, err = g.vectorSearch(allowList, g.params.SearchVector) 94 if err != nil { 95 return nil, fmt.Errorf("failed to perform vector search: %w", err) 96 } 97 } else if g.params.Hybrid != nil { 98 ids, err = g.hybrid(ctx, allowList) 99 if err != nil { 100 return nil, fmt.Errorf("hybrid search: %w", err) 101 } 102 } else { 103 ids = allowList.Slice() 104 } 105 106 return 107 } 108 109 func (g *grouper) hybrid(ctx context.Context, allowList helpers.AllowList) ([]uint64, error) { 110 sparseSearch := func() ([]*storobj.Object, []float32, error) { 111 kw, err := g.buildHybridKeywordRanking() 112 if err != nil { 113 return nil, nil, fmt.Errorf("build hybrid keyword ranking: %w", err) 114 } 115 116 if g.params.ObjectLimit == nil { 117 limit := hybrid.DefaultLimit 118 g.params.ObjectLimit = &limit 119 } 120 121 sparse, dists, err := g.bm25Objects(ctx, kw) 122 if err != nil { 123 return nil, nil, fmt.Errorf("aggregate sparse search: %w", err) 124 } 125 126 return sparse, dists, nil 127 } 128 129 denseSearch := func(vec []float32) ([]*storobj.Object, []float32, error) { 130 res, dists, err := g.objectVectorSearch(vec, allowList) 131 if err != nil { 132 return nil, nil, fmt.Errorf("aggregate grouped dense search: %w", err) 133 } 134 135 return res, dists, nil 136 } 137 138 res, err := hybrid.Search(ctx, &hybrid.Params{ 139 HybridSearch: g.params.Hybrid, 140 Keyword: nil, 141 Class: g.params.ClassName.String(), 142 }, g.logger, sparseSearch, denseSearch, nil, nil, nil, nil) 143 if err != nil { 144 return nil, err 145 } 146 147 ids := make([]uint64, len(res)) 148 for i, r := range res { 149 ids[i] = *r.DocID 150 } 151 152 return ids, nil 153 } 154 155 func (g *grouper) addElementById(s *models.PropertySchema, docID uint64) error { 156 if s == nil { 157 return nil 158 } 159 160 item, ok := (*s).(map[string]interface{})[g.params.GroupBy.Property.String()] 161 if !ok { 162 return nil 163 } 164 165 switch val := item.(type) { 166 case []string: 167 for i := range val { 168 g.addItem(val[i], docID) 169 } 170 case []float64: 171 for i := range val { 172 g.addItem(val[i], docID) 173 } 174 case []bool: 175 for i := range val { 176 g.addItem(val[i], docID) 177 } 178 case []interface{}: 179 for i := range val { 180 g.addItem(val[i], docID) 181 } 182 case models.MultipleRef: 183 for i := range val { 184 g.addItem(val[i].Beacon, docID) 185 } 186 default: 187 g.addItem(val, docID) 188 } 189 190 return nil 191 } 192 193 func (g *grouper) addItem(item interface{}, docID uint64) { 194 idsMap, ok := g.values[item] 195 if !ok { 196 idsMap = map[uint64]struct{}{} 197 } 198 idsMap[docID] = struct{}{} 199 g.values[item] = idsMap 200 } 201 202 func (g *grouper) aggregateAndSelect() ([]group, error) { 203 for value, idsMap := range g.values { 204 count := len(idsMap) 205 ids := make([]uint64, count) 206 207 i := 0 208 for id := range idsMap { 209 ids[i] = id 210 i++ 211 } 212 213 g.insertOrdered(group{ 214 res: aggregation.Group{ 215 GroupedBy: &aggregation.GroupedBy{ 216 Path: g.params.GroupBy.Slice(), 217 Value: value, 218 }, 219 Count: count, 220 }, 221 docIDs: ids, 222 }) 223 } 224 225 return g.topGroups, nil 226 } 227 228 func (g *grouper) insertOrdered(elem group) { 229 if len(g.topGroups) == 0 { 230 g.topGroups = []group{elem} 231 return 232 } 233 234 added := false 235 for i, existing := range g.topGroups { 236 if existing.res.Count > elem.res.Count { 237 continue 238 } 239 240 // we have found the first one that's smaller so we must insert before i 241 g.topGroups = append( 242 g.topGroups[:i], append( 243 []group{elem}, 244 g.topGroups[i:]..., 245 )..., 246 ) 247 248 added = true 249 break 250 } 251 252 if len(g.topGroups) > g.limit { 253 g.topGroups = g.topGroups[:len(g.topGroups)-1] 254 } 255 256 if !added && len(g.topGroups) < g.limit { 257 g.topGroups = append(g.topGroups, elem) 258 } 259 } 260 261 // ScanAll iterates over every row in the object buckets 262 // TODO: where should this live? 263 func ScanAll(tx *bolt.Tx, scan docid.ObjectScanFn) error { 264 b := tx.Bucket(helpers.ObjectsBucket) 265 if b == nil { 266 return fmt.Errorf("objects bucket not found") 267 } 268 269 b.ForEach(func(_, v []byte) error { 270 elem, err := storobj.FromBinary(v) 271 if err != nil { 272 return errors.Wrapf(err, "unmarshal data object") 273 } 274 275 // scanAll has no abort, so we can ignore the first arg 276 properties := elem.Properties() 277 _, err = scan(&properties, elem.DocID) 278 return err 279 }) 280 281 return nil 282 } 283 284 // ScanAllLSM iterates over every row in the object buckets 285 func ScanAllLSM(store *lsmkv.Store, scan docid.ObjectScanFn) error { 286 b := store.Bucket(helpers.ObjectsBucketLSM) 287 if b == nil { 288 return fmt.Errorf("objects bucket not found") 289 } 290 291 c := b.Cursor() 292 defer c.Close() 293 294 for k, v := c.First(); k != nil; k, v = c.Next() { 295 elem, err := storobj.FromBinary(v) 296 if err != nil { 297 return errors.Wrapf(err, "unmarshal data object") 298 } 299 300 // scanAll has no abort, so we can ignore the first arg 301 properties := elem.Properties() 302 _, err = scan(&properties, elem.DocID) 303 if err != nil { 304 return err 305 } 306 } 307 308 return nil 309 }