github.com/weaviate/weaviate@v1.24.6/adapters/handlers/graphql/local/get/class_builder_fields.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package get 13 14 import ( 15 "context" 16 "fmt" 17 "regexp" 18 "strings" 19 20 "github.com/tailor-inc/graphql" 21 "github.com/tailor-inc/graphql/language/ast" 22 "github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions" 23 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" 24 "github.com/weaviate/weaviate/entities/additional" 25 "github.com/weaviate/weaviate/entities/dto" 26 enterrors "github.com/weaviate/weaviate/entities/errors" 27 "github.com/weaviate/weaviate/entities/filters" 28 "github.com/weaviate/weaviate/entities/models" 29 "github.com/weaviate/weaviate/entities/modulecapabilities" 30 "github.com/weaviate/weaviate/entities/schema" 31 "github.com/weaviate/weaviate/entities/search" 32 "github.com/weaviate/weaviate/entities/searchparams" 33 ) 34 35 func (b *classBuilder) primitiveField(propertyType schema.PropertyDataType, 36 property *models.Property, className string, 37 ) *graphql.Field { 38 switch propertyType.AsPrimitive() { 39 case schema.DataTypeText: 40 return &graphql.Field{ 41 Description: property.Description, 42 Name: property.Name, 43 Type: graphql.String, 44 } 45 case schema.DataTypeInt: 46 return &graphql.Field{ 47 Description: property.Description, 48 Name: property.Name, 49 Type: graphql.Int, 50 } 51 case schema.DataTypeNumber: 52 return &graphql.Field{ 53 Description: property.Description, 54 Name: property.Name, 55 Type: graphql.Float, 56 } 57 case schema.DataTypeBoolean: 58 return &graphql.Field{ 59 Description: property.Description, 60 Name: property.Name, 61 Type: graphql.Boolean, 62 } 63 case schema.DataTypeDate: 64 return &graphql.Field{ 65 Description: property.Description, 66 Name: property.Name, 67 Type: graphql.String, // String since no graphql date datatype exists 68 } 69 case schema.DataTypeGeoCoordinates: 70 obj := newGeoCoordinatesObject(className, property.Name) 71 72 return &graphql.Field{ 73 Description: property.Description, 74 Name: property.Name, 75 Type: obj, 76 Resolve: resolveGeoCoordinates, 77 } 78 case schema.DataTypePhoneNumber: 79 obj := newPhoneNumberObject(className, property.Name) 80 81 return &graphql.Field{ 82 Description: property.Description, 83 Name: property.Name, 84 Type: obj, 85 Resolve: resolvePhoneNumber, 86 } 87 case schema.DataTypeBlob: 88 return &graphql.Field{ 89 Description: property.Description, 90 Name: property.Name, 91 Type: graphql.String, 92 } 93 case schema.DataTypeTextArray: 94 return &graphql.Field{ 95 Description: property.Description, 96 Name: property.Name, 97 Type: graphql.NewList(graphql.String), 98 } 99 case schema.DataTypeIntArray: 100 return &graphql.Field{ 101 Description: property.Description, 102 Name: property.Name, 103 Type: graphql.NewList(graphql.Int), 104 } 105 case schema.DataTypeNumberArray: 106 return &graphql.Field{ 107 Description: property.Description, 108 Name: property.Name, 109 Type: graphql.NewList(graphql.Float), 110 } 111 case schema.DataTypeBooleanArray: 112 return &graphql.Field{ 113 Description: property.Description, 114 Name: property.Name, 115 Type: graphql.NewList(graphql.Boolean), 116 } 117 case schema.DataTypeDateArray: 118 return &graphql.Field{ 119 Description: property.Description, 120 Name: property.Name, 121 Type: graphql.NewList(graphql.String), // String since no graphql date datatype exists 122 } 123 case schema.DataTypeUUIDArray: 124 return &graphql.Field{ 125 Description: property.Description, 126 Name: property.Name, 127 Type: graphql.NewList(graphql.String), // Always return UUID as string representation to the user 128 } 129 case schema.DataTypeUUID: 130 return &graphql.Field{ 131 Description: property.Description, 132 Name: property.Name, 133 Type: graphql.String, // Always return UUID as string representation to the user 134 } 135 default: 136 panic(fmt.Sprintf("buildGetClass: unknown primitive type for %s.%s; %s", 137 className, property.Name, propertyType.AsPrimitive())) 138 } 139 } 140 141 func newGeoCoordinatesObject(className string, propertyName string) *graphql.Object { 142 return graphql.NewObject(graphql.ObjectConfig{ 143 Description: "GeoCoordinates as latitude and longitude in decimal form", 144 Name: fmt.Sprintf("%s%sGeoCoordinatesObj", className, propertyName), 145 Fields: graphql.Fields{ 146 "latitude": &graphql.Field{ 147 Name: "Latitude", 148 Description: "The Latitude of the point in decimal form.", 149 Type: graphql.Float, 150 }, 151 "longitude": &graphql.Field{ 152 Name: "Longitude", 153 Description: "The Longitude of the point in decimal form.", 154 Type: graphql.Float, 155 }, 156 }, 157 }) 158 } 159 160 func newPhoneNumberObject(className string, propertyName string) *graphql.Object { 161 return graphql.NewObject(graphql.ObjectConfig{ 162 Description: "PhoneNumber in various parsed formats", 163 Name: fmt.Sprintf("%s%sPhoneNumberObj", className, propertyName), 164 Fields: graphql.Fields{ 165 "input": &graphql.Field{ 166 Name: "Input", 167 Description: "The raw phone number as put in by the user prior to parsing", 168 Type: graphql.String, 169 }, 170 "internationalFormatted": &graphql.Field{ 171 Name: "Input", 172 Description: "The parsed phone number in the international format", 173 Type: graphql.String, 174 }, 175 "nationalFormatted": &graphql.Field{ 176 Name: "Input", 177 Description: "The parsed phone number in the national format", 178 Type: graphql.String, 179 }, 180 "national": &graphql.Field{ 181 Name: "Input", 182 Description: "The parsed phone number in the national format", 183 Type: graphql.Int, 184 }, 185 "valid": &graphql.Field{ 186 Name: "Input", 187 Description: "Whether the phone number could be successfully parsed and was considered valid by the parser", 188 Type: graphql.Boolean, 189 }, 190 "countryCode": &graphql.Field{ 191 Name: "Input", 192 Description: "The parsed country code, i.e. the leading numbers identifing the country in an international format", 193 Type: graphql.Int, 194 }, 195 "defaultCountry": &graphql.Field{ 196 Name: "Input", 197 Description: "The defaultCountry as put in by the user. (This is used to help parse national numbers into an international format)", 198 Type: graphql.String, 199 }, 200 }, 201 }) 202 } 203 204 func buildGetClassField(classObject *graphql.Object, 205 class *models.Class, modulesProvider ModulesProvider, fusionEnum *graphql.Enum, 206 ) graphql.Field { 207 field := graphql.Field{ 208 Type: graphql.NewList(classObject), 209 Description: class.Description, 210 Args: graphql.FieldConfigArgument{ 211 "after": &graphql.ArgumentConfig{ 212 Description: descriptions.AfterID, 213 Type: graphql.String, 214 }, 215 "limit": &graphql.ArgumentConfig{ 216 Description: descriptions.Limit, 217 Type: graphql.Int, 218 }, 219 "offset": &graphql.ArgumentConfig{ 220 Description: descriptions.After, 221 Type: graphql.Int, 222 }, 223 "autocut": &graphql.ArgumentConfig{ 224 Description: "Cut off number of results after the Nth extrema. Off by default, negative numbers mean off.", 225 Type: graphql.Int, 226 }, 227 228 "sort": sortArgument(class.Class), 229 "nearVector": nearVectorArgument(class.Class), 230 "nearObject": nearObjectArgument(class.Class), 231 "where": whereArgument(class.Class), 232 "group": groupArgument(class.Class), 233 "groupBy": groupByArgument(class.Class), 234 }, 235 Resolve: newResolver(modulesProvider).makeResolveGetClass(class.Class), 236 } 237 238 field.Args["bm25"] = bm25Argument(class.Class) 239 field.Args["hybrid"] = hybridArgument(classObject, class, modulesProvider, fusionEnum) 240 241 if modulesProvider != nil { 242 for name, argument := range modulesProvider.GetArguments(class) { 243 field.Args[name] = argument 244 } 245 } 246 247 if replicationEnabled(class) { 248 field.Args["consistencyLevel"] = consistencyLevelArgument(class) 249 } 250 251 if schema.MultiTenancyEnabled(class) { 252 field.Args["tenant"] = tenantArgument() 253 } 254 255 return field 256 } 257 258 func resolveGeoCoordinates(p graphql.ResolveParams) (interface{}, error) { 259 field := p.Source.(map[string]interface{})[p.Info.FieldName] 260 if field == nil { 261 return nil, nil 262 } 263 264 geo, ok := field.(*models.GeoCoordinates) 265 if !ok { 266 return nil, fmt.Errorf("expected a *models.GeoCoordinates, but got: %T", field) 267 } 268 269 return map[string]interface{}{ 270 "latitude": geo.Latitude, 271 "longitude": geo.Longitude, 272 }, nil 273 } 274 275 func resolvePhoneNumber(p graphql.ResolveParams) (interface{}, error) { 276 field := p.Source.(map[string]interface{})[p.Info.FieldName] 277 if field == nil { 278 return nil, nil 279 } 280 281 phone, ok := field.(*models.PhoneNumber) 282 if !ok { 283 return nil, fmt.Errorf("expected a *models.PhoneNumber, but got: %T", field) 284 } 285 286 return map[string]interface{}{ 287 "input": phone.Input, 288 "internationalFormatted": phone.InternationalFormatted, 289 "nationalFormatted": phone.NationalFormatted, 290 "national": phone.National, 291 "valid": phone.Valid, 292 "countryCode": phone.CountryCode, 293 "defaultCountry": phone.DefaultCountry, 294 }, nil 295 } 296 297 func whereArgument(className string) *graphql.ArgumentConfig { 298 return &graphql.ArgumentConfig{ 299 Description: descriptions.GetWhere, 300 Type: graphql.NewInputObject( 301 graphql.InputObjectConfig{ 302 Name: fmt.Sprintf("GetObjects%sWhereInpObj", className), 303 Fields: common_filters.BuildNew(fmt.Sprintf("GetObjects%s", className)), 304 Description: descriptions.GetWhereInpObj, 305 }, 306 ), 307 } 308 } 309 310 type resolver struct { 311 modulesProvider ModulesProvider 312 } 313 314 func newResolver(modulesProvider ModulesProvider) *resolver { 315 return &resolver{modulesProvider} 316 } 317 318 func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn { 319 return func(p graphql.ResolveParams) (interface{}, error) { 320 result, err := r.resolveGet(p, className) 321 if err != nil { 322 return result, enterrors.NewErrGraphQLUser(err, "Get", className) 323 } 324 return result, nil 325 } 326 } 327 328 func (r *resolver) resolveGet(p graphql.ResolveParams, className string) (interface{}, error) { 329 source, ok := p.Source.(map[string]interface{}) 330 if !ok { 331 return nil, fmt.Errorf("expected graphql root to be a map, but was %T", p.Source) 332 } 333 334 resolver, ok := source["Resolver"].(Resolver) 335 if !ok { 336 return nil, fmt.Errorf("expected source map to have a usable Resolver, but got %#v", source["Resolver"]) 337 } 338 339 pagination, err := filters.ExtractPaginationFromArgs(p.Args) 340 if err != nil { 341 return nil, err 342 } 343 344 cursor, err := filters.ExtractCursorFromArgs(p.Args) 345 if err != nil { 346 return nil, err 347 } 348 349 // There can only be exactly one ast.Field; it is the class name. 350 if len(p.Info.FieldASTs) != 1 { 351 panic("Only one Field expected here") 352 } 353 354 selectionsOfClass := p.Info.FieldASTs[0].SelectionSet 355 356 properties, addlProps, err := extractProperties( 357 className, selectionsOfClass, p.Info.Fragments, r.modulesProvider) 358 if err != nil { 359 return nil, err 360 } 361 362 var sort []filters.Sort 363 if sortArg, ok := p.Args["sort"]; ok { 364 sort = filters.ExtractSortFromArgs(sortArg.([]interface{})) 365 } 366 367 filters, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName) 368 if err != nil { 369 return nil, fmt.Errorf("could not extract filters: %s", err) 370 } 371 372 var nearVectorParams *searchparams.NearVector 373 if nearVector, ok := p.Args["nearVector"]; ok { 374 p, err := common_filters.ExtractNearVector(nearVector.(map[string]interface{})) 375 if err != nil { 376 return nil, fmt.Errorf("failed to extract nearVector params: %s", err) 377 } 378 nearVectorParams = &p 379 } 380 381 var nearObjectParams *searchparams.NearObject 382 if nearObject, ok := p.Args["nearObject"]; ok { 383 p, err := common_filters.ExtractNearObject(nearObject.(map[string]interface{})) 384 if err != nil { 385 return nil, fmt.Errorf("failed to extract nearObject params: %s", err) 386 } 387 nearObjectParams = &p 388 } 389 390 var moduleParams map[string]interface{} 391 if r.modulesProvider != nil { 392 extractedParams := r.modulesProvider.ExtractSearchParams(p.Args, className) 393 if len(extractedParams) > 0 { 394 moduleParams = extractedParams 395 } 396 } 397 398 // extracts bm25 (sparseSearch) from the query 399 var keywordRankingParams *searchparams.KeywordRanking 400 if bm25, ok := p.Args["bm25"]; ok { 401 if len(sort) > 0 { 402 return nil, fmt.Errorf("bm25 search is not compatible with sort") 403 } 404 p := common_filters.ExtractBM25(bm25.(map[string]interface{}), addlProps.ExplainScore) 405 keywordRankingParams = &p 406 } 407 408 // Extract hybrid search params from the processed query 409 // Everything hybrid can go in another namespace AFTER modulesprovider is 410 // refactored 411 var hybridParams *searchparams.HybridSearch 412 if hybrid, ok := p.Args["hybrid"]; ok { 413 if len(sort) > 0 { 414 return nil, fmt.Errorf("hybrid search is not compatible with sort") 415 } 416 p, err := common_filters.ExtractHybridSearch(hybrid.(map[string]interface{}), addlProps.ExplainScore) 417 if err != nil { 418 return nil, fmt.Errorf("failed to extract hybrid params: %w", err) 419 } 420 hybridParams = p 421 } 422 423 var replProps *additional.ReplicationProperties 424 if cl, ok := p.Args["consistencyLevel"]; ok { 425 replProps = &additional.ReplicationProperties{ 426 ConsistencyLevel: cl.(string), 427 } 428 } 429 430 group := extractGroup(p.Args) 431 432 var groupByParams *searchparams.GroupBy 433 if groupBy, ok := p.Args["groupBy"]; ok { 434 p := common_filters.ExtractGroupBy(groupBy.(map[string]interface{})) 435 groupByParams = &p 436 } 437 438 var tenant string 439 if tk, ok := p.Args["tenant"]; ok { 440 tenant = tk.(string) 441 } 442 443 params := dto.GetParams{ 444 Filters: filters, 445 ClassName: className, 446 Pagination: pagination, 447 Cursor: cursor, 448 Properties: properties, 449 Sort: sort, 450 NearVector: nearVectorParams, 451 NearObject: nearObjectParams, 452 Group: group, 453 ModuleParams: moduleParams, 454 AdditionalProperties: addlProps, 455 KeywordRanking: keywordRankingParams, 456 HybridSearch: hybridParams, 457 ReplicationProperties: replProps, 458 GroupBy: groupByParams, 459 Tenant: tenant, 460 } 461 462 // need to perform vector search by distance 463 // under certain conditions 464 setLimitBasedOnVectorSearchParams(¶ms) 465 466 return func() (interface{}, error) { 467 result, err := resolver.GetClass(p.Context, principalFromContext(p.Context), params) 468 if err != nil { 469 return result, enterrors.NewErrGraphQLUser(err, "Get", params.ClassName) 470 } 471 return result, nil 472 }, nil 473 } 474 475 // the limit needs to be set according to the vector search parameters. 476 // for example, if a certainty is provided by any of the near* options, 477 // and no limit was provided, weaviate will want to execute a vector 478 // search by distance. it knows to do this by watching for a limit 479 // flag, specifically filters.LimitFlagSearchByDistance 480 func setLimitBasedOnVectorSearchParams(params *dto.GetParams) { 481 setLimit := func(params *dto.GetParams) { 482 if params.Pagination == nil { 483 // limit was omitted entirely, implicitly 484 // indicating to do unlimited search 485 params.Pagination = &filters.Pagination{ 486 Limit: filters.LimitFlagSearchByDist, 487 } 488 } else if params.Pagination.Limit < 0 { 489 // a negative limit was set, explicitly 490 // indicating to do unlimited search 491 params.Pagination.Limit = filters.LimitFlagSearchByDist 492 } 493 } 494 495 if params.NearVector != nil && 496 (params.NearVector.Certainty != 0 || params.NearVector.WithDistance) { 497 setLimit(params) 498 return 499 } 500 501 if params.NearObject != nil && 502 (params.NearObject.Certainty != 0 || params.NearObject.WithDistance) { 503 setLimit(params) 504 return 505 } 506 507 for _, param := range params.ModuleParams { 508 nearParam, ok := param.(modulecapabilities.NearParam) 509 if ok && nearParam.SimilarityMetricProvided() { 510 setLimit(params) 511 return 512 } 513 } 514 } 515 516 func extractGroup(args map[string]interface{}) *dto.GroupParams { 517 group, ok := args["group"] 518 if !ok { 519 return nil 520 } 521 522 asMap := group.(map[string]interface{}) // guaranteed by graphql 523 strategy := asMap["type"].(string) 524 force := asMap["force"].(float64) 525 return &dto.GroupParams{ 526 Strategy: strategy, 527 Force: float32(force), 528 } 529 } 530 531 func principalFromContext(ctx context.Context) *models.Principal { 532 principal := ctx.Value("principal") 533 if principal == nil { 534 return nil 535 } 536 537 return principal.(*models.Principal) 538 } 539 540 func isPrimitive(selectionSet *ast.SelectionSet) bool { 541 if selectionSet == nil { 542 return true 543 } 544 545 // if there is a selection set it could either be a cross-ref or a map-type 546 // field like GeoCoordinates or PhoneNumber 547 for _, subSelection := range selectionSet.Selections { 548 if subsectionField, ok := subSelection.(*ast.Field); ok { 549 if fieldNameIsOfObjectButNonReferenceType(subsectionField.Name.Value) { 550 return true 551 } 552 } 553 } 554 555 // must be a ref field 556 return false 557 } 558 559 type additionalCheck struct { 560 modulesProvider ModulesProvider 561 } 562 563 func (ac *additionalCheck) isAdditional(parentName, name string) bool { 564 if parentName == "_additional" { 565 if name == "classification" || name == "certainty" || 566 name == "distance" || name == "id" || name == "vector" || name == "vectors" || 567 name == "creationTimeUnix" || name == "lastUpdateTimeUnix" || 568 name == "score" || name == "explainScore" || name == "isConsistent" || 569 name == "group" { 570 return true 571 } 572 if ac.isModuleAdditional(name) { 573 return true 574 } 575 } 576 return false 577 } 578 579 func (ac *additionalCheck) isModuleAdditional(name string) bool { 580 if ac.modulesProvider != nil { 581 if len(ac.modulesProvider.GraphQLAdditionalFieldNames()) > 0 { 582 for _, moduleAdditionalProperty := range ac.modulesProvider.GraphQLAdditionalFieldNames() { 583 if name == moduleAdditionalProperty { 584 return true 585 } 586 } 587 } 588 } 589 return false 590 } 591 592 func fieldNameIsOfObjectButNonReferenceType(field string) bool { 593 switch field { 594 case "latitude", "longitude": 595 // must be a geo prop 596 return true 597 case "input", "internationalFormatted", "nationalFormatted", "national", 598 "valid", "countryCode", "defaultCountry": 599 // must be a phone number 600 return true 601 default: 602 return false 603 } 604 } 605 606 func extractProperties(className string, selections *ast.SelectionSet, 607 fragments map[string]ast.Definition, 608 modulesProvider ModulesProvider, 609 ) ([]search.SelectProperty, additional.Properties, error) { 610 var properties []search.SelectProperty 611 var additionalProps additional.Properties 612 additionalCheck := &additionalCheck{modulesProvider} 613 614 for _, selection := range selections.Selections { 615 field := selection.(*ast.Field) 616 name := field.Name.Value 617 property := search.SelectProperty{Name: name} 618 619 property.IsPrimitive = isPrimitive(field.SelectionSet) 620 if !property.IsPrimitive { 621 // We can interpret this property in different ways 622 for _, subSelection := range field.SelectionSet.Selections { 623 switch s := subSelection.(type) { 624 case *ast.Field: 625 // Is it a field with the name __typename? 626 if s.Name.Value == "__typename" { 627 property.IncludeTypeName = true 628 continue 629 } else if additionalCheck.isAdditional(name, s.Name.Value) { 630 additionalProperty := s.Name.Value 631 if additionalProperty == "classification" { 632 additionalProps.Classification = true 633 continue 634 } 635 if additionalProperty == "certainty" { 636 additionalProps.Certainty = true 637 continue 638 } 639 if additionalProperty == "distance" { 640 additionalProps.Distance = true 641 continue 642 } 643 if additionalProperty == "id" { 644 additionalProps.ID = true 645 continue 646 } 647 if additionalProperty == "vector" { 648 additionalProps.Vector = true 649 continue 650 } 651 if additionalProperty == "vectors" { 652 if s.SelectionSet != nil && len(s.SelectionSet.Selections) > 0 { 653 vectors := make([]string, len(s.SelectionSet.Selections)) 654 for i, selection := range s.SelectionSet.Selections { 655 if field, ok := selection.(*ast.Field); ok { 656 vectors[i] = field.Name.Value 657 } 658 } 659 additionalProps.Vectors = vectors 660 } 661 continue 662 } 663 if additionalProperty == "creationTimeUnix" { 664 additionalProps.CreationTimeUnix = true 665 continue 666 } 667 if additionalProperty == "score" { 668 additionalProps.Score = true 669 continue 670 } 671 if additionalProperty == "explainScore" { 672 additionalProps.ExplainScore = true 673 continue 674 } 675 if additionalProperty == "lastUpdateTimeUnix" { 676 additionalProps.LastUpdateTimeUnix = true 677 continue 678 } 679 if additionalProperty == "isConsistent" { 680 additionalProps.IsConsistent = true 681 continue 682 } 683 if additionalProperty == "group" { 684 additionalProps.Group = true 685 additionalGroupHitProperties, err := extractGroupHitProperties(className, additionalProps, subSelection, fragments, modulesProvider) 686 if err != nil { 687 return nil, additionalProps, err 688 } 689 properties = append(properties, additionalGroupHitProperties...) 690 continue 691 } 692 if modulesProvider != nil { 693 if additionalCheck.isModuleAdditional(additionalProperty) { 694 additionalProps.ModuleParams = getModuleParams(additionalProps.ModuleParams) 695 additionalProps.ModuleParams[additionalProperty] = modulesProvider.ExtractAdditionalField(className, additionalProperty, s.Arguments) 696 continue 697 } 698 } 699 } else { 700 // It's an object / object array property 701 continue 702 } 703 704 case *ast.FragmentSpread: 705 ref, err := extractFragmentSpread(className, s, fragments, modulesProvider) 706 if err != nil { 707 return nil, additionalProps, err 708 } 709 710 property.Refs = append(property.Refs, ref) 711 712 case *ast.InlineFragment: 713 ref, err := extractInlineFragment(className, s, fragments, modulesProvider) 714 if err != nil { 715 return nil, additionalProps, err 716 } 717 718 property.Refs = append(property.Refs, ref) 719 720 default: 721 return nil, additionalProps, fmt.Errorf("unrecoginzed type in subs-selection: %T", subSelection) 722 } 723 } 724 } 725 726 if name == "_additional" { 727 continue 728 } 729 730 properties = append(properties, property) 731 } 732 733 return properties, additionalProps, nil 734 } 735 736 func extractGroupHitProperties( 737 className string, 738 additionalProps additional.Properties, 739 subSelection ast.Selection, 740 fragments map[string]ast.Definition, 741 modulesProvider ModulesProvider, 742 ) ([]search.SelectProperty, error) { 743 additionalGroupProperties := []search.SelectProperty{} 744 if subSelection != nil { 745 if selectionSet := subSelection.GetSelectionSet(); selectionSet != nil { 746 for _, groupSubSelection := range selectionSet.Selections { 747 if groupSubSelection != nil { 748 if groupSubSelectionField, ok := groupSubSelection.(*ast.Field); ok { 749 if groupSubSelectionField.Name.Value == "hits" && groupSubSelectionField.SelectionSet != nil { 750 for _, groupHitsSubSelection := range groupSubSelectionField.SelectionSet.Selections { 751 if hf, ok := groupHitsSubSelection.(*ast.Field); ok { 752 if hf.SelectionSet != nil { 753 for _, ss := range hf.SelectionSet.Selections { 754 if inlineFrag, ok := ss.(*ast.InlineFragment); ok { 755 ref, err := extractInlineFragment(className, inlineFrag, fragments, modulesProvider) 756 if err != nil { 757 return nil, err 758 } 759 760 additionalGroupHitProp := search.SelectProperty{Name: fmt.Sprintf("_additional:group:hits:%v", hf.Name.Value)} 761 additionalGroupHitProp.Refs = append(additionalGroupHitProp.Refs, ref) 762 additionalGroupProperties = append(additionalGroupProperties, additionalGroupHitProp) 763 } 764 } 765 } 766 } 767 } 768 } 769 } 770 } 771 } 772 } 773 } 774 return additionalGroupProperties, nil 775 } 776 777 func getModuleParams(moduleParams map[string]interface{}) map[string]interface{} { 778 if moduleParams == nil { 779 return map[string]interface{}{} 780 } 781 return moduleParams 782 } 783 784 func extractInlineFragment(class string, fragment *ast.InlineFragment, 785 fragments map[string]ast.Definition, 786 modulesProvider ModulesProvider, 787 ) (search.SelectClass, error) { 788 var className schema.ClassName 789 var err error 790 var result search.SelectClass 791 792 if strings.Contains(fragment.TypeCondition.Name.Value, "__") { 793 // is a helper type for a network ref 794 // don't validate anything as of now 795 className = schema.ClassName(fragment.TypeCondition.Name.Value) 796 } else { 797 className, err = schema.ValidateClassName(fragment.TypeCondition.Name.Value) 798 if err != nil { 799 return result, fmt.Errorf("the inline fragment type name '%s' is not a valid class name", fragment.TypeCondition.Name.Value) 800 } 801 } 802 803 if className == "Beacon" { 804 return result, fmt.Errorf("retrieving cross-refs by beacon is not supported yet - coming soon!") 805 } 806 807 subProperties, additionalProperties, err := extractProperties(class, fragment.SelectionSet, fragments, modulesProvider) 808 if err != nil { 809 return result, err 810 } 811 812 result.ClassName = string(className) 813 result.RefProperties = subProperties 814 result.AdditionalProperties = additionalProperties 815 return result, nil 816 } 817 818 func extractFragmentSpread(class string, spread *ast.FragmentSpread, 819 fragments map[string]ast.Definition, 820 modulesProvider ModulesProvider, 821 ) (search.SelectClass, error) { 822 var result search.SelectClass 823 name := spread.Name.Value 824 825 def, ok := fragments[name] 826 if !ok { 827 return result, fmt.Errorf("spread fragment '%s' refers to unknown fragment", name) 828 } 829 830 className, err := hackyWorkaroundToExtractClassName(def, name) 831 if err != nil { 832 return result, err 833 } 834 835 subProperties, additionalProperties, err := extractProperties(class, def.GetSelectionSet(), fragments, modulesProvider) 836 if err != nil { 837 return result, err 838 } 839 840 result.ClassName = string(className) 841 result.RefProperties = subProperties 842 result.AdditionalProperties = additionalProperties 843 return result, nil 844 } 845 846 // It seems there's no proper way to extract this info unfortunately: 847 // https://github.com/tailor-inc/graphql/issues/455 848 func hackyWorkaroundToExtractClassName(def ast.Definition, name string) (string, error) { 849 loc := def.GetLoc() 850 raw := loc.Source.Body[loc.Start:loc.End] 851 r := regexp.MustCompile(fmt.Sprintf(`fragment\s*%s\s*on\s*(\w*)\s*{`, name)) 852 matches := r.FindSubmatch(raw) 853 if len(matches) < 2 { 854 return "", fmt.Errorf("could not extract a className from fragment") 855 } 856 857 return string(matches[1]), nil 858 }