github.com/weaviate/weaviate@v1.24.6/adapters/handlers/graphql/local/aggregate/resolver.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 // Package aggregate provides the local aggregate graphql endpoint for Weaviate 13 package aggregate 14 15 import ( 16 "context" 17 "fmt" 18 "strconv" 19 "strings" 20 21 "github.com/tailor-inc/graphql" 22 "github.com/tailor-inc/graphql/language/ast" 23 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" 24 "github.com/weaviate/weaviate/entities/aggregation" 25 enterrors "github.com/weaviate/weaviate/entities/errors" 26 "github.com/weaviate/weaviate/entities/filters" 27 "github.com/weaviate/weaviate/entities/models" 28 "github.com/weaviate/weaviate/entities/schema" 29 "github.com/weaviate/weaviate/entities/searchparams" 30 ) 31 32 // GroupedByFieldName is a special graphQL field that appears alongside the 33 // to-be-aggregated props, but doesn't require any processing by the connectors 34 // itself, as it just displays meta info about the overall aggregation. 35 const GroupedByFieldName = "groupedBy" 36 37 // Resolver is a local interface that can be composed with other interfaces to 38 // form the overall GraphQL API main interface. All data-base connectors that 39 // want to support the Meta feature must implement this interface. 40 type Resolver interface { 41 Aggregate(ctx context.Context, principal *models.Principal, info *aggregation.Params) (interface{}, error) 42 } 43 44 // RequestsLog is a local abstraction on the RequestsLog that needs to be 45 // provided to the graphQL API in order to log Local.Get queries. 46 type RequestsLog interface { 47 Register(requestType string, identifier string) 48 } 49 50 func makeResolveClass(modulesProvider ModulesProvider, class *models.Class) graphql.FieldResolveFn { 51 return func(p graphql.ResolveParams) (interface{}, error) { 52 res, err := resolveAggregate(p, modulesProvider, class) 53 if err != nil { 54 return res, enterrors.NewErrGraphQLUser(err, "Aggregate", schema.ClassName(p.Info.FieldName).String()) 55 } 56 return res, nil 57 } 58 } 59 60 func resolveAggregate(p graphql.ResolveParams, modulesProvider ModulesProvider, class *models.Class) (interface{}, error) { 61 className := schema.ClassName(p.Info.FieldName) 62 source, ok := p.Source.(map[string]interface{}) 63 if !ok { 64 return nil, fmt.Errorf("expected source to be a map, but was %t", p.Source) 65 } 66 67 resolver, ok := source["Resolver"].(Resolver) 68 if !ok { 69 return nil, fmt.Errorf("expected source to contain a usable Resolver, but was %t", p.Source) 70 } 71 72 // There can only be exactly one ast.Field; it is the class name. 73 if len(p.Info.FieldASTs) != 1 { 74 panic("Only one Field expected here") 75 } 76 77 selections := p.Info.FieldASTs[0].SelectionSet 78 properties, includeMeta, err := extractProperties(selections) 79 if err != nil { 80 return nil, fmt.Errorf("could not extract properties for class '%s': %w", className, err) 81 } 82 83 groupBy, err := extractGroupBy(p.Args, p.Info.FieldName) 84 if err != nil { 85 return nil, fmt.Errorf("could not extract groupBy path: %w", err) 86 } 87 88 limit, err := extractLimit(p.Args) 89 if err != nil { 90 return nil, fmt.Errorf("could not extract limit: %w", err) 91 } 92 93 objectLimit, err := extractObjectLimit(p.Args) 94 if objectLimit != nil && *objectLimit <= 0 { 95 return nil, fmt.Errorf("objectLimit must be a positive integer") 96 } 97 if err != nil { 98 return nil, fmt.Errorf("could not extract objectLimit: %w", err) 99 } 100 101 filters, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName) 102 if err != nil { 103 return nil, fmt.Errorf("could not extract filters: %w", err) 104 } 105 106 var nearVectorParams *searchparams.NearVector 107 if nearVector, ok := p.Args["nearVector"]; ok { 108 p, err := common_filters.ExtractNearVector(nearVector.(map[string]interface{})) 109 if err != nil { 110 return nil, fmt.Errorf("failed to extract nearVector params: %w", err) 111 } 112 nearVectorParams = &p 113 } 114 115 var nearObjectParams *searchparams.NearObject 116 if nearObject, ok := p.Args["nearObject"]; ok { 117 p, err := common_filters.ExtractNearObject(nearObject.(map[string]interface{})) 118 if err != nil { 119 return nil, fmt.Errorf("failed to extract nearObject params: %w", err) 120 } 121 nearObjectParams = &p 122 } 123 124 var moduleParams map[string]interface{} 125 if modulesProvider != nil { 126 extractedParams := modulesProvider.ExtractSearchParams(p.Args, class.Class) 127 if len(extractedParams) > 0 { 128 moduleParams = extractedParams 129 } 130 } 131 132 // Extract hybrid search params from the processed query 133 // Everything hybrid can go in another namespace AFTER modulesprovider is 134 // refactored 135 var hybridParams *searchparams.HybridSearch 136 if hybrid, ok := p.Args["hybrid"]; ok { 137 p, err := common_filters.ExtractHybridSearch(hybrid.(map[string]interface{}), false) 138 if err != nil { 139 return nil, fmt.Errorf("failed to extract hybrid params: %w", err) 140 } 141 hybridParams = p 142 } 143 144 var tenant string 145 if tk, ok := p.Args["tenant"]; ok { 146 tenant = tk.(string) 147 } 148 149 params := &aggregation.Params{ 150 Filters: filters, 151 ClassName: className, 152 Properties: properties, 153 GroupBy: groupBy, 154 IncludeMetaCount: includeMeta, 155 Limit: limit, 156 ObjectLimit: objectLimit, 157 NearVector: nearVectorParams, 158 NearObject: nearObjectParams, 159 ModuleParams: moduleParams, 160 Hybrid: hybridParams, 161 Tenant: tenant, 162 } 163 164 // we might support objectLimit without nearMedia filters later, e.g. with sort 165 if params.ObjectLimit != nil && !validateObjectLimitUsage(params) { 166 return nil, fmt.Errorf("objectLimit can only be used with a near<Media> or hybrid filter") 167 } 168 169 res, err := resolver.Aggregate(p.Context, principalFromContext(p.Context), params) 170 if err != nil { 171 return nil, err 172 } 173 174 switch parsed := res.(type) { 175 case *aggregation.Result: 176 return parsed.Groups, nil 177 default: 178 return res, nil 179 } 180 } 181 182 func extractProperties(selections *ast.SelectionSet) ([]aggregation.ParamProperty, bool, error) { 183 properties := []aggregation.ParamProperty{} 184 var includeMeta bool 185 186 for _, selection := range selections.Selections { 187 field := selection.(*ast.Field) 188 name := field.Name.Value 189 if name == GroupedByFieldName { 190 // in the graphQL API we show the "groupedBy" field alongside various 191 // properties, however, we don't have to include it here, as we don't 192 // won't to perform aggregations on it. 193 // If we didn't exclude it we'd run into errors down the line, because 194 // the connector would look for a "groupedBy" prop on the specific class 195 // which doesn't exist. 196 197 continue 198 } 199 200 if name == "meta" { 201 includeMeta = true 202 continue 203 } 204 205 if name == "__typename" { 206 continue 207 } 208 209 name = strings.ToLower(string(name[0:1])) + string(name[1:]) 210 property := aggregation.ParamProperty{Name: schema.PropertyName(name)} 211 aggregators, err := extractAggregators(field.SelectionSet) 212 if err != nil { 213 return nil, false, err 214 } 215 216 property.Aggregators = aggregators 217 properties = append(properties, property) 218 } 219 220 return properties, includeMeta, nil 221 } 222 223 func extractAggregators(selections *ast.SelectionSet) ([]aggregation.Aggregator, error) { 224 if selections == nil { 225 return nil, nil 226 } 227 analyses := []aggregation.Aggregator{} 228 for _, selection := range selections.Selections { 229 field := selection.(*ast.Field) 230 name := field.Name.Value 231 if name == "__typename" { 232 continue 233 } 234 property, err := aggregation.ParseAggregatorProp(name) 235 if err != nil { 236 return nil, err 237 } 238 239 if property.String() == aggregation.NewTopOccurrencesAggregator(nil).String() { 240 // a top occurrence, so we need to check if we have a limit argument 241 if overwrite := extractLimitFromArgs(field.Arguments); overwrite != nil { 242 property.Limit = overwrite 243 } 244 } 245 246 analyses = append(analyses, property) 247 } 248 249 return analyses, nil 250 } 251 252 func extractGroupBy(args map[string]interface{}, rootClass string) (*filters.Path, error) { 253 groupBy, ok := args["groupBy"] 254 if !ok { 255 // not set means the user is not interested in grouping (former Meta) 256 return nil, nil 257 } 258 259 pathSegments, ok := groupBy.([]interface{}) 260 if !ok { 261 return nil, fmt.Errorf("no groupBy must be a list, instead got: %#v", groupBy) 262 } 263 264 return filters.ParsePath(pathSegments, rootClass) 265 } 266 267 func principalFromContext(ctx context.Context) *models.Principal { 268 principal := ctx.Value("principal") 269 if principal == nil { 270 return nil 271 } 272 273 return principal.(*models.Principal) 274 } 275 276 func extractLimit(args map[string]interface{}) (*int, error) { 277 limit, ok := args["limit"] 278 if !ok { 279 // not set means the user is not interested and the UC should use a reasonable default 280 return nil, nil 281 } 282 283 limitInt, ok := limit.(int) 284 if !ok { 285 return nil, fmt.Errorf("limit must be an int, instead got: %#v", limit) 286 } 287 288 return &limitInt, nil 289 } 290 291 func extractObjectLimit(args map[string]interface{}) (*int, error) { 292 objectLimit, ok := args["objectLimit"] 293 if !ok { 294 return nil, nil 295 } 296 297 objectLimitInt, ok := objectLimit.(int) 298 if !ok { 299 return nil, fmt.Errorf("objectLimit must be an int, instead got: %#v", objectLimit) 300 } 301 302 return &objectLimitInt, nil 303 } 304 305 func extractLimitFromArgs(args []*ast.Argument) *int { 306 for _, arg := range args { 307 if arg.Name.Value != "limit" { 308 continue 309 } 310 311 v, ok := arg.Value.GetValue().(string) 312 if ok { 313 asInt, _ := strconv.Atoi(v) 314 return &asInt 315 } 316 } 317 318 return nil 319 } 320 321 func validateObjectLimitUsage(params *aggregation.Params) bool { 322 return params.NearObject != nil || 323 params.NearVector != nil || 324 len(params.ModuleParams) > 0 || 325 params.Hybrid != nil 326 }