github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/grouper.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package aggregator
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/adapters/repos/db/docid"
    20  	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    22  	"github.com/weaviate/weaviate/entities/aggregation"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/storobj"
    25  	"github.com/weaviate/weaviate/usecases/traverser/hybrid"
    26  	bolt "go.etcd.io/bbolt"
    27  )
    28  
    29  // grouper is the component which identifies the top-n groups for a specific
    30  // group-by parameter. It is used as part of the grouped aggregator, which then
    31  // additionally performs an aggregation for each group.
    32  type grouper struct {
    33  	*Aggregator
    34  	values    map[interface{}]map[uint64]struct{} // map[value][docID]struct, to keep docIds unique
    35  	topGroups []group
    36  	limit     int
    37  }
    38  
    39  func newGrouper(a *Aggregator, limit int) *grouper {
    40  	return &grouper{
    41  		Aggregator: a,
    42  		values:     map[interface{}]map[uint64]struct{}{},
    43  		limit:      limit,
    44  	}
    45  }
    46  
    47  func (g *grouper) Do(ctx context.Context) ([]group, error) {
    48  	if len(g.params.GroupBy.Slice()) > 1 {
    49  		return nil, fmt.Errorf("grouping by cross-refs not supported")
    50  	}
    51  
    52  	if g.params.Filters == nil && len(g.params.SearchVector) == 0 && g.params.Hybrid == nil {
    53  		return g.groupAll(ctx)
    54  	} else {
    55  		return g.groupFiltered(ctx)
    56  	}
    57  }
    58  
    59  func (g *grouper) groupAll(ctx context.Context) ([]group, error) {
    60  	err := ScanAllLSM(g.store, func(prop *models.PropertySchema, docID uint64) (bool, error) {
    61  		return true, g.addElementById(prop, docID)
    62  	})
    63  	if err != nil {
    64  		return nil, errors.Wrap(err, "group all (unfiltered)")
    65  	}
    66  
    67  	return g.aggregateAndSelect()
    68  }
    69  
    70  func (g *grouper) groupFiltered(ctx context.Context) ([]group, error) {
    71  	ids, err := g.fetchDocIDs(ctx)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	if err := docid.ScanObjectsLSM(g.store, ids,
    77  		func(prop *models.PropertySchema, docID uint64) (bool, error) {
    78  			return true, g.addElementById(prop, docID)
    79  		}, []string{g.params.GroupBy.Property.String()}); err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	return g.aggregateAndSelect()
    84  }
    85  
    86  func (g *grouper) fetchDocIDs(ctx context.Context) (ids []uint64, err error) {
    87  	allowList, err := g.buildAllowList(ctx)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	if len(g.params.SearchVector) > 0 {
    93  		ids, _, err = g.vectorSearch(allowList, g.params.SearchVector)
    94  		if err != nil {
    95  			return nil, fmt.Errorf("failed to perform vector search: %w", err)
    96  		}
    97  	} else if g.params.Hybrid != nil {
    98  		ids, err = g.hybrid(ctx, allowList)
    99  		if err != nil {
   100  			return nil, fmt.Errorf("hybrid search: %w", err)
   101  		}
   102  	} else {
   103  		ids = allowList.Slice()
   104  	}
   105  
   106  	return
   107  }
   108  
   109  func (g *grouper) hybrid(ctx context.Context, allowList helpers.AllowList) ([]uint64, error) {
   110  	sparseSearch := func() ([]*storobj.Object, []float32, error) {
   111  		kw, err := g.buildHybridKeywordRanking()
   112  		if err != nil {
   113  			return nil, nil, fmt.Errorf("build hybrid keyword ranking: %w", err)
   114  		}
   115  
   116  		if g.params.ObjectLimit == nil {
   117  			limit := hybrid.DefaultLimit
   118  			g.params.ObjectLimit = &limit
   119  		}
   120  
   121  		sparse, dists, err := g.bm25Objects(ctx, kw)
   122  		if err != nil {
   123  			return nil, nil, fmt.Errorf("aggregate sparse search: %w", err)
   124  		}
   125  
   126  		return sparse, dists, nil
   127  	}
   128  
   129  	denseSearch := func(vec []float32) ([]*storobj.Object, []float32, error) {
   130  		res, dists, err := g.objectVectorSearch(vec, allowList)
   131  		if err != nil {
   132  			return nil, nil, fmt.Errorf("aggregate grouped dense search: %w", err)
   133  		}
   134  
   135  		return res, dists, nil
   136  	}
   137  
   138  	res, err := hybrid.Search(ctx, &hybrid.Params{
   139  		HybridSearch: g.params.Hybrid,
   140  		Keyword:      nil,
   141  		Class:        g.params.ClassName.String(),
   142  	}, g.logger, sparseSearch, denseSearch, nil, nil, nil, nil)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	ids := make([]uint64, len(res))
   148  	for i, r := range res {
   149  		ids[i] = *r.DocID
   150  	}
   151  
   152  	return ids, nil
   153  }
   154  
   155  func (g *grouper) addElementById(s *models.PropertySchema, docID uint64) error {
   156  	if s == nil {
   157  		return nil
   158  	}
   159  
   160  	item, ok := (*s).(map[string]interface{})[g.params.GroupBy.Property.String()]
   161  	if !ok {
   162  		return nil
   163  	}
   164  
   165  	switch val := item.(type) {
   166  	case []string:
   167  		for i := range val {
   168  			g.addItem(val[i], docID)
   169  		}
   170  	case []float64:
   171  		for i := range val {
   172  			g.addItem(val[i], docID)
   173  		}
   174  	case []bool:
   175  		for i := range val {
   176  			g.addItem(val[i], docID)
   177  		}
   178  	case []interface{}:
   179  		for i := range val {
   180  			g.addItem(val[i], docID)
   181  		}
   182  	case models.MultipleRef:
   183  		for i := range val {
   184  			g.addItem(val[i].Beacon, docID)
   185  		}
   186  	default:
   187  		g.addItem(val, docID)
   188  	}
   189  
   190  	return nil
   191  }
   192  
   193  func (g *grouper) addItem(item interface{}, docID uint64) {
   194  	idsMap, ok := g.values[item]
   195  	if !ok {
   196  		idsMap = map[uint64]struct{}{}
   197  	}
   198  	idsMap[docID] = struct{}{}
   199  	g.values[item] = idsMap
   200  }
   201  
   202  func (g *grouper) aggregateAndSelect() ([]group, error) {
   203  	for value, idsMap := range g.values {
   204  		count := len(idsMap)
   205  		ids := make([]uint64, count)
   206  
   207  		i := 0
   208  		for id := range idsMap {
   209  			ids[i] = id
   210  			i++
   211  		}
   212  
   213  		g.insertOrdered(group{
   214  			res: aggregation.Group{
   215  				GroupedBy: &aggregation.GroupedBy{
   216  					Path:  g.params.GroupBy.Slice(),
   217  					Value: value,
   218  				},
   219  				Count: count,
   220  			},
   221  			docIDs: ids,
   222  		})
   223  	}
   224  
   225  	return g.topGroups, nil
   226  }
   227  
   228  func (g *grouper) insertOrdered(elem group) {
   229  	if len(g.topGroups) == 0 {
   230  		g.topGroups = []group{elem}
   231  		return
   232  	}
   233  
   234  	added := false
   235  	for i, existing := range g.topGroups {
   236  		if existing.res.Count > elem.res.Count {
   237  			continue
   238  		}
   239  
   240  		// we have found the first one that's smaller so we must insert before i
   241  		g.topGroups = append(
   242  			g.topGroups[:i], append(
   243  				[]group{elem},
   244  				g.topGroups[i:]...,
   245  			)...,
   246  		)
   247  
   248  		added = true
   249  		break
   250  	}
   251  
   252  	if len(g.topGroups) > g.limit {
   253  		g.topGroups = g.topGroups[:len(g.topGroups)-1]
   254  	}
   255  
   256  	if !added && len(g.topGroups) < g.limit {
   257  		g.topGroups = append(g.topGroups, elem)
   258  	}
   259  }
   260  
   261  // ScanAll iterates over every row in the object buckets
   262  // TODO: where should this live?
   263  func ScanAll(tx *bolt.Tx, scan docid.ObjectScanFn) error {
   264  	b := tx.Bucket(helpers.ObjectsBucket)
   265  	if b == nil {
   266  		return fmt.Errorf("objects bucket not found")
   267  	}
   268  
   269  	b.ForEach(func(_, v []byte) error {
   270  		elem, err := storobj.FromBinary(v)
   271  		if err != nil {
   272  			return errors.Wrapf(err, "unmarshal data object")
   273  		}
   274  
   275  		// scanAll has no abort, so we can ignore the first arg
   276  		properties := elem.Properties()
   277  		_, err = scan(&properties, elem.DocID)
   278  		return err
   279  	})
   280  
   281  	return nil
   282  }
   283  
   284  // ScanAllLSM iterates over every row in the object buckets
   285  func ScanAllLSM(store *lsmkv.Store, scan docid.ObjectScanFn) error {
   286  	b := store.Bucket(helpers.ObjectsBucketLSM)
   287  	if b == nil {
   288  		return fmt.Errorf("objects bucket not found")
   289  	}
   290  
   291  	c := b.Cursor()
   292  	defer c.Close()
   293  
   294  	for k, v := c.First(); k != nil; k, v = c.Next() {
   295  		elem, err := storobj.FromBinary(v)
   296  		if err != nil {
   297  			return errors.Wrapf(err, "unmarshal data object")
   298  		}
   299  
   300  		// scanAll has no abort, so we can ignore the first arg
   301  		properties := elem.Properties()
   302  		_, err = scan(&properties, elem.DocID)
   303  		if err != nil {
   304  			return err
   305  		}
   306  	}
   307  
   308  	return nil
   309  }