github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/aggregator.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package aggregator
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/sirupsen/logrus"
    20  	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/inverted"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/inverted/stopwords"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    24  	"github.com/weaviate/weaviate/adapters/repos/db/roaringset"
    25  	"github.com/weaviate/weaviate/entities/aggregation"
    26  	"github.com/weaviate/weaviate/entities/schema"
    27  	schemaUC "github.com/weaviate/weaviate/usecases/schema"
    28  )
    29  
    30  type vectorIndex interface {
    31  	SearchByVectorDistance(vector []float32, targetDistance float32, maxLimit int64,
    32  		allowList helpers.AllowList) ([]uint64, []float32, error)
    33  	SearchByVector(vector []float32, k int, allowList helpers.AllowList) ([]uint64, []float32, error)
    34  }
    35  
    36  type Aggregator struct {
    37  	logger                 logrus.FieldLogger
    38  	store                  *lsmkv.Store
    39  	params                 aggregation.Params
    40  	getSchema              schemaUC.SchemaGetter
    41  	classSearcher          inverted.ClassSearcher // to support ref-filters
    42  	vectorIndex            vectorIndex
    43  	stopwords              stopwords.StopwordDetector
    44  	shardVersion           uint16
    45  	propLenTracker         *inverted.JsonPropertyLengthTracker
    46  	isFallbackToSearchable inverted.IsFallbackToSearchable
    47  	tenant                 string
    48  	nestedCrossRefLimit    int64
    49  	bitmapFactory          *roaringset.BitmapFactory
    50  }
    51  
    52  func New(store *lsmkv.Store, params aggregation.Params,
    53  	getSchema schemaUC.SchemaGetter, classSearcher inverted.ClassSearcher,
    54  	stopwords stopwords.StopwordDetector, shardVersion uint16,
    55  	vectorIndex vectorIndex, logger logrus.FieldLogger,
    56  	propLenTracker *inverted.JsonPropertyLengthTracker,
    57  	isFallbackToSearchable inverted.IsFallbackToSearchable,
    58  	tenant string, nestedCrossRefLimit int64,
    59  	bitmapFactory *roaringset.BitmapFactory,
    60  ) *Aggregator {
    61  	return &Aggregator{
    62  		logger:                 logger,
    63  		store:                  store,
    64  		params:                 params,
    65  		getSchema:              getSchema,
    66  		classSearcher:          classSearcher,
    67  		stopwords:              stopwords,
    68  		shardVersion:           shardVersion,
    69  		vectorIndex:            vectorIndex,
    70  		propLenTracker:         propLenTracker,
    71  		isFallbackToSearchable: isFallbackToSearchable,
    72  		tenant:                 tenant,
    73  		nestedCrossRefLimit:    nestedCrossRefLimit,
    74  		bitmapFactory:          bitmapFactory,
    75  	}
    76  }
    77  
    78  func (a *Aggregator) GetPropertyLengthTracker() *inverted.JsonPropertyLengthTracker {
    79  	return a.propLenTracker
    80  }
    81  
    82  func (a *Aggregator) Do(ctx context.Context) (*aggregation.Result, error) {
    83  	if a.params.GroupBy != nil {
    84  		return newGroupedAggregator(a).Do(ctx)
    85  	}
    86  
    87  	if a.params.Filters != nil || len(a.params.SearchVector) > 0 || a.params.Hybrid != nil {
    88  		return newFilteredAggregator(a).Do(ctx)
    89  	}
    90  
    91  	return newUnfilteredAggregator(a).Do(ctx)
    92  }
    93  
    94  func (a *Aggregator) aggTypeOfProperty(
    95  	name schema.PropertyName,
    96  ) (aggregation.PropertyType, schema.DataType, error) {
    97  	s := a.getSchema.GetSchemaSkipAuth()
    98  	schemaProp, err := s.GetProperty(a.params.ClassName, name)
    99  	if err != nil {
   100  		return "", "", errors.Wrapf(err, "property %s", name)
   101  	}
   102  
   103  	if schema.IsRefDataType(schemaProp.DataType) {
   104  		return aggregation.PropertyTypeReference, schema.DataTypeCRef, nil
   105  	}
   106  
   107  	dt := schema.DataType(schemaProp.DataType[0])
   108  	switch dt {
   109  	case schema.DataTypeInt, schema.DataTypeNumber, schema.DataTypeIntArray,
   110  		schema.DataTypeNumberArray:
   111  		return aggregation.PropertyTypeNumerical, dt, nil
   112  	case schema.DataTypeBoolean, schema.DataTypeBooleanArray:
   113  		return aggregation.PropertyTypeBoolean, dt, nil
   114  	case schema.DataTypeText, schema.DataTypeTextArray:
   115  		return aggregation.PropertyTypeText, dt, nil
   116  	case schema.DataTypeDate, schema.DataTypeDateArray:
   117  		return aggregation.PropertyTypeDate, dt, nil
   118  	case schema.DataTypeGeoCoordinates:
   119  		return "", "", fmt.Errorf("dataType geoCoordinates can't be aggregated")
   120  	case schema.DataTypePhoneNumber:
   121  		return "", "", fmt.Errorf("dataType phoneNumber can't be aggregated")
   122  	default:
   123  		return "", "", fmt.Errorf("unrecoginzed dataType %v", schemaProp.DataType[0])
   124  	}
   125  }