git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/field.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "go/types" 6 "log" 7 "reflect" 8 "strconv" 9 "strings" 10 11 "git.sr.ht/~sircmpwn/gqlgen/codegen/config" 12 "git.sr.ht/~sircmpwn/gqlgen/codegen/templates" 13 "github.com/pkg/errors" 14 "github.com/vektah/gqlparser/v2/ast" 15 ) 16 17 type Field struct { 18 *ast.FieldDefinition 19 20 TypeReference *config.TypeReference 21 GoFieldType GoFieldType // The field type in go, if any 22 GoReceiverName string // The name of method & var receiver in go, if any 23 GoFieldName string // The name of the method or var in go, if any 24 IsResolver bool // Does this field need a resolver 25 Args []*FieldArgument // A list of arguments to be passed to this field 26 MethodHasContext bool // If this is bound to a go method, does the method also take a context 27 NoErr bool // If this is bound to a go method, does that method have an error as the second argument 28 Object *Object // A link back to the parent object 29 Default interface{} // The default value 30 Stream bool // does this field return a channel? 31 Directives []*Directive 32 } 33 34 func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { 35 dirs, err := b.getDirectives(field.Directives) 36 if err != nil { 37 return nil, err 38 } 39 40 f := Field{ 41 FieldDefinition: field, 42 Object: obj, 43 Directives: dirs, 44 GoFieldName: templates.ToGo(field.Name), 45 GoFieldType: GoFieldVariable, 46 GoReceiverName: "obj", 47 } 48 49 if field.DefaultValue != nil { 50 var err error 51 f.Default, err = field.DefaultValue.Value(nil) 52 if err != nil { 53 return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error()) 54 } 55 } 56 57 for _, arg := range field.Arguments { 58 newArg, err := b.buildArg(obj, arg) 59 if err != nil { 60 return nil, err 61 } 62 f.Args = append(f.Args, newArg) 63 } 64 65 if err = b.bindField(obj, &f); err != nil { 66 f.IsResolver = true 67 log.Println(err.Error()) 68 } 69 70 if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { 71 f.TypeReference = b.Binder.PointerTo(f.TypeReference) 72 } 73 74 return &f, nil 75 } 76 77 func (b *builder) bindField(obj *Object, f *Field) (errret error) { 78 defer func() { 79 if f.TypeReference == nil { 80 tr, err := b.Binder.TypeReference(f.Type, nil) 81 if err != nil { 82 errret = err 83 } 84 f.TypeReference = tr 85 } 86 }() 87 88 f.Stream = obj.Stream 89 90 switch { 91 case f.Name == "__schema": 92 f.GoFieldType = GoFieldMethod 93 f.GoReceiverName = "ec" 94 f.GoFieldName = "introspectSchema" 95 return nil 96 case f.Name == "__type": 97 f.GoFieldType = GoFieldMethod 98 f.GoReceiverName = "ec" 99 f.GoFieldName = "introspectType" 100 return nil 101 case f.Name == "_entities": 102 f.GoFieldType = GoFieldMethod 103 f.GoReceiverName = "ec" 104 f.GoFieldName = "__resolve_entities" 105 f.MethodHasContext = true 106 return nil 107 case f.Name == "_service": 108 f.GoFieldType = GoFieldMethod 109 f.GoReceiverName = "ec" 110 f.GoFieldName = "__resolve__service" 111 f.MethodHasContext = true 112 return nil 113 case obj.Root: 114 f.IsResolver = true 115 return nil 116 case b.Config.Models[obj.Name].Fields[f.Name].Resolver: 117 f.IsResolver = true 118 return nil 119 case obj.Type == config.MapType: 120 f.GoFieldType = GoFieldMap 121 return nil 122 case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": 123 f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName 124 } 125 126 target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) 127 if err != nil { 128 return err 129 } 130 131 pos := b.Binder.ObjectPosition(target) 132 133 switch target := target.(type) { 134 case nil: 135 objPos := b.Binder.TypePosition(obj.Type) 136 return fmt.Errorf( 137 "%s:%d adding resolver method for %s.%s, nothing matched", 138 objPos.Filename, 139 objPos.Line, 140 obj.Name, 141 f.Name, 142 ) 143 144 case *types.Func: 145 sig := target.Type().(*types.Signature) 146 if sig.Results().Len() == 1 { 147 f.NoErr = true 148 } else if sig.Results().Len() != 2 { 149 return fmt.Errorf("method has wrong number of args") 150 } 151 params := sig.Params() 152 // If the first argument is the context, remove it from the comparison and set 153 // the MethodHasContext flag so that the context will be passed to this model's method 154 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 155 f.MethodHasContext = true 156 vars := make([]*types.Var, params.Len()-1) 157 for i := 1; i < params.Len(); i++ { 158 vars[i-1] = params.At(i) 159 } 160 params = types.NewTuple(vars...) 161 } 162 163 if err = b.bindArgs(f, params); err != nil { 164 return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line) 165 } 166 167 result := sig.Results().At(0) 168 tr, err := b.Binder.TypeReference(f.Type, result.Type()) 169 if err != nil { 170 return err 171 } 172 173 // success, args and return type match. Bind to method 174 f.GoFieldType = GoFieldMethod 175 f.GoReceiverName = "obj" 176 f.GoFieldName = target.Name() 177 f.TypeReference = tr 178 179 return nil 180 181 case *types.Var: 182 tr, err := b.Binder.TypeReference(f.Type, target.Type()) 183 if err != nil { 184 return err 185 } 186 187 // success, bind to var 188 f.GoFieldType = GoFieldVariable 189 f.GoReceiverName = "obj" 190 f.GoFieldName = target.Name() 191 f.TypeReference = tr 192 193 return nil 194 default: 195 panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) 196 } 197 } 198 199 // findBindTarget attempts to match the name to a field or method on a Type 200 // with the following priorites: 201 // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found 202 // 2. Any method or field with a matching name. Errors if more than one match is found 203 // 3. Same logic again for embedded fields 204 func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { 205 // NOTE: a struct tag will override both methods and fields 206 // Bind to struct tag 207 found, err := b.findBindStructTagTarget(t, name) 208 if found != nil || err != nil { 209 return found, err 210 } 211 212 // Search for a method to bind to 213 foundMethod, err := b.findBindMethodTarget(t, name) 214 if err != nil { 215 return nil, err 216 } 217 218 // Search for a field to bind to 219 foundField, err := b.findBindFieldTarget(t, name) 220 if err != nil { 221 return nil, err 222 } 223 224 switch { 225 case foundField == nil && foundMethod != nil: 226 // Bind to method 227 return foundMethod, nil 228 case foundField != nil && foundMethod == nil: 229 // Bind to field 230 return foundField, nil 231 case foundField != nil && foundMethod != nil: 232 // Error 233 return nil, errors.Errorf("found more than one way to bind for %s", name) 234 } 235 236 // Search embeds 237 return b.findBindEmbedsTarget(t, name) 238 } 239 240 func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { 241 if b.Config.StructTag == "" { 242 return nil, nil 243 } 244 245 switch t := in.(type) { 246 case *types.Named: 247 return b.findBindStructTagTarget(t.Underlying(), name) 248 case *types.Struct: 249 var found types.Object 250 for i := 0; i < t.NumFields(); i++ { 251 field := t.Field(i) 252 if !field.Exported() || field.Embedded() { 253 continue 254 } 255 tags := reflect.StructTag(t.Tag(i)) 256 if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { 257 if found != nil { 258 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) 259 } 260 261 found = field 262 } 263 } 264 265 return found, nil 266 } 267 268 return nil, nil 269 } 270 271 func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { 272 switch t := in.(type) { 273 case *types.Named: 274 if _, ok := t.Underlying().(*types.Interface); ok { 275 return b.findBindMethodTarget(t.Underlying(), name) 276 } 277 278 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 279 case *types.Interface: 280 // FIX-ME: Should use ExplicitMethod here? What's the difference? 281 return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) 282 } 283 284 return nil, nil 285 } 286 287 func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { 288 var found types.Object 289 for i := 0; i < methodCount; i++ { 290 method := methodFunc(i) 291 if !method.Exported() || !strings.EqualFold(method.Name(), name) { 292 continue 293 } 294 295 if found != nil { 296 return nil, errors.Errorf("found more than one matching method to bind for %s", name) 297 } 298 299 found = method 300 } 301 302 return found, nil 303 } 304 305 func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { 306 switch t := in.(type) { 307 case *types.Named: 308 return b.findBindFieldTarget(t.Underlying(), name) 309 case *types.Struct: 310 var found types.Object 311 for i := 0; i < t.NumFields(); i++ { 312 field := t.Field(i) 313 if !field.Exported() || !equalFieldName(field.Name(), name) { 314 continue 315 } 316 317 if found != nil { 318 return nil, errors.Errorf("found more than one matching field to bind for %s", name) 319 } 320 321 found = field 322 } 323 324 return found, nil 325 } 326 327 return nil, nil 328 } 329 330 func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { 331 switch t := in.(type) { 332 case *types.Named: 333 return b.findBindEmbedsTarget(t.Underlying(), name) 334 case *types.Struct: 335 return b.findBindStructEmbedsTarget(t, name) 336 case *types.Interface: 337 return b.findBindInterfaceEmbedsTarget(t, name) 338 } 339 340 return nil, nil 341 } 342 343 func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { 344 var found types.Object 345 for i := 0; i < strukt.NumFields(); i++ { 346 field := strukt.Field(i) 347 if !field.Embedded() { 348 continue 349 } 350 351 fieldType := field.Type() 352 if ptr, ok := fieldType.(*types.Pointer); ok { 353 fieldType = ptr.Elem() 354 } 355 356 f, err := b.findBindTarget(fieldType, name) 357 if err != nil { 358 return nil, err 359 } 360 361 if f != nil && found != nil { 362 return nil, errors.Errorf("found more than one way to bind for %s", name) 363 } 364 365 if f != nil { 366 found = f 367 } 368 } 369 370 return found, nil 371 } 372 373 func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) { 374 var found types.Object 375 for i := 0; i < iface.NumEmbeddeds(); i++ { 376 embeddedType := iface.EmbeddedType(i) 377 378 f, err := b.findBindTarget(embeddedType, name) 379 if err != nil { 380 return nil, err 381 } 382 383 if f != nil && found != nil { 384 return nil, errors.Errorf("found more than one way to bind for %s", name) 385 } 386 387 if f != nil { 388 found = f 389 } 390 } 391 392 return found, nil 393 } 394 395 func (f *Field) HasDirectives() bool { 396 return len(f.ImplDirectives()) > 0 397 } 398 399 func (f *Field) DirectiveObjName() string { 400 if f.Object.Root { 401 return "nil" 402 } 403 return f.GoReceiverName 404 } 405 406 func (f *Field) ImplDirectives() []*Directive { 407 var d []*Directive 408 loc := ast.LocationFieldDefinition 409 if f.Object.IsInputType() { 410 loc = ast.LocationInputFieldDefinition 411 } 412 for i := range f.Directives { 413 if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) { 414 d = append(d, f.Directives[i]) 415 } 416 } 417 return d 418 } 419 420 func (f *Field) IsReserved() bool { 421 return strings.HasPrefix(f.Name, "__") 422 } 423 424 func (f *Field) IsMethod() bool { 425 return f.GoFieldType == GoFieldMethod 426 } 427 428 func (f *Field) IsVariable() bool { 429 return f.GoFieldType == GoFieldVariable 430 } 431 432 func (f *Field) IsMap() bool { 433 return f.GoFieldType == GoFieldMap 434 } 435 436 func (f *Field) IsConcurrent() bool { 437 if f.Object.DisableConcurrency { 438 return false 439 } 440 return f.MethodHasContext || f.IsResolver 441 } 442 443 func (f *Field) GoNameUnexported() string { 444 return templates.ToGoPrivate(f.Name) 445 } 446 447 func (f *Field) ShortInvocation() string { 448 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 449 } 450 451 func (f *Field) ArgsFunc() string { 452 if len(f.Args) == 0 { 453 return "" 454 } 455 456 return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" 457 } 458 459 func (f *Field) ResolverType() string { 460 if !f.IsResolver { 461 return "" 462 } 463 464 return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) 465 } 466 467 func (f *Field) ShortResolverDeclaration() string { 468 res := "(ctx context.Context" 469 470 if !f.Object.Root { 471 res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference())) 472 } 473 for _, arg := range f.Args { 474 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 475 } 476 477 result := templates.CurrentImports.LookupType(f.TypeReference.GO) 478 if f.Object.Stream { 479 result = "<-chan " + result 480 } 481 482 res += fmt.Sprintf(") (%s, error)", result) 483 return res 484 } 485 486 func (f *Field) ComplexitySignature() string { 487 res := fmt.Sprintf("func(childComplexity int") 488 for _, arg := range f.Args { 489 res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) 490 } 491 res += ") int" 492 return res 493 } 494 495 func (f *Field) ComplexityArgs() string { 496 args := make([]string, len(f.Args)) 497 for i, arg := range f.Args { 498 args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" 499 } 500 501 return strings.Join(args, ", ") 502 } 503 504 func (f *Field) CallArgs() string { 505 args := make([]string, 0, len(f.Args)+2) 506 507 if f.IsResolver { 508 args = append(args, "rctx") 509 510 if !f.Object.Root { 511 args = append(args, "obj") 512 } 513 } else if f.MethodHasContext { 514 args = append(args, "ctx") 515 } 516 517 for _, arg := range f.Args { 518 args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")") 519 } 520 521 return strings.Join(args, ", ") 522 }