github.com/kerryoscer/gqlgen@v0.17.29/codegen/field.go (about) 1 package codegen 2 3 import ( 4 "errors" 5 "fmt" 6 goast "go/ast" 7 "go/types" 8 "log" 9 "reflect" 10 "strconv" 11 "strings" 12 13 "github.com/kerryoscer/gqlgen/codegen/config" 14 "github.com/kerryoscer/gqlgen/codegen/templates" 15 "github.com/vektah/gqlparser/v2/ast" 16 "golang.org/x/text/cases" 17 "golang.org/x/text/language" 18 ) 19 20 type Field struct { 21 *ast.FieldDefinition 22 23 TypeReference *config.TypeReference 24 GoFieldType GoFieldType // The field type in go, if any 25 GoReceiverName string // The name of method & var receiver in go, if any 26 GoFieldName string // The name of the method or var in go, if any 27 IsResolver bool // Does this field need a resolver 28 Args []*FieldArgument // A list of arguments to be passed to this field 29 MethodHasContext bool // If this is bound to a go method, does the method also take a context 30 NoErr bool // If this is bound to a go method, does that method have an error as the second argument 31 VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool) 32 Object *Object // A link back to the parent object 33 Default interface{} // The default value 34 Stream bool // does this field return a channel? 35 Directives []*Directive 36 } 37 38 func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { 39 dirs, err := b.getDirectives(field.Directives) 40 if err != nil { 41 return nil, err 42 } 43 44 f := Field{ 45 FieldDefinition: field, 46 Object: obj, 47 Directives: dirs, 48 GoFieldName: templates.ToGo(field.Name), 49 GoFieldType: GoFieldVariable, 50 GoReceiverName: "obj", 51 } 52 53 if field.DefaultValue != nil { 54 var err error 55 f.Default, err = field.DefaultValue.Value(nil) 56 if err != nil { 57 return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err) 58 } 59 } 60 61 for _, arg := range field.Arguments { 62 newArg, err := b.buildArg(obj, arg) 63 if err != nil { 64 return nil, err 65 } 66 f.Args = append(f.Args, newArg) 67 } 68 69 if err = b.bindField(obj, &f); err != nil { 70 f.IsResolver = true 71 if errors.Is(err, config.ErrTypeNotFound) { 72 return nil, err 73 } 74 log.Println(err.Error()) 75 } 76 77 if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { 78 f.TypeReference = b.Binder.PointerTo(f.TypeReference) 79 } 80 81 return &f, nil 82 } 83 84 func (b *builder) bindField(obj *Object, f *Field) (errret error) { 85 defer func() { 86 if f.TypeReference == nil { 87 tr, err := b.Binder.TypeReference(f.Type, nil) 88 if err != nil { 89 errret = err 90 } 91 f.TypeReference = tr 92 } 93 if f.TypeReference != nil { 94 dirs, err := b.getDirectives(f.TypeReference.Definition.Directives) 95 if err != nil { 96 errret = err 97 } 98 for _, dir := range obj.Directives { 99 if dir.IsLocation(ast.LocationInputObject) { 100 dirs = append(dirs, dir) 101 } 102 } 103 f.Directives = append(dirs, f.Directives...) 104 } 105 }() 106 107 f.Stream = obj.Stream 108 109 switch { 110 case f.Name == "__schema": 111 f.GoFieldType = GoFieldMethod 112 f.GoReceiverName = "ec" 113 f.GoFieldName = "introspectSchema" 114 return nil 115 case f.Name == "__type": 116 f.GoFieldType = GoFieldMethod 117 f.GoReceiverName = "ec" 118 f.GoFieldName = "introspectType" 119 return nil 120 case f.Name == "_entities": 121 f.GoFieldType = GoFieldMethod 122 f.GoReceiverName = "ec" 123 f.GoFieldName = "__resolve_entities" 124 f.MethodHasContext = true 125 f.NoErr = true 126 return nil 127 case f.Name == "_service": 128 f.GoFieldType = GoFieldMethod 129 f.GoReceiverName = "ec" 130 f.GoFieldName = "__resolve__service" 131 f.MethodHasContext = true 132 return nil 133 case obj.Root: 134 f.IsResolver = true 135 return nil 136 case b.Config.Models[obj.Name].Fields[f.Name].Resolver: 137 f.IsResolver = true 138 return nil 139 case obj.Type == config.MapType: 140 f.GoFieldType = GoFieldMap 141 return nil 142 case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": 143 f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName 144 } 145 146 target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) 147 if err != nil { 148 return err 149 } 150 151 pos := b.Binder.ObjectPosition(target) 152 153 switch target := target.(type) { 154 case nil: 155 objPos := b.Binder.TypePosition(obj.Type) 156 return fmt.Errorf( 157 "%s:%d adding resolver method for %s.%s, nothing matched", 158 objPos.Filename, 159 objPos.Line, 160 obj.Name, 161 f.Name, 162 ) 163 164 case *types.Func: 165 sig := target.Type().(*types.Signature) 166 if sig.Results().Len() == 1 { 167 f.NoErr = true 168 } else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" { 169 f.VOkFunc = true 170 } else if sig.Results().Len() != 2 { 171 return fmt.Errorf("method has wrong number of args") 172 } 173 params := sig.Params() 174 // If the first argument is the context, remove it from the comparison and set 175 // the MethodHasContext flag so that the context will be passed to this model's method 176 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 177 f.MethodHasContext = true 178 vars := make([]*types.Var, params.Len()-1) 179 for i := 1; i < params.Len(); i++ { 180 vars[i-1] = params.At(i) 181 } 182 params = types.NewTuple(vars...) 183 } 184 185 // Try to match target function's arguments with GraphQL field arguments. 186 newArgs, err := b.bindArgs(f, sig, params) 187 if err != nil { 188 return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err) 189 } 190 191 // Try to match target function's return types with GraphQL field return type 192 result := sig.Results().At(0) 193 tr, err := b.Binder.TypeReference(f.Type, result.Type()) 194 if err != nil { 195 return err 196 } 197 198 // success, args and return type match. Bind to method 199 f.GoFieldType = GoFieldMethod 200 f.GoReceiverName = "obj" 201 f.GoFieldName = target.Name() 202 f.Args = newArgs 203 f.TypeReference = tr 204 205 return nil 206 207 case *types.Var: 208 tr, err := b.Binder.TypeReference(f.Type, target.Type()) 209 if err != nil { 210 return err 211 } 212 213 // success, bind to var 214 f.GoFieldType = GoFieldVariable 215 f.GoReceiverName = "obj" 216 f.GoFieldName = target.Name() 217 f.TypeReference = tr 218 219 return nil 220 default: 221 panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) 222 } 223 } 224 225 // findBindTarget attempts to match the name to a field or method on a Type 226 // with the following priorites: 227 // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found 228 // 2. Any method or field with a matching name. Errors if more than one match is found 229 // 3. Same logic again for embedded fields 230 func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { 231 // NOTE: a struct tag will override both methods and fields 232 // Bind to struct tag 233 found, err := b.findBindStructTagTarget(t, name) 234 if found != nil || err != nil { 235 return found, err 236 } 237 238 // Search for a method to bind to 239 foundMethod, err := b.findBindMethodTarget(t, name) 240 if err != nil { 241 return nil, err 242 } 243 244 // Search for a field to bind to 245 foundField, err := b.findBindFieldTarget(t, name) 246 if err != nil { 247 return nil, err 248 } 249 250 switch { 251 case foundField == nil && foundMethod != nil: 252 // Bind to method 253 return foundMethod, nil 254 case foundField != nil && foundMethod == nil: 255 // Bind to field 256 return foundField, nil 257 case foundField != nil && foundMethod != nil: 258 // Error 259 return nil, fmt.Errorf("found more than one way to bind for %s", name) 260 } 261 262 // Search embeds 263 return b.findBindEmbedsTarget(t, name) 264 } 265 266 func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { 267 if b.Config.StructTag == "" { 268 return nil, nil 269 } 270 271 switch t := in.(type) { 272 case *types.Named: 273 return b.findBindStructTagTarget(t.Underlying(), name) 274 case *types.Struct: 275 var found types.Object 276 for i := 0; i < t.NumFields(); i++ { 277 field := t.Field(i) 278 if !field.Exported() || field.Embedded() { 279 continue 280 } 281 tags := reflect.StructTag(t.Tag(i)) 282 if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { 283 if found != nil { 284 return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) 285 } 286 287 found = field 288 } 289 } 290 291 return found, nil 292 } 293 294 return nil, nil 295 } 296 297 func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { 298 switch t := in.(type) { 299 case *types.Named: 300 if _, ok := t.Underlying().(*types.Interface); ok { 301 return b.findBindMethodTarget(t.Underlying(), name) 302 } 303 304 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 305 case *types.Interface: 306 // FIX-ME: Should use ExplicitMethod here? What's the difference? 307 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 308 } 309 310 return nil, nil 311 } 312 313 func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { 314 var found types.Object 315 for i := 0; i < methodCount; i++ { 316 method := methodFunc(i) 317 if !method.Exported() || !strings.EqualFold(method.Name(), name) { 318 continue 319 } 320 321 if found != nil { 322 return nil, fmt.Errorf("found more than one matching method to bind for %s", name) 323 } 324 325 found = method 326 } 327 328 return found, nil 329 } 330 331 func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { 332 switch t := in.(type) { 333 case *types.Named: 334 return b.findBindFieldTarget(t.Underlying(), name) 335 case *types.Struct: 336 var found types.Object 337 for i := 0; i < t.NumFields(); i++ { 338 field := t.Field(i) 339 if !field.Exported() || !equalFieldName(field.Name(), name) { 340 continue 341 } 342 343 if found != nil { 344 return nil, fmt.Errorf("found more than one matching field to bind for %s", name) 345 } 346 347 found = field 348 } 349 350 return found, nil 351 } 352 353 return nil, nil 354 } 355 356 func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { 357 switch t := in.(type) { 358 case *types.Named: 359 return b.findBindEmbedsTarget(t.Underlying(), name) 360 case *types.Struct: 361 return b.findBindStructEmbedsTarget(t, name) 362 case *types.Interface: 363 return b.findBindInterfaceEmbedsTarget(t, name) 364 } 365 366 return nil, nil 367 } 368 369 func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { 370 var found types.Object 371 for i := 0; i < strukt.NumFields(); i++ { 372 field := strukt.Field(i) 373 if !field.Embedded() { 374 continue 375 } 376 377 fieldType := field.Type() 378 if ptr, ok := fieldType.(*types.Pointer); ok { 379 fieldType = ptr.Elem() 380 } 381 382 f, err := b.findBindTarget(fieldType, name) 383 if err != nil { 384 return nil, err 385 } 386 387 if f != nil && found != nil { 388 return nil, fmt.Errorf("found more than one way to bind for %s", name) 389 } 390 391 if f != nil { 392 found = f 393 } 394 } 395 396 return found, nil 397 } 398 399 func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) { 400 var found types.Object 401 for i := 0; i < iface.NumEmbeddeds(); i++ { 402 embeddedType := iface.EmbeddedType(i) 403 404 f, err := b.findBindTarget(embeddedType, name) 405 if err != nil { 406 return nil, err 407 } 408 409 if f != nil && found != nil { 410 return nil, fmt.Errorf("found more than one way to bind for %s", name) 411 } 412 413 if f != nil { 414 found = f 415 } 416 } 417 418 return found, nil 419 } 420 421 func (f *Field) HasDirectives() bool { 422 return len(f.ImplDirectives()) > 0 423 } 424 425 func (f *Field) DirectiveObjName() string { 426 if f.Object.Root { 427 return "nil" 428 } 429 return f.GoReceiverName 430 } 431 432 func (f *Field) ImplDirectives() []*Directive { 433 var d []*Directive 434 loc := ast.LocationFieldDefinition 435 if f.Object.IsInputType() { 436 loc = ast.LocationInputFieldDefinition 437 } 438 for i := range f.Directives { 439 if !f.Directives[i].Builtin && 440 (f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) { 441 d = append(d, f.Directives[i]) 442 } 443 } 444 return d 445 } 446 447 func (f *Field) IsReserved() bool { 448 return strings.HasPrefix(f.Name, "__") 449 } 450 451 func (f *Field) IsMethod() bool { 452 return f.GoFieldType == GoFieldMethod 453 } 454 455 func (f *Field) IsVariable() bool { 456 return f.GoFieldType == GoFieldVariable 457 } 458 459 func (f *Field) IsMap() bool { 460 return f.GoFieldType == GoFieldMap 461 } 462 463 func (f *Field) IsConcurrent() bool { 464 if f.Object.DisableConcurrency { 465 return false 466 } 467 return f.MethodHasContext || f.IsResolver 468 } 469 470 func (f *Field) GoNameUnexported() string { 471 return templates.ToGoPrivate(f.Name) 472 } 473 474 func (f *Field) ShortInvocation() string { 475 caser := cases.Title(language.English, cases.NoLower) 476 if f.Object.Kind == ast.InputObject { 477 return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName) 478 } 479 return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs()) 480 } 481 482 func (f *Field) ArgsFunc() string { 483 if len(f.Args) == 0 { 484 return "" 485 } 486 487 return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" 488 } 489 490 func (f *Field) FieldContextFunc() string { 491 return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name 492 } 493 494 func (f *Field) ChildFieldContextFunc(name string) string { 495 return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name 496 } 497 498 func (f *Field) ResolverType() string { 499 if !f.IsResolver { 500 return "" 501 } 502 503 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 504 } 505 506 func (f *Field) IsInputObject() bool { 507 return f.Object.Kind == ast.InputObject 508 } 509 510 func (f *Field) IsRoot() bool { 511 return f.Object.Root 512 } 513 514 func (f *Field) ShortResolverDeclaration() string { 515 return f.ShortResolverSignature(nil) 516 } 517 518 // ShortResolverSignature is identical to ShortResolverDeclaration, 519 // but respects previous naming (return) conventions, if any. 520 func (f *Field) ShortResolverSignature(ft *goast.FuncType) string { 521 if f.Object.Kind == ast.InputObject { 522 return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error", 523 templates.CurrentImports.LookupType(f.Object.Reference()), 524 templates.CurrentImports.LookupType(f.TypeReference.GO), 525 ) 526 } 527 528 res := "(ctx context.Context" 529 530 if !f.Object.Root { 531 res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference())) 532 } 533 for _, arg := range f.Args { 534 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 535 } 536 537 result := templates.CurrentImports.LookupType(f.TypeReference.GO) 538 if f.Object.Stream { 539 result = "<-chan " + result 540 } 541 // Named return. 542 var namedV, namedE string 543 if ft != nil { 544 if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 { 545 namedV = ft.Results.List[0].Names[0].Name 546 } 547 if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 { 548 namedE = ft.Results.List[1].Names[0].Name 549 } 550 } 551 res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE) 552 return res 553 } 554 555 func (f *Field) GoResultName() (string, bool) { 556 name := fmt.Sprintf("%v", f.TypeReference.GO) 557 splits := strings.Split(name, "/") 558 559 return splits[len(splits)-1], strings.HasPrefix(name, "[]") 560 } 561 562 func (f *Field) ComplexitySignature() string { 563 res := "func(childComplexity int" 564 for _, arg := range f.Args { 565 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 566 } 567 res += ") int" 568 return res 569 } 570 571 func (f *Field) ComplexityArgs() string { 572 args := make([]string, len(f.Args)) 573 for i, arg := range f.Args { 574 args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 575 } 576 577 return strings.Join(args, ", ") 578 } 579 580 func (f *Field) CallArgs() string { 581 args := make([]string, 0, len(f.Args)+2) 582 583 if f.IsResolver { 584 args = append(args, "rctx") 585 586 if !f.Object.Root { 587 args = append(args, "obj") 588 } 589 } else if f.MethodHasContext { 590 args = append(args, "ctx") 591 } 592 593 for _, arg := range f.Args { 594 tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 595 596 if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() { 597 tmp = fmt.Sprintf(` 598 func () interface{} { 599 if fc.Args["%s"] == nil { 600 return nil 601 } 602 return fc.Args["%s"].(interface{}) 603 }()`, arg.Name, arg.Name, 604 ) 605 } 606 607 args = append(args, tmp) 608 } 609 610 return strings.Join(args, ", ") 611 }