github.com/shippio/gqlgen@v0.0.0-20220912092219-633ea699ef07/codegen/field.go (about) 1 package codegen 2 3 import ( 4 "errors" 5 "fmt" 6 "go/types" 7 "log" 8 "reflect" 9 "strconv" 10 "strings" 11 12 "github.com/99designs/gqlgen/codegen/config" 13 "github.com/99designs/gqlgen/codegen/templates" 14 "github.com/vektah/gqlparser/v2/ast" 15 "golang.org/x/text/cases" 16 "golang.org/x/text/language" 17 ) 18 19 type Field struct { 20 *ast.FieldDefinition 21 22 TypeReference *config.TypeReference 23 GoFieldType GoFieldType // The field type in go, if any 24 GoReceiverName string // The name of method & var receiver in go, if any 25 GoFieldName string // The name of the method or var in go, if any 26 IsResolver bool // Does this field need a resolver 27 Args []*FieldArgument // A list of arguments to be passed to this field 28 MethodHasContext bool // If this is bound to a go method, does the method also take a context 29 NoErr bool // If this is bound to a go method, does that method have an error as the second argument 30 VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool) 31 Object *Object // A link back to the parent object 32 Default interface{} // The default value 33 Stream bool // does this field return a channel? 34 Directives []*Directive 35 } 36 37 func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { 38 dirs, err := b.getDirectives(field.Directives) 39 if err != nil { 40 return nil, err 41 } 42 43 f := Field{ 44 FieldDefinition: field, 45 Object: obj, 46 Directives: dirs, 47 GoFieldName: templates.ToGo(field.Name), 48 GoFieldType: GoFieldVariable, 49 GoReceiverName: "obj", 50 } 51 52 if field.DefaultValue != nil { 53 var err error 54 f.Default, err = field.DefaultValue.Value(nil) 55 if err != nil { 56 return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err) 57 } 58 } 59 60 for _, arg := range field.Arguments { 61 newArg, err := b.buildArg(obj, arg) 62 if err != nil { 63 return nil, err 64 } 65 f.Args = append(f.Args, newArg) 66 } 67 68 if err = b.bindField(obj, &f); err != nil { 69 f.IsResolver = true 70 if errors.Is(err, config.ErrTypeNotFound) { 71 return nil, err 72 } 73 log.Println(err.Error()) 74 } 75 76 if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { 77 f.TypeReference = b.Binder.PointerTo(f.TypeReference) 78 } 79 80 return &f, nil 81 } 82 83 func (b *builder) bindField(obj *Object, f *Field) (errret error) { 84 defer func() { 85 if f.TypeReference == nil { 86 tr, err := b.Binder.TypeReference(f.Type, nil) 87 if err != nil { 88 errret = err 89 } 90 f.TypeReference = tr 91 } 92 if f.TypeReference != nil { 93 dirs, err := b.getDirectives(f.TypeReference.Definition.Directives) 94 if err != nil { 95 errret = err 96 } 97 for _, dir := range obj.Directives { 98 if dir.IsLocation(ast.LocationInputObject) { 99 dirs = append(dirs, dir) 100 } 101 } 102 f.Directives = append(dirs, f.Directives...) 103 } 104 }() 105 106 f.Stream = obj.Stream 107 108 switch { 109 case f.Name == "__schema": 110 f.GoFieldType = GoFieldMethod 111 f.GoReceiverName = "ec" 112 f.GoFieldName = "introspectSchema" 113 return nil 114 case f.Name == "__type": 115 f.GoFieldType = GoFieldMethod 116 f.GoReceiverName = "ec" 117 f.GoFieldName = "introspectType" 118 return nil 119 case f.Name == "_entities": 120 f.GoFieldType = GoFieldMethod 121 f.GoReceiverName = "ec" 122 f.GoFieldName = "__resolve_entities" 123 f.MethodHasContext = true 124 f.NoErr = true 125 return nil 126 case f.Name == "_service": 127 f.GoFieldType = GoFieldMethod 128 f.GoReceiverName = "ec" 129 f.GoFieldName = "__resolve__service" 130 f.MethodHasContext = true 131 return nil 132 case obj.Root: 133 f.IsResolver = true 134 return nil 135 case b.Config.Models[obj.Name].Fields[f.Name].Resolver: 136 f.IsResolver = true 137 return nil 138 case obj.Type == config.MapType: 139 f.GoFieldType = GoFieldMap 140 return nil 141 case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": 142 f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName 143 } 144 145 target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) 146 if err != nil { 147 return err 148 } 149 150 pos := b.Binder.ObjectPosition(target) 151 152 switch target := target.(type) { 153 case nil: 154 objPos := b.Binder.TypePosition(obj.Type) 155 return fmt.Errorf( 156 "%s:%d adding resolver method for %s.%s, nothing matched", 157 objPos.Filename, 158 objPos.Line, 159 obj.Name, 160 f.Name, 161 ) 162 163 case *types.Func: 164 sig := target.Type().(*types.Signature) 165 if sig.Results().Len() == 1 { 166 f.NoErr = true 167 } else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" { 168 f.VOkFunc = true 169 } else if sig.Results().Len() != 2 { 170 return fmt.Errorf("method has wrong number of args") 171 } 172 params := sig.Params() 173 // If the first argument is the context, remove it from the comparison and set 174 // the MethodHasContext flag so that the context will be passed to this model's method 175 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 176 f.MethodHasContext = true 177 vars := make([]*types.Var, params.Len()-1) 178 for i := 1; i < params.Len(); i++ { 179 vars[i-1] = params.At(i) 180 } 181 params = types.NewTuple(vars...) 182 } 183 184 // Try to match target function's arguments with GraphQL field arguments. 185 newArgs, err := b.bindArgs(f, sig, params) 186 if err != nil { 187 return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err) 188 } 189 190 // Try to match target function's return types with GraphQL field return type 191 result := sig.Results().At(0) 192 tr, err := b.Binder.TypeReference(f.Type, result.Type()) 193 if err != nil { 194 return err 195 } 196 197 // success, args and return type match. Bind to method 198 f.GoFieldType = GoFieldMethod 199 f.GoReceiverName = "obj" 200 f.GoFieldName = target.Name() 201 f.Args = newArgs 202 f.TypeReference = tr 203 204 return nil 205 206 case *types.Var: 207 tr, err := b.Binder.TypeReference(f.Type, target.Type()) 208 if err != nil { 209 return err 210 } 211 212 // success, bind to var 213 f.GoFieldType = GoFieldVariable 214 f.GoReceiverName = "obj" 215 f.GoFieldName = target.Name() 216 f.TypeReference = tr 217 218 return nil 219 default: 220 panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) 221 } 222 } 223 224 // findBindTarget attempts to match the name to a field or method on a Type 225 // with the following priorites: 226 // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found 227 // 2. Any method or field with a matching name. Errors if more than one match is found 228 // 3. Same logic again for embedded fields 229 func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { 230 // NOTE: a struct tag will override both methods and fields 231 // Bind to struct tag 232 found, err := b.findBindStructTagTarget(t, name) 233 if found != nil || err != nil { 234 return found, err 235 } 236 237 // Search for a method to bind to 238 foundMethod, err := b.findBindMethodTarget(t, name) 239 if err != nil { 240 return nil, err 241 } 242 243 // Search for a field to bind to 244 foundField, err := b.findBindFieldTarget(t, name) 245 if err != nil { 246 return nil, err 247 } 248 249 switch { 250 case foundField == nil && foundMethod != nil: 251 // Bind to method 252 return foundMethod, nil 253 case foundField != nil && foundMethod == nil: 254 // Bind to field 255 return foundField, nil 256 case foundField != nil && foundMethod != nil: 257 // Error 258 return nil, fmt.Errorf("found more than one way to bind for %s", name) 259 } 260 261 // Search embeds 262 return b.findBindEmbedsTarget(t, name) 263 } 264 265 func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { 266 if b.Config.StructTag == "" { 267 return nil, nil 268 } 269 270 switch t := in.(type) { 271 case *types.Named: 272 return b.findBindStructTagTarget(t.Underlying(), name) 273 case *types.Struct: 274 var found types.Object 275 for i := 0; i < t.NumFields(); i++ { 276 field := t.Field(i) 277 if !field.Exported() || field.Embedded() { 278 continue 279 } 280 tags := reflect.StructTag(t.Tag(i)) 281 if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { 282 if found != nil { 283 return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) 284 } 285 286 found = field 287 } 288 } 289 290 return found, nil 291 } 292 293 return nil, nil 294 } 295 296 func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { 297 switch t := in.(type) { 298 case *types.Named: 299 if _, ok := t.Underlying().(*types.Interface); ok { 300 return b.findBindMethodTarget(t.Underlying(), name) 301 } 302 303 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 304 case *types.Interface: 305 // FIX-ME: Should use ExplicitMethod here? What's the difference? 306 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 307 } 308 309 return nil, nil 310 } 311 312 func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { 313 var found types.Object 314 for i := 0; i < methodCount; i++ { 315 method := methodFunc(i) 316 if !method.Exported() || !strings.EqualFold(method.Name(), name) { 317 continue 318 } 319 320 if found != nil { 321 return nil, fmt.Errorf("found more than one matching method to bind for %s", name) 322 } 323 324 found = method 325 } 326 327 return found, nil 328 } 329 330 func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { 331 switch t := in.(type) { 332 case *types.Named: 333 return b.findBindFieldTarget(t.Underlying(), name) 334 case *types.Struct: 335 var found types.Object 336 for i := 0; i < t.NumFields(); i++ { 337 field := t.Field(i) 338 if !field.Exported() || !equalFieldName(field.Name(), name) { 339 continue 340 } 341 342 if found != nil { 343 return nil, fmt.Errorf("found more than one matching field to bind for %s", name) 344 } 345 346 found = field 347 } 348 349 return found, nil 350 } 351 352 return nil, nil 353 } 354 355 func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { 356 switch t := in.(type) { 357 case *types.Named: 358 return b.findBindEmbedsTarget(t.Underlying(), name) 359 case *types.Struct: 360 return b.findBindStructEmbedsTarget(t, name) 361 case *types.Interface: 362 return b.findBindInterfaceEmbedsTarget(t, name) 363 } 364 365 return nil, nil 366 } 367 368 func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { 369 var found types.Object 370 for i := 0; i < strukt.NumFields(); i++ { 371 field := strukt.Field(i) 372 if !field.Embedded() { 373 continue 374 } 375 376 fieldType := field.Type() 377 if ptr, ok := fieldType.(*types.Pointer); ok { 378 fieldType = ptr.Elem() 379 } 380 381 f, err := b.findBindTarget(fieldType, name) 382 if err != nil { 383 return nil, err 384 } 385 386 if f != nil && found != nil { 387 return nil, fmt.Errorf("found more than one way to bind for %s", name) 388 } 389 390 if f != nil { 391 found = f 392 } 393 } 394 395 return found, nil 396 } 397 398 func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) { 399 var found types.Object 400 for i := 0; i < iface.NumEmbeddeds(); i++ { 401 embeddedType := iface.EmbeddedType(i) 402 403 f, err := b.findBindTarget(embeddedType, name) 404 if err != nil { 405 return nil, err 406 } 407 408 if f != nil && found != nil { 409 return nil, fmt.Errorf("found more than one way to bind for %s", name) 410 } 411 412 if f != nil { 413 found = f 414 } 415 } 416 417 return found, nil 418 } 419 420 func (f *Field) HasDirectives() bool { 421 return len(f.ImplDirectives()) > 0 422 } 423 424 func (f *Field) DirectiveObjName() string { 425 if f.Object.Root { 426 return "nil" 427 } 428 return f.GoReceiverName 429 } 430 431 func (f *Field) ImplDirectives() []*Directive { 432 var d []*Directive 433 loc := ast.LocationFieldDefinition 434 if f.Object.IsInputType() { 435 loc = ast.LocationInputFieldDefinition 436 } 437 for i := range f.Directives { 438 if !f.Directives[i].Builtin && 439 (f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) { 440 d = append(d, f.Directives[i]) 441 } 442 } 443 return d 444 } 445 446 func (f *Field) IsReserved() bool { 447 return strings.HasPrefix(f.Name, "__") 448 } 449 450 func (f *Field) IsMethod() bool { 451 return f.GoFieldType == GoFieldMethod 452 } 453 454 func (f *Field) IsVariable() bool { 455 return f.GoFieldType == GoFieldVariable 456 } 457 458 func (f *Field) IsMap() bool { 459 return f.GoFieldType == GoFieldMap 460 } 461 462 func (f *Field) IsConcurrent() bool { 463 if f.Object.DisableConcurrency { 464 return false 465 } 466 return f.MethodHasContext || f.IsResolver 467 } 468 469 func (f *Field) GoNameUnexported() string { 470 return templates.ToGoPrivate(f.Name) 471 } 472 473 func (f *Field) ShortInvocation() string { 474 caser := cases.Title(language.English, cases.NoLower) 475 if f.Object.Kind == ast.InputObject { 476 return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName) 477 } 478 return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs()) 479 } 480 481 func (f *Field) ArgsFunc() string { 482 if len(f.Args) == 0 { 483 return "" 484 } 485 486 return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" 487 } 488 489 func (f *Field) FieldContextFunc() string { 490 return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name 491 } 492 493 func (f *Field) ChildFieldContextFunc(name string) string { 494 return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name 495 } 496 497 func (f *Field) ResolverType() string { 498 if !f.IsResolver { 499 return "" 500 } 501 502 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 503 } 504 505 func (f *Field) ShortResolverDeclaration() string { 506 if f.Object.Kind == ast.InputObject { 507 return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error", 508 templates.CurrentImports.LookupType(f.Object.Reference()), 509 templates.CurrentImports.LookupType(f.TypeReference.GO), 510 ) 511 } 512 513 res := "(ctx context.Context" 514 515 if !f.Object.Root { 516 res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference())) 517 } 518 for _, arg := range f.Args { 519 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 520 } 521 522 result := templates.CurrentImports.LookupType(f.TypeReference.GO) 523 if f.Object.Stream { 524 result = "<-chan " + result 525 } 526 527 res += fmt.Sprintf(") (%s, error)", result) 528 return res 529 } 530 531 func (f *Field) ComplexitySignature() string { 532 res := "func(childComplexity int" 533 for _, arg := range f.Args { 534 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 535 } 536 res += ") int" 537 return res 538 } 539 540 func (f *Field) ComplexityArgs() string { 541 args := make([]string, len(f.Args)) 542 for i, arg := range f.Args { 543 args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 544 } 545 546 return strings.Join(args, ", ") 547 } 548 549 func (f *Field) CallArgs() string { 550 args := make([]string, 0, len(f.Args)+2) 551 552 if f.IsResolver { 553 args = append(args, "rctx") 554 555 if !f.Object.Root { 556 args = append(args, "obj") 557 } 558 } else if f.MethodHasContext { 559 args = append(args, "ctx") 560 } 561 562 for _, arg := range f.Args { 563 tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 564 565 if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() { 566 tmp = fmt.Sprintf(` 567 func () interface{} { 568 if fc.Args["%s"] == nil { 569 return nil 570 } 571 return fc.Args["%s"].(interface{}) 572 }()`, arg.Name, arg.Name, 573 ) 574 } 575 576 args = append(args, tmp) 577 } 578 579 return strings.Join(args, ", ") 580 }