github.com/operandinc/gqlgen@v0.16.1/codegen/field.go (about) 1 package codegen 2 3 import ( 4 "errors" 5 "fmt" 6 "go/types" 7 "log" 8 "reflect" 9 "strconv" 10 "strings" 11 12 "github.com/operandinc/gqlgen/codegen/config" 13 "github.com/operandinc/gqlgen/codegen/templates" 14 "github.com/vektah/gqlparser/v2/ast" 15 ) 16 17 type Field struct { 18 *ast.FieldDefinition 19 20 TypeReference *config.TypeReference 21 GoFieldType GoFieldType // The field type in go, if any 22 GoReceiverName string // The name of method & var receiver in go, if any 23 GoFieldName string // The name of the method or var in go, if any 24 IsResolver bool // Does this field need a resolver 25 Args []*FieldArgument // A list of arguments to be passed to this field 26 MethodHasContext bool // If this is bound to a go method, does the method also take a context 27 NoErr bool // If this is bound to a go method, does that method have an error as the second argument 28 VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool) 29 Object *Object // A link back to the parent object 30 Default interface{} // The default value 31 Stream bool // does this field return a channel? 32 Directives []*Directive 33 } 34 35 func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { 36 dirs, err := b.getDirectives(field.Directives) 37 if err != nil { 38 return nil, err 39 } 40 41 f := Field{ 42 FieldDefinition: field, 43 Object: obj, 44 Directives: dirs, 45 GoFieldName: templates.ToGo(field.Name), 46 GoFieldType: GoFieldVariable, 47 GoReceiverName: "obj", 48 } 49 50 if field.DefaultValue != nil { 51 var err error 52 f.Default, err = field.DefaultValue.Value(nil) 53 if err != nil { 54 return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err) 55 } 56 } 57 58 for _, arg := range field.Arguments { 59 newArg, err := b.buildArg(obj, arg) 60 if err != nil { 61 return nil, err 62 } 63 f.Args = append(f.Args, newArg) 64 } 65 66 if err = b.bindField(obj, &f); err != nil { 67 f.IsResolver = true 68 if errors.Is(err, config.ErrTypeNotFound) { 69 return nil, err 70 } 71 log.Println(err.Error()) 72 } 73 74 if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { 75 f.TypeReference = b.Binder.PointerTo(f.TypeReference) 76 } 77 78 return &f, nil 79 } 80 81 func (b *builder) bindField(obj *Object, f *Field) (errret error) { 82 defer func() { 83 if f.TypeReference == nil { 84 tr, err := b.Binder.TypeReference(f.Type, nil) 85 if err != nil { 86 errret = err 87 } 88 f.TypeReference = tr 89 } 90 if f.TypeReference != nil { 91 dirs, err := b.getDirectives(f.TypeReference.Definition.Directives) 92 if err != nil { 93 errret = err 94 } 95 for _, dir := range obj.Directives { 96 if dir.IsLocation(ast.LocationInputObject) { 97 dirs = append(dirs, dir) 98 } 99 } 100 f.Directives = append(dirs, f.Directives...) 101 } 102 }() 103 104 f.Stream = obj.Stream 105 106 switch { 107 case f.Name == "__schema": 108 f.GoFieldType = GoFieldMethod 109 f.GoReceiverName = "ec" 110 f.GoFieldName = "introspectSchema" 111 return nil 112 case f.Name == "__type": 113 f.GoFieldType = GoFieldMethod 114 f.GoReceiverName = "ec" 115 f.GoFieldName = "introspectType" 116 return nil 117 case f.Name == "_entities": 118 f.GoFieldType = GoFieldMethod 119 f.GoReceiverName = "ec" 120 f.GoFieldName = "__resolve_entities" 121 f.MethodHasContext = true 122 f.NoErr = true 123 return nil 124 case f.Name == "_service": 125 f.GoFieldType = GoFieldMethod 126 f.GoReceiverName = "ec" 127 f.GoFieldName = "__resolve__service" 128 f.MethodHasContext = true 129 return nil 130 case obj.Root: 131 f.IsResolver = true 132 return nil 133 case b.Config.Models[obj.Name].Fields[f.Name].Resolver: 134 f.IsResolver = true 135 return nil 136 case obj.Type == config.MapType: 137 f.GoFieldType = GoFieldMap 138 return nil 139 case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": 140 f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName 141 } 142 143 target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) 144 if err != nil { 145 return err 146 } 147 148 pos := b.Binder.ObjectPosition(target) 149 150 switch target := target.(type) { 151 case nil: 152 objPos := b.Binder.TypePosition(obj.Type) 153 return fmt.Errorf( 154 "%s:%d adding resolver method for %s.%s, nothing matched", 155 objPos.Filename, 156 objPos.Line, 157 obj.Name, 158 f.Name, 159 ) 160 161 case *types.Func: 162 sig := target.Type().(*types.Signature) 163 if sig.Results().Len() == 1 { 164 f.NoErr = true 165 } else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" { 166 f.VOkFunc = true 167 } else if sig.Results().Len() != 2 { 168 return fmt.Errorf("method has wrong number of args") 169 } 170 params := sig.Params() 171 // If the first argument is the context, remove it from the comparison and set 172 // the MethodHasContext flag so that the context will be passed to this model's method 173 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 174 f.MethodHasContext = true 175 vars := make([]*types.Var, params.Len()-1) 176 for i := 1; i < params.Len(); i++ { 177 vars[i-1] = params.At(i) 178 } 179 params = types.NewTuple(vars...) 180 } 181 182 // Try to match target function's arguments with GraphQL field arguments 183 newArgs, err := b.bindArgs(f, params) 184 if err != nil { 185 return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err) 186 } 187 188 // Try to match target function's return types with GraphQL field return type 189 result := sig.Results().At(0) 190 tr, err := b.Binder.TypeReference(f.Type, result.Type()) 191 if err != nil { 192 return err 193 } 194 195 // success, args and return type match. Bind to method 196 f.GoFieldType = GoFieldMethod 197 f.GoReceiverName = "obj" 198 f.GoFieldName = target.Name() 199 f.Args = newArgs 200 f.TypeReference = tr 201 202 return nil 203 204 case *types.Var: 205 tr, err := b.Binder.TypeReference(f.Type, target.Type()) 206 if err != nil { 207 return err 208 } 209 210 // success, bind to var 211 f.GoFieldType = GoFieldVariable 212 f.GoReceiverName = "obj" 213 f.GoFieldName = target.Name() 214 f.TypeReference = tr 215 216 return nil 217 default: 218 panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) 219 } 220 } 221 222 // findBindTarget attempts to match the name to a field or method on a Type 223 // with the following priorites: 224 // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found 225 // 2. Any method or field with a matching name. Errors if more than one match is found 226 // 3. Same logic again for embedded fields 227 func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { 228 // NOTE: a struct tag will override both methods and fields 229 // Bind to struct tag 230 found, err := b.findBindStructTagTarget(t, name) 231 if found != nil || err != nil { 232 return found, err 233 } 234 235 // Search for a method to bind to 236 foundMethod, err := b.findBindMethodTarget(t, name) 237 if err != nil { 238 return nil, err 239 } 240 241 // Search for a field to bind to 242 foundField, err := b.findBindFieldTarget(t, name) 243 if err != nil { 244 return nil, err 245 } 246 247 switch { 248 case foundField == nil && foundMethod != nil: 249 // Bind to method 250 return foundMethod, nil 251 case foundField != nil && foundMethod == nil: 252 // Bind to field 253 return foundField, nil 254 case foundField != nil && foundMethod != nil: 255 // Error 256 return nil, fmt.Errorf("found more than one way to bind for %s", name) 257 } 258 259 // Search embeds 260 return b.findBindEmbedsTarget(t, name) 261 } 262 263 func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { 264 if b.Config.StructTag == "" { 265 return nil, nil 266 } 267 268 switch t := in.(type) { 269 case *types.Named: 270 return b.findBindStructTagTarget(t.Underlying(), name) 271 case *types.Struct: 272 var found types.Object 273 for i := 0; i < t.NumFields(); i++ { 274 field := t.Field(i) 275 if !field.Exported() || field.Embedded() { 276 continue 277 } 278 tags := reflect.StructTag(t.Tag(i)) 279 if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { 280 if found != nil { 281 return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) 282 } 283 284 found = field 285 } 286 } 287 288 return found, nil 289 } 290 291 return nil, nil 292 } 293 294 func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { 295 switch t := in.(type) { 296 case *types.Named: 297 if _, ok := t.Underlying().(*types.Interface); ok { 298 return b.findBindMethodTarget(t.Underlying(), name) 299 } 300 301 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 302 case *types.Interface: 303 // FIX-ME: Should use ExplicitMethod here? What's the difference? 304 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 305 } 306 307 return nil, nil 308 } 309 310 func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { 311 var found types.Object 312 for i := 0; i < methodCount; i++ { 313 method := methodFunc(i) 314 if !method.Exported() || !strings.EqualFold(method.Name(), name) { 315 continue 316 } 317 318 if found != nil { 319 return nil, fmt.Errorf("found more than one matching method to bind for %s", name) 320 } 321 322 found = method 323 } 324 325 return found, nil 326 } 327 328 func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { 329 switch t := in.(type) { 330 case *types.Named: 331 return b.findBindFieldTarget(t.Underlying(), name) 332 case *types.Struct: 333 var found types.Object 334 for i := 0; i < t.NumFields(); i++ { 335 field := t.Field(i) 336 if !field.Exported() || !equalFieldName(field.Name(), name) { 337 continue 338 } 339 340 if found != nil { 341 return nil, fmt.Errorf("found more than one matching field to bind for %s", name) 342 } 343 344 found = field 345 } 346 347 return found, nil 348 } 349 350 return nil, nil 351 } 352 353 func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { 354 switch t := in.(type) { 355 case *types.Named: 356 return b.findBindEmbedsTarget(t.Underlying(), name) 357 case *types.Struct: 358 return b.findBindStructEmbedsTarget(t, name) 359 case *types.Interface: 360 return b.findBindInterfaceEmbedsTarget(t, name) 361 } 362 363 return nil, nil 364 } 365 366 func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { 367 var found types.Object 368 for i := 0; i < strukt.NumFields(); i++ { 369 field := strukt.Field(i) 370 if !field.Embedded() { 371 continue 372 } 373 374 fieldType := field.Type() 375 if ptr, ok := fieldType.(*types.Pointer); ok { 376 fieldType = ptr.Elem() 377 } 378 379 f, err := b.findBindTarget(fieldType, name) 380 if err != nil { 381 return nil, err 382 } 383 384 if f != nil && found != nil { 385 return nil, fmt.Errorf("found more than one way to bind for %s", name) 386 } 387 388 if f != nil { 389 found = f 390 } 391 } 392 393 return found, nil 394 } 395 396 func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) { 397 var found types.Object 398 for i := 0; i < iface.NumEmbeddeds(); i++ { 399 embeddedType := iface.EmbeddedType(i) 400 401 f, err := b.findBindTarget(embeddedType, name) 402 if err != nil { 403 return nil, err 404 } 405 406 if f != nil && found != nil { 407 return nil, fmt.Errorf("found more than one way to bind for %s", name) 408 } 409 410 if f != nil { 411 found = f 412 } 413 } 414 415 return found, nil 416 } 417 418 func (f *Field) HasDirectives() bool { 419 return len(f.ImplDirectives()) > 0 420 } 421 422 func (f *Field) DirectiveObjName() string { 423 if f.Object.Root { 424 return "nil" 425 } 426 return f.GoReceiverName 427 } 428 429 func (f *Field) ImplDirectives() []*Directive { 430 var d []*Directive 431 loc := ast.LocationFieldDefinition 432 if f.Object.IsInputType() { 433 loc = ast.LocationInputFieldDefinition 434 } 435 for i := range f.Directives { 436 if !f.Directives[i].Builtin && 437 (f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) { 438 d = append(d, f.Directives[i]) 439 } 440 } 441 return d 442 } 443 444 func (f *Field) IsReserved() bool { 445 return strings.HasPrefix(f.Name, "__") 446 } 447 448 func (f *Field) IsMethod() bool { 449 return f.GoFieldType == GoFieldMethod 450 } 451 452 func (f *Field) IsVariable() bool { 453 return f.GoFieldType == GoFieldVariable 454 } 455 456 func (f *Field) IsMap() bool { 457 return f.GoFieldType == GoFieldMap 458 } 459 460 func (f *Field) IsConcurrent() bool { 461 if f.Object.DisableConcurrency { 462 return false 463 } 464 return f.MethodHasContext || f.IsResolver 465 } 466 467 func (f *Field) GoNameUnexported() string { 468 return templates.ToGoPrivate(f.Name) 469 } 470 471 func (f *Field) ShortInvocation() string { 472 if f.Object.Kind == ast.InputObject { 473 return fmt.Sprintf("%s().%s(ctx, &it, data)", strings.Title(f.Object.Definition.Name), f.GoFieldName) 474 } 475 return fmt.Sprintf("%s().%s(%s)", strings.Title(f.Object.Definition.Name), f.GoFieldName, f.CallArgs()) 476 } 477 478 func (f *Field) ArgsFunc() string { 479 if len(f.Args) == 0 { 480 return "" 481 } 482 483 return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" 484 } 485 486 func (f *Field) ResolverType() string { 487 if !f.IsResolver { 488 return "" 489 } 490 491 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 492 } 493 494 func (f *Field) ShortResolverDeclaration() string { 495 if f.Object.Kind == ast.InputObject { 496 return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error", 497 templates.CurrentImports.LookupType(f.Object.Reference()), 498 templates.CurrentImports.LookupType(f.TypeReference.GO), 499 ) 500 } 501 502 res := "(ctx context.Context" 503 504 if !f.Object.Root { 505 res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference())) 506 } 507 for _, arg := range f.Args { 508 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 509 } 510 511 result := templates.CurrentImports.LookupType(f.TypeReference.GO) 512 if f.Object.Stream { 513 result = "<-chan " + result 514 } 515 516 res += fmt.Sprintf(") (%s, error)", result) 517 return res 518 } 519 520 func (f *Field) ComplexitySignature() string { 521 res := "func(childComplexity int" 522 for _, arg := range f.Args { 523 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 524 } 525 res += ") int" 526 return res 527 } 528 529 func (f *Field) ComplexityArgs() string { 530 args := make([]string, len(f.Args)) 531 for i, arg := range f.Args { 532 args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 533 } 534 535 return strings.Join(args, ", ") 536 } 537 538 func (f *Field) CallArgs() string { 539 args := make([]string, 0, len(f.Args)+2) 540 541 if f.IsResolver { 542 args = append(args, "rctx") 543 544 if !f.Object.Root { 545 args = append(args, "obj") 546 } 547 } else if f.MethodHasContext { 548 args = append(args, "ctx") 549 } 550 551 for _, arg := range f.Args { 552 args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")") 553 } 554 555 return strings.Join(args, ", ") 556 }