github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/field.go (about) 1 package codegen 2 3 import ( 4 "errors" 5 "fmt" 6 goast "go/ast" 7 "go/types" 8 "log" 9 "reflect" 10 "strconv" 11 "strings" 12 13 "github.com/vektah/gqlparser/v2/ast" 14 "golang.org/x/text/cases" 15 "golang.org/x/text/language" 16 17 "github.com/niko0xdev/gqlgen/codegen/config" 18 "github.com/niko0xdev/gqlgen/codegen/templates" 19 ) 20 21 type Field struct { 22 *ast.FieldDefinition 23 24 TypeReference *config.TypeReference 25 GoFieldType GoFieldType // The field type in go, if any 26 GoReceiverName string // The name of method & var receiver in go, if any 27 GoFieldName string // The name of the method or var in go, if any 28 IsResolver bool // Does this field need a resolver 29 Args []*FieldArgument // A list of arguments to be passed to this field 30 MethodHasContext bool // If this is bound to a go method, does the method also take a context 31 NoErr bool // If this is bound to a go method, does that method have an error as the second argument 32 VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool) 33 Object *Object // A link back to the parent object 34 Default interface{} // The default value 35 Stream bool // does this field return a channel? 36 Directives []*Directive 37 } 38 39 func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { 40 dirs, err := b.getDirectives(field.Directives) 41 if err != nil { 42 return nil, err 43 } 44 45 f := Field{ 46 FieldDefinition: field, 47 Object: obj, 48 Directives: dirs, 49 GoFieldName: templates.ToGo(field.Name), 50 GoFieldType: GoFieldVariable, 51 GoReceiverName: "obj", 52 } 53 54 if field.DefaultValue != nil { 55 var err error 56 f.Default, err = field.DefaultValue.Value(nil) 57 if err != nil { 58 return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err) 59 } 60 } 61 62 for _, arg := range field.Arguments { 63 newArg, err := b.buildArg(obj, arg) 64 if err != nil { 65 return nil, err 66 } 67 f.Args = append(f.Args, newArg) 68 } 69 70 if err = b.bindField(obj, &f); err != nil { 71 f.IsResolver = true 72 if errors.Is(err, config.ErrTypeNotFound) { 73 return nil, err 74 } 75 log.Println(err.Error()) 76 } 77 78 if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { 79 f.TypeReference = b.Binder.PointerTo(f.TypeReference) 80 } 81 82 return &f, nil 83 } 84 85 func (b *builder) bindField(obj *Object, f *Field) (errret error) { 86 defer func() { 87 if f.TypeReference == nil { 88 tr, err := b.Binder.TypeReference(f.Type, nil) 89 if err != nil { 90 errret = err 91 } 92 f.TypeReference = tr 93 } 94 if f.TypeReference != nil { 95 dirs, err := b.getDirectives(f.TypeReference.Definition.Directives) 96 if err != nil { 97 errret = err 98 } 99 for _, dir := range obj.Directives { 100 if dir.IsLocation(ast.LocationInputObject) { 101 dirs = append(dirs, dir) 102 } 103 } 104 f.Directives = append(dirs, f.Directives...) 105 } 106 }() 107 108 f.Stream = obj.Stream 109 110 switch { 111 case f.Name == "__schema": 112 f.GoFieldType = GoFieldMethod 113 f.GoReceiverName = "ec" 114 f.GoFieldName = "introspectSchema" 115 return nil 116 case f.Name == "__type": 117 f.GoFieldType = GoFieldMethod 118 f.GoReceiverName = "ec" 119 f.GoFieldName = "introspectType" 120 return nil 121 case f.Name == "_entities": 122 f.GoFieldType = GoFieldMethod 123 f.GoReceiverName = "ec" 124 f.GoFieldName = "__resolve_entities" 125 f.MethodHasContext = true 126 f.NoErr = true 127 return nil 128 case f.Name == "_service": 129 f.GoFieldType = GoFieldMethod 130 f.GoReceiverName = "ec" 131 f.GoFieldName = "__resolve__service" 132 f.MethodHasContext = true 133 return nil 134 case obj.Root: 135 f.IsResolver = true 136 return nil 137 case b.Config.Models[obj.Name].Fields[f.Name].Resolver: 138 f.IsResolver = true 139 return nil 140 case obj.Type == config.MapType: 141 f.GoFieldType = GoFieldMap 142 return nil 143 case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": 144 f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName 145 } 146 147 target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) 148 if err != nil { 149 return err 150 } 151 152 pos := b.Binder.ObjectPosition(target) 153 154 switch target := target.(type) { 155 case nil: 156 // Skips creating a resolver for any root types 157 if b.Config.IsRoot(b.Schema.Types[f.Type.Name()]) { 158 return nil 159 } 160 161 objPos := b.Binder.TypePosition(obj.Type) 162 return fmt.Errorf( 163 "%s:%d adding resolver method for %s.%s, nothing matched", 164 objPos.Filename, 165 objPos.Line, 166 obj.Name, 167 f.Name, 168 ) 169 170 case *types.Func: 171 sig := target.Type().(*types.Signature) 172 if sig.Results().Len() == 1 { 173 f.NoErr = true 174 } else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" { 175 f.VOkFunc = true 176 } else if sig.Results().Len() != 2 { 177 return fmt.Errorf("method has wrong number of args") 178 } 179 params := sig.Params() 180 // If the first argument is the context, remove it from the comparison and set 181 // the MethodHasContext flag so that the context will be passed to this model's method 182 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 183 f.MethodHasContext = true 184 vars := make([]*types.Var, params.Len()-1) 185 for i := 1; i < params.Len(); i++ { 186 vars[i-1] = params.At(i) 187 } 188 params = types.NewTuple(vars...) 189 } 190 191 // Try to match target function's arguments with GraphQL field arguments. 192 newArgs, err := b.bindArgs(f, sig, params) 193 if err != nil { 194 return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err) 195 } 196 197 // Try to match target function's return types with GraphQL field return type 198 result := sig.Results().At(0) 199 tr, err := b.Binder.TypeReference(f.Type, result.Type()) 200 if err != nil { 201 return err 202 } 203 204 // success, args and return type match. Bind to method 205 f.GoFieldType = GoFieldMethod 206 f.GoReceiverName = "obj" 207 f.GoFieldName = target.Name() 208 f.Args = newArgs 209 f.TypeReference = tr 210 211 return nil 212 213 case *types.Var: 214 tr, err := b.Binder.TypeReference(f.Type, target.Type()) 215 if err != nil { 216 return err 217 } 218 219 // success, bind to var 220 f.GoFieldType = GoFieldVariable 221 f.GoReceiverName = "obj" 222 f.GoFieldName = target.Name() 223 f.TypeReference = tr 224 225 return nil 226 default: 227 panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) 228 } 229 } 230 231 // findBindTarget attempts to match the name to a field or method on a Type 232 // with the following priorites: 233 // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found 234 // 2. Any method or field with a matching name. Errors if more than one match is found 235 // 3. Same logic again for embedded fields 236 func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { 237 // NOTE: a struct tag will override both methods and fields 238 // Bind to struct tag 239 found, err := b.findBindStructTagTarget(t, name) 240 if found != nil || err != nil { 241 return found, err 242 } 243 244 // Search for a method to bind to 245 foundMethod, err := b.findBindMethodTarget(t, name) 246 if err != nil { 247 return nil, err 248 } 249 250 // Search for a field to bind to 251 foundField, err := b.findBindFieldTarget(t, name) 252 if err != nil { 253 return nil, err 254 } 255 256 switch { 257 case foundField == nil && foundMethod != nil: 258 // Bind to method 259 return foundMethod, nil 260 case foundField != nil && foundMethod == nil: 261 // Bind to field 262 return foundField, nil 263 case foundField != nil && foundMethod != nil: 264 // Error 265 return nil, fmt.Errorf("found more than one way to bind for %s", name) 266 } 267 268 // Search embeds 269 return b.findBindEmbedsTarget(t, name) 270 } 271 272 func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { 273 if b.Config.StructTag == "" { 274 return nil, nil 275 } 276 277 switch t := in.(type) { 278 case *types.Named: 279 return b.findBindStructTagTarget(t.Underlying(), name) 280 case *types.Struct: 281 var found types.Object 282 for i := 0; i < t.NumFields(); i++ { 283 field := t.Field(i) 284 if !field.Exported() || field.Embedded() { 285 continue 286 } 287 tags := reflect.StructTag(t.Tag(i)) 288 if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { 289 if found != nil { 290 return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) 291 } 292 293 found = field 294 } 295 } 296 297 return found, nil 298 } 299 300 return nil, nil 301 } 302 303 func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { 304 switch t := in.(type) { 305 case *types.Named: 306 if _, ok := t.Underlying().(*types.Interface); ok { 307 return b.findBindMethodTarget(t.Underlying(), name) 308 } 309 310 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 311 case *types.Interface: 312 // FIX-ME: Should use ExplicitMethod here? What's the difference? 313 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 314 } 315 316 return nil, nil 317 } 318 319 func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { 320 var found types.Object 321 for i := 0; i < methodCount; i++ { 322 method := methodFunc(i) 323 if !method.Exported() || !strings.EqualFold(method.Name(), name) { 324 continue 325 } 326 327 if found != nil { 328 return nil, fmt.Errorf("found more than one matching method to bind for %s", name) 329 } 330 331 found = method 332 } 333 334 return found, nil 335 } 336 337 func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { 338 switch t := in.(type) { 339 case *types.Named: 340 return b.findBindFieldTarget(t.Underlying(), name) 341 case *types.Struct: 342 var found types.Object 343 for i := 0; i < t.NumFields(); i++ { 344 field := t.Field(i) 345 if !field.Exported() || !equalFieldName(field.Name(), name) { 346 continue 347 } 348 349 if found != nil { 350 return nil, fmt.Errorf("found more than one matching field to bind for %s", name) 351 } 352 353 found = field 354 } 355 356 return found, nil 357 } 358 359 return nil, nil 360 } 361 362 func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { 363 switch t := in.(type) { 364 case *types.Named: 365 return b.findBindEmbedsTarget(t.Underlying(), name) 366 case *types.Struct: 367 return b.findBindStructEmbedsTarget(t, name) 368 case *types.Interface: 369 return b.findBindInterfaceEmbedsTarget(t, name) 370 } 371 372 return nil, nil 373 } 374 375 func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { 376 var found types.Object 377 for i := 0; i < strukt.NumFields(); i++ { 378 field := strukt.Field(i) 379 if !field.Embedded() { 380 continue 381 } 382 383 fieldType := field.Type() 384 if ptr, ok := fieldType.(*types.Pointer); ok { 385 fieldType = ptr.Elem() 386 } 387 388 f, err := b.findBindTarget(fieldType, name) 389 if err != nil { 390 return nil, err 391 } 392 393 if f != nil && found != nil { 394 return nil, fmt.Errorf("found more than one way to bind for %s", name) 395 } 396 397 if f != nil { 398 found = f 399 } 400 } 401 402 return found, nil 403 } 404 405 func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) { 406 var found types.Object 407 for i := 0; i < iface.NumEmbeddeds(); i++ { 408 embeddedType := iface.EmbeddedType(i) 409 410 f, err := b.findBindTarget(embeddedType, name) 411 if err != nil { 412 return nil, err 413 } 414 415 if f != nil && found != nil { 416 return nil, fmt.Errorf("found more than one way to bind for %s", name) 417 } 418 419 if f != nil { 420 found = f 421 } 422 } 423 424 return found, nil 425 } 426 427 func (f *Field) HasDirectives() bool { 428 return len(f.ImplDirectives()) > 0 429 } 430 431 func (f *Field) DirectiveObjName() string { 432 if f.Object.Root { 433 return "nil" 434 } 435 return f.GoReceiverName 436 } 437 438 func (f *Field) ImplDirectives() []*Directive { 439 var d []*Directive 440 loc := ast.LocationFieldDefinition 441 if f.Object.IsInputType() { 442 loc = ast.LocationInputFieldDefinition 443 } 444 for i := range f.Directives { 445 if !f.Directives[i].Builtin && 446 (f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) { 447 d = append(d, f.Directives[i]) 448 } 449 } 450 return d 451 } 452 453 func (f *Field) IsReserved() bool { 454 return strings.HasPrefix(f.Name, "__") 455 } 456 457 func (f *Field) IsMethod() bool { 458 return f.GoFieldType == GoFieldMethod 459 } 460 461 func (f *Field) IsVariable() bool { 462 return f.GoFieldType == GoFieldVariable 463 } 464 465 func (f *Field) IsMap() bool { 466 return f.GoFieldType == GoFieldMap 467 } 468 469 func (f *Field) IsConcurrent() bool { 470 if f.Object.DisableConcurrency { 471 return false 472 } 473 return f.MethodHasContext || f.IsResolver 474 } 475 476 func (f *Field) GoNameUnexported() string { 477 return templates.ToGoPrivate(f.Name) 478 } 479 480 func (f *Field) ShortInvocation() string { 481 caser := cases.Title(language.English, cases.NoLower) 482 if f.Object.Kind == ast.InputObject { 483 return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName) 484 } 485 return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs()) 486 } 487 488 func (f *Field) ArgsFunc() string { 489 if len(f.Args) == 0 { 490 return "" 491 } 492 493 return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" 494 } 495 496 func (f *Field) FieldContextFunc() string { 497 return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name 498 } 499 500 func (f *Field) ChildFieldContextFunc(name string) string { 501 return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name 502 } 503 504 func (f *Field) ResolverType() string { 505 if !f.IsResolver { 506 return "" 507 } 508 509 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 510 } 511 512 func (f *Field) IsInputObject() bool { 513 return f.Object.Kind == ast.InputObject 514 } 515 516 func (f *Field) IsRoot() bool { 517 return f.Object.Root 518 } 519 520 func (f *Field) ShortResolverDeclaration() string { 521 return f.ShortResolverSignature(nil) 522 } 523 524 // ShortResolverSignature is identical to ShortResolverDeclaration, 525 // but respects previous naming (return) conventions, if any. 526 func (f *Field) ShortResolverSignature(ft *goast.FuncType) string { 527 if f.Object.Kind == ast.InputObject { 528 return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error", 529 templates.CurrentImports.LookupType(f.Object.Reference()), 530 templates.CurrentImports.LookupType(f.TypeReference.GO), 531 ) 532 } 533 534 res := "(ctx context.Context" 535 536 if !f.Object.Root { 537 res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference())) 538 } 539 for _, arg := range f.Args { 540 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 541 } 542 543 result := templates.CurrentImports.LookupType(f.TypeReference.GO) 544 if f.Object.Stream { 545 result = "<-chan " + result 546 } 547 // Named return. 548 var namedV, namedE string 549 if ft != nil { 550 if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 { 551 namedV = ft.Results.List[0].Names[0].Name 552 } 553 if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 { 554 namedE = ft.Results.List[1].Names[0].Name 555 } 556 } 557 res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE) 558 return res 559 } 560 561 func (f *Field) GoResultName() (string, bool) { 562 name := fmt.Sprintf("%v", f.TypeReference.GO) 563 splits := strings.Split(name, "/") 564 565 return splits[len(splits)-1], strings.HasPrefix(name, "[]") 566 } 567 568 func (f *Field) ComplexitySignature() string { 569 res := "func(childComplexity int" 570 for _, arg := range f.Args { 571 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 572 } 573 res += ") int" 574 return res 575 } 576 577 func (f *Field) ComplexityArgs() string { 578 args := make([]string, len(f.Args)) 579 for i, arg := range f.Args { 580 args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 581 } 582 583 return strings.Join(args, ", ") 584 } 585 586 func (f *Field) CallArgs() string { 587 args := make([]string, 0, len(f.Args)+2) 588 589 if f.IsResolver { 590 args = append(args, "rctx") 591 592 if !f.Object.Root { 593 args = append(args, "obj") 594 } 595 } else if f.MethodHasContext { 596 args = append(args, "ctx") 597 } 598 599 for _, arg := range f.Args { 600 tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 601 602 if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() { 603 tmp = fmt.Sprintf(` 604 func () interface{} { 605 if fc.Args["%s"] == nil { 606 return nil 607 } 608 return fc.Args["%s"].(interface{}) 609 }()`, arg.Name, arg.Name, 610 ) 611 } 612 613 args = append(args, tmp) 614 } 615 616 return strings.Join(args, ", ") 617 }