github.com/weaviate/weaviate@v1.24.6/adapters/handlers/graphql/local/aggregate/resolver.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  // Package aggregate provides the local aggregate graphql endpoint for Weaviate
    13  package aggregate
    14  
    15  import (
    16  	"context"
    17  	"fmt"
    18  	"strconv"
    19  	"strings"
    20  
    21  	"github.com/tailor-inc/graphql"
    22  	"github.com/tailor-inc/graphql/language/ast"
    23  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
    24  	"github.com/weaviate/weaviate/entities/aggregation"
    25  	enterrors "github.com/weaviate/weaviate/entities/errors"
    26  	"github.com/weaviate/weaviate/entities/filters"
    27  	"github.com/weaviate/weaviate/entities/models"
    28  	"github.com/weaviate/weaviate/entities/schema"
    29  	"github.com/weaviate/weaviate/entities/searchparams"
    30  )
    31  
    32  // GroupedByFieldName is a special graphQL field that appears alongside the
    33  // to-be-aggregated props, but doesn't require any processing by the connectors
    34  // itself, as it just displays meta info about the overall aggregation.
    35  const GroupedByFieldName = "groupedBy"
    36  
    37  // Resolver is a local interface that can be composed with other interfaces to
    38  // form the overall GraphQL API main interface. All data-base connectors that
    39  // want to support the Meta feature must implement this interface.
    40  type Resolver interface {
    41  	Aggregate(ctx context.Context, principal *models.Principal, info *aggregation.Params) (interface{}, error)
    42  }
    43  
    44  // RequestsLog is a local abstraction on the RequestsLog that needs to be
    45  // provided to the graphQL API in order to log Local.Get queries.
    46  type RequestsLog interface {
    47  	Register(requestType string, identifier string)
    48  }
    49  
    50  func makeResolveClass(modulesProvider ModulesProvider, class *models.Class) graphql.FieldResolveFn {
    51  	return func(p graphql.ResolveParams) (interface{}, error) {
    52  		res, err := resolveAggregate(p, modulesProvider, class)
    53  		if err != nil {
    54  			return res, enterrors.NewErrGraphQLUser(err, "Aggregate", schema.ClassName(p.Info.FieldName).String())
    55  		}
    56  		return res, nil
    57  	}
    58  }
    59  
    60  func resolveAggregate(p graphql.ResolveParams, modulesProvider ModulesProvider, class *models.Class) (interface{}, error) {
    61  	className := schema.ClassName(p.Info.FieldName)
    62  	source, ok := p.Source.(map[string]interface{})
    63  	if !ok {
    64  		return nil, fmt.Errorf("expected source to be a map, but was %t", p.Source)
    65  	}
    66  
    67  	resolver, ok := source["Resolver"].(Resolver)
    68  	if !ok {
    69  		return nil, fmt.Errorf("expected source to contain a usable Resolver, but was %t", p.Source)
    70  	}
    71  
    72  	// There can only be exactly one ast.Field; it is the class name.
    73  	if len(p.Info.FieldASTs) != 1 {
    74  		panic("Only one Field expected here")
    75  	}
    76  
    77  	selections := p.Info.FieldASTs[0].SelectionSet
    78  	properties, includeMeta, err := extractProperties(selections)
    79  	if err != nil {
    80  		return nil, fmt.Errorf("could not extract properties for class '%s': %w", className, err)
    81  	}
    82  
    83  	groupBy, err := extractGroupBy(p.Args, p.Info.FieldName)
    84  	if err != nil {
    85  		return nil, fmt.Errorf("could not extract groupBy path: %w", err)
    86  	}
    87  
    88  	limit, err := extractLimit(p.Args)
    89  	if err != nil {
    90  		return nil, fmt.Errorf("could not extract limit: %w", err)
    91  	}
    92  
    93  	objectLimit, err := extractObjectLimit(p.Args)
    94  	if objectLimit != nil && *objectLimit <= 0 {
    95  		return nil, fmt.Errorf("objectLimit must be a positive integer")
    96  	}
    97  	if err != nil {
    98  		return nil, fmt.Errorf("could not extract objectLimit: %w", err)
    99  	}
   100  
   101  	filters, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName)
   102  	if err != nil {
   103  		return nil, fmt.Errorf("could not extract filters: %w", err)
   104  	}
   105  
   106  	var nearVectorParams *searchparams.NearVector
   107  	if nearVector, ok := p.Args["nearVector"]; ok {
   108  		p, err := common_filters.ExtractNearVector(nearVector.(map[string]interface{}))
   109  		if err != nil {
   110  			return nil, fmt.Errorf("failed to extract nearVector params: %w", err)
   111  		}
   112  		nearVectorParams = &p
   113  	}
   114  
   115  	var nearObjectParams *searchparams.NearObject
   116  	if nearObject, ok := p.Args["nearObject"]; ok {
   117  		p, err := common_filters.ExtractNearObject(nearObject.(map[string]interface{}))
   118  		if err != nil {
   119  			return nil, fmt.Errorf("failed to extract nearObject params: %w", err)
   120  		}
   121  		nearObjectParams = &p
   122  	}
   123  
   124  	var moduleParams map[string]interface{}
   125  	if modulesProvider != nil {
   126  		extractedParams := modulesProvider.ExtractSearchParams(p.Args, class.Class)
   127  		if len(extractedParams) > 0 {
   128  			moduleParams = extractedParams
   129  		}
   130  	}
   131  
   132  	// Extract hybrid search params from the processed query
   133  	// Everything hybrid can go in another namespace AFTER modulesprovider is
   134  	// refactored
   135  	var hybridParams *searchparams.HybridSearch
   136  	if hybrid, ok := p.Args["hybrid"]; ok {
   137  		p, err := common_filters.ExtractHybridSearch(hybrid.(map[string]interface{}), false)
   138  		if err != nil {
   139  			return nil, fmt.Errorf("failed to extract hybrid params: %w", err)
   140  		}
   141  		hybridParams = p
   142  	}
   143  
   144  	var tenant string
   145  	if tk, ok := p.Args["tenant"]; ok {
   146  		tenant = tk.(string)
   147  	}
   148  
   149  	params := &aggregation.Params{
   150  		Filters:          filters,
   151  		ClassName:        className,
   152  		Properties:       properties,
   153  		GroupBy:          groupBy,
   154  		IncludeMetaCount: includeMeta,
   155  		Limit:            limit,
   156  		ObjectLimit:      objectLimit,
   157  		NearVector:       nearVectorParams,
   158  		NearObject:       nearObjectParams,
   159  		ModuleParams:     moduleParams,
   160  		Hybrid:           hybridParams,
   161  		Tenant:           tenant,
   162  	}
   163  
   164  	// we might support objectLimit without nearMedia filters later, e.g. with sort
   165  	if params.ObjectLimit != nil && !validateObjectLimitUsage(params) {
   166  		return nil, fmt.Errorf("objectLimit can only be used with a near<Media> or hybrid filter")
   167  	}
   168  
   169  	res, err := resolver.Aggregate(p.Context, principalFromContext(p.Context), params)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	switch parsed := res.(type) {
   175  	case *aggregation.Result:
   176  		return parsed.Groups, nil
   177  	default:
   178  		return res, nil
   179  	}
   180  }
   181  
   182  func extractProperties(selections *ast.SelectionSet) ([]aggregation.ParamProperty, bool, error) {
   183  	properties := []aggregation.ParamProperty{}
   184  	var includeMeta bool
   185  
   186  	for _, selection := range selections.Selections {
   187  		field := selection.(*ast.Field)
   188  		name := field.Name.Value
   189  		if name == GroupedByFieldName {
   190  			// in the graphQL API we show the "groupedBy" field alongside various
   191  			// properties, however, we don't have to include it here, as we don't
   192  			// won't to perform aggregations on it.
   193  			// If we didn't exclude it we'd run into errors down the line, because
   194  			// the connector would look for a "groupedBy" prop on the specific class
   195  			// which doesn't exist.
   196  
   197  			continue
   198  		}
   199  
   200  		if name == "meta" {
   201  			includeMeta = true
   202  			continue
   203  		}
   204  
   205  		if name == "__typename" {
   206  			continue
   207  		}
   208  
   209  		name = strings.ToLower(string(name[0:1])) + string(name[1:])
   210  		property := aggregation.ParamProperty{Name: schema.PropertyName(name)}
   211  		aggregators, err := extractAggregators(field.SelectionSet)
   212  		if err != nil {
   213  			return nil, false, err
   214  		}
   215  
   216  		property.Aggregators = aggregators
   217  		properties = append(properties, property)
   218  	}
   219  
   220  	return properties, includeMeta, nil
   221  }
   222  
   223  func extractAggregators(selections *ast.SelectionSet) ([]aggregation.Aggregator, error) {
   224  	if selections == nil {
   225  		return nil, nil
   226  	}
   227  	analyses := []aggregation.Aggregator{}
   228  	for _, selection := range selections.Selections {
   229  		field := selection.(*ast.Field)
   230  		name := field.Name.Value
   231  		if name == "__typename" {
   232  			continue
   233  		}
   234  		property, err := aggregation.ParseAggregatorProp(name)
   235  		if err != nil {
   236  			return nil, err
   237  		}
   238  
   239  		if property.String() == aggregation.NewTopOccurrencesAggregator(nil).String() {
   240  			// a top occurrence, so we need to check if we have a limit argument
   241  			if overwrite := extractLimitFromArgs(field.Arguments); overwrite != nil {
   242  				property.Limit = overwrite
   243  			}
   244  		}
   245  
   246  		analyses = append(analyses, property)
   247  	}
   248  
   249  	return analyses, nil
   250  }
   251  
   252  func extractGroupBy(args map[string]interface{}, rootClass string) (*filters.Path, error) {
   253  	groupBy, ok := args["groupBy"]
   254  	if !ok {
   255  		// not set means the user is not interested in grouping (former Meta)
   256  		return nil, nil
   257  	}
   258  
   259  	pathSegments, ok := groupBy.([]interface{})
   260  	if !ok {
   261  		return nil, fmt.Errorf("no groupBy must be a list, instead got: %#v", groupBy)
   262  	}
   263  
   264  	return filters.ParsePath(pathSegments, rootClass)
   265  }
   266  
   267  func principalFromContext(ctx context.Context) *models.Principal {
   268  	principal := ctx.Value("principal")
   269  	if principal == nil {
   270  		return nil
   271  	}
   272  
   273  	return principal.(*models.Principal)
   274  }
   275  
   276  func extractLimit(args map[string]interface{}) (*int, error) {
   277  	limit, ok := args["limit"]
   278  	if !ok {
   279  		// not set means the user is not interested and the UC should use a reasonable default
   280  		return nil, nil
   281  	}
   282  
   283  	limitInt, ok := limit.(int)
   284  	if !ok {
   285  		return nil, fmt.Errorf("limit must be an int, instead got: %#v", limit)
   286  	}
   287  
   288  	return &limitInt, nil
   289  }
   290  
   291  func extractObjectLimit(args map[string]interface{}) (*int, error) {
   292  	objectLimit, ok := args["objectLimit"]
   293  	if !ok {
   294  		return nil, nil
   295  	}
   296  
   297  	objectLimitInt, ok := objectLimit.(int)
   298  	if !ok {
   299  		return nil, fmt.Errorf("objectLimit must be an int, instead got: %#v", objectLimit)
   300  	}
   301  
   302  	return &objectLimitInt, nil
   303  }
   304  
   305  func extractLimitFromArgs(args []*ast.Argument) *int {
   306  	for _, arg := range args {
   307  		if arg.Name.Value != "limit" {
   308  			continue
   309  		}
   310  
   311  		v, ok := arg.Value.GetValue().(string)
   312  		if ok {
   313  			asInt, _ := strconv.Atoi(v)
   314  			return &asInt
   315  		}
   316  	}
   317  
   318  	return nil
   319  }
   320  
   321  func validateObjectLimitUsage(params *aggregation.Params) bool {
   322  	return params.NearObject != nil ||
   323  		params.NearVector != nil ||
   324  		len(params.ModuleParams) > 0 ||
   325  		params.Hybrid != nil
   326  }