github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/config/binder.go (about) 1 package config 2 3 import ( 4 "errors" 5 "fmt" 6 "go/token" 7 "go/types" 8 9 "github.com/vektah/gqlparser/v2/ast" 10 "golang.org/x/tools/go/packages" 11 12 "github.com/niko0xdev/gqlgen/codegen/templates" 13 "github.com/niko0xdev/gqlgen/internal/code" 14 ) 15 16 var ErrTypeNotFound = errors.New("unable to find type") 17 18 // Binder connects graphql types to golang types using static analysis 19 type Binder struct { 20 pkgs *code.Packages 21 schema *ast.Schema 22 cfg *Config 23 tctx *types.Context 24 References []*TypeReference 25 SawInvalid bool 26 objectCache map[string]map[string]types.Object 27 } 28 29 func (c *Config) NewBinder() *Binder { 30 return &Binder{ 31 pkgs: c.Packages, 32 schema: c.Schema, 33 cfg: c, 34 } 35 } 36 37 func (b *Binder) TypePosition(typ types.Type) token.Position { 38 named, isNamed := typ.(*types.Named) 39 if !isNamed { 40 return token.Position{ 41 Filename: "unknown", 42 } 43 } 44 45 return b.ObjectPosition(named.Obj()) 46 } 47 48 func (b *Binder) ObjectPosition(typ types.Object) token.Position { 49 if typ == nil { 50 return token.Position{ 51 Filename: "unknown", 52 } 53 } 54 pkg := b.pkgs.Load(typ.Pkg().Path()) 55 return pkg.Fset.Position(typ.Pos()) 56 } 57 58 func (b *Binder) FindTypeFromName(name string) (types.Type, error) { 59 pkgName, typeName := code.PkgAndType(name) 60 return b.FindType(pkgName, typeName) 61 } 62 63 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { 64 if pkgName == "" { 65 if typeName == "map[string]interface{}" { 66 return MapType, nil 67 } 68 69 if typeName == "interface{}" { 70 return InterfaceType, nil 71 } 72 } 73 74 obj, err := b.FindObject(pkgName, typeName) 75 if err != nil { 76 return nil, err 77 } 78 79 if fun, isFunc := obj.(*types.Func); isFunc { 80 return fun.Type().(*types.Signature).Params().At(0).Type(), nil 81 } 82 return obj.Type(), nil 83 } 84 85 func (b *Binder) InstantiateType(orig types.Type, targs []types.Type) (types.Type, error) { 86 if b.tctx == nil { 87 b.tctx = types.NewContext() 88 } 89 90 return types.Instantiate(b.tctx, orig, targs, false) 91 } 92 93 var ( 94 MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) 95 InterfaceType = types.NewInterfaceType(nil, nil) 96 ) 97 98 func (b *Binder) DefaultUserObject(name string) (types.Type, error) { 99 models := b.cfg.Models[name].Model 100 if len(models) == 0 { 101 return nil, fmt.Errorf(name + " not found in typemap") 102 } 103 104 if models[0] == "map[string]interface{}" { 105 return MapType, nil 106 } 107 108 if models[0] == "interface{}" { 109 return InterfaceType, nil 110 } 111 112 pkgName, typeName := code.PkgAndType(models[0]) 113 if pkgName == "" { 114 return nil, fmt.Errorf("missing package name for %s", name) 115 } 116 117 obj, err := b.FindObject(pkgName, typeName) 118 if err != nil { 119 return nil, err 120 } 121 122 return obj.Type(), nil 123 } 124 125 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) { 126 if pkgName == "" { 127 return nil, fmt.Errorf("package cannot be nil") 128 } 129 130 pkg := b.pkgs.LoadWithTypes(pkgName) 131 if pkg == nil { 132 err := b.pkgs.Errors() 133 if err != nil { 134 return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err) 135 } 136 return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName) 137 } 138 139 if b.objectCache == nil { 140 b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count()) 141 } 142 143 defsIndex, ok := b.objectCache[pkgName] 144 if !ok { 145 defsIndex = indexDefs(pkg) 146 b.objectCache[pkgName] = defsIndex 147 } 148 149 // function based marshalers take precedence 150 if val, ok := defsIndex["Marshal"+typeName]; ok { 151 return val, nil 152 } 153 154 if val, ok := defsIndex[typeName]; ok { 155 return val, nil 156 } 157 158 return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName) 159 } 160 161 func indexDefs(pkg *packages.Package) map[string]types.Object { 162 res := make(map[string]types.Object) 163 164 scope := pkg.Types.Scope() 165 for astNode, def := range pkg.TypesInfo.Defs { 166 // only look at defs in the top scope 167 if def == nil { 168 continue 169 } 170 parent := def.Parent() 171 if parent == nil || parent != scope { 172 continue 173 } 174 175 if _, ok := res[astNode.Name]; !ok { 176 // The above check may not be really needed, it is only here to have a consistent behavior with 177 // previous implementation of FindObject() function which only honored the first inclusion of a def. 178 // If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups. 179 res[astNode.Name] = def 180 } 181 } 182 183 return res 184 } 185 186 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { 187 newRef := *ref 188 newRef.GO = types.NewPointer(ref.GO) 189 b.References = append(b.References, &newRef) 190 return &newRef 191 } 192 193 // TypeReference is used by args and field types. The Definition can refer to both input and output types. 194 type TypeReference struct { 195 Definition *ast.Definition 196 GQL *ast.Type 197 GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. 198 Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. 199 CastType types.Type // Before calling marshalling functions cast from/to this base type 200 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function 201 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function 202 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler 203 IsOmittable bool // Is the type wrapped with Omittable 204 IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. 205 PointersInUmarshalInput bool // Inverse values and pointers in return. 206 IsRoot bool // Is the type a root level definition such as Query, Mutation or Subscription 207 } 208 209 func (ref *TypeReference) Elem() *TypeReference { 210 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 211 newRef := *ref 212 newRef.GO = p.Elem() 213 return &newRef 214 } 215 216 if ref.IsSlice() { 217 newRef := *ref 218 newRef.GO = ref.GO.(*types.Slice).Elem() 219 newRef.GQL = ref.GQL.Elem 220 return &newRef 221 } 222 return nil 223 } 224 225 func (ref *TypeReference) IsPtr() bool { 226 _, isPtr := ref.GO.(*types.Pointer) 227 return isPtr 228 } 229 230 // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful) 231 func (ref *TypeReference) IsPtrToPtr() bool { 232 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 233 _, isPtr := p.Elem().(*types.Pointer) 234 return isPtr 235 } 236 return false 237 } 238 239 func (ref *TypeReference) IsNilable() bool { 240 return IsNilable(ref.GO) 241 } 242 243 func (ref *TypeReference) IsSlice() bool { 244 _, isSlice := ref.GO.(*types.Slice) 245 return ref.GQL.Elem != nil && isSlice 246 } 247 248 func (ref *TypeReference) IsPtrToSlice() bool { 249 if ref.IsPtr() { 250 _, isPointerToSlice := ref.GO.(*types.Pointer).Elem().(*types.Slice) 251 return isPointerToSlice 252 } 253 return false 254 } 255 256 func (ref *TypeReference) IsPtrToIntf() bool { 257 if ref.IsPtr() { 258 _, isPointerToInterface := ref.GO.(*types.Pointer).Elem().(*types.Interface) 259 return isPointerToInterface 260 } 261 return false 262 } 263 264 func (ref *TypeReference) IsNamed() bool { 265 _, isSlice := ref.GO.(*types.Named) 266 return isSlice 267 } 268 269 func (ref *TypeReference) IsStruct() bool { 270 _, isStruct := ref.GO.Underlying().(*types.Struct) 271 return isStruct 272 } 273 274 func (ref *TypeReference) IsScalar() bool { 275 return ref.Definition.Kind == ast.Scalar 276 } 277 278 func (ref *TypeReference) IsMap() bool { 279 return ref.GO == MapType 280 } 281 282 func (ref *TypeReference) UniquenessKey() string { 283 nullability := "O" 284 if ref.GQL.NonNull { 285 nullability = "N" 286 } 287 288 elemNullability := "" 289 if ref.GQL.Elem != nil && ref.GQL.Elem.NonNull { 290 // Fix for #896 291 elemNullability = "áš„" 292 } 293 return nullability + ref.Definition.Name + "2" + templates.TypeIdentifier(ref.GO) + elemNullability 294 } 295 296 func (ref *TypeReference) MarshalFunc() string { 297 if ref.Definition == nil { 298 panic(errors.New("Definition missing for " + ref.GQL.Name())) 299 } 300 301 if ref.Definition.Kind == ast.InputObject { 302 return "" 303 } 304 305 return "marshal" + ref.UniquenessKey() 306 } 307 308 func (ref *TypeReference) UnmarshalFunc() string { 309 if ref.Definition == nil { 310 panic(errors.New("Definition missing for " + ref.GQL.Name())) 311 } 312 313 if !ref.Definition.IsInputType() { 314 return "" 315 } 316 317 return "unmarshal" + ref.UniquenessKey() 318 } 319 320 func (ref *TypeReference) IsTargetNilable() bool { 321 return IsNilable(ref.Target) 322 } 323 324 func (b *Binder) PushRef(ret *TypeReference) { 325 b.References = append(b.References, ret) 326 } 327 328 func isMap(t types.Type) bool { 329 if t == nil { 330 return true 331 } 332 _, ok := t.(*types.Map) 333 return ok 334 } 335 336 func isIntf(t types.Type) bool { 337 if t == nil { 338 return true 339 } 340 _, ok := t.(*types.Interface) 341 return ok 342 } 343 344 func unwrapOmittable(t types.Type) (types.Type, bool) { 345 if t == nil { 346 return t, false 347 } 348 named, ok := t.(*types.Named) 349 if !ok { 350 return t, false 351 } 352 if named.Origin().String() != "github.com/niko0xdev/gqlgen/graphql.Omittable[T any]" { 353 return t, false 354 } 355 return named.TypeArgs().At(0), true 356 } 357 358 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) { 359 if innerType, ok := unwrapOmittable(bindTarget); ok { 360 if schemaType.NonNull { 361 return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name()) 362 } 363 364 ref, err := b.TypeReference(schemaType, innerType) 365 if err != nil { 366 return nil, err 367 } 368 369 ref.IsOmittable = true 370 return ref, err 371 } 372 373 if !isValid(bindTarget) { 374 b.SawInvalid = true 375 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name()) 376 } 377 378 var pkgName, typeName string 379 def := b.schema.Types[schemaType.Name()] 380 defer func() { 381 if err == nil && ret != nil { 382 b.PushRef(ret) 383 } 384 }() 385 386 if len(b.cfg.Models[schemaType.Name()].Model) == 0 { 387 return nil, fmt.Errorf("%s was not found", schemaType.Name()) 388 } 389 390 for _, model := range b.cfg.Models[schemaType.Name()].Model { 391 if model == "map[string]interface{}" { 392 if !isMap(bindTarget) { 393 continue 394 } 395 return &TypeReference{ 396 Definition: def, 397 GQL: schemaType, 398 GO: MapType, 399 IsRoot: b.cfg.IsRoot(def), 400 }, nil 401 } 402 403 if model == "interface{}" { 404 if !isIntf(bindTarget) { 405 continue 406 } 407 return &TypeReference{ 408 Definition: def, 409 GQL: schemaType, 410 GO: InterfaceType, 411 IsRoot: b.cfg.IsRoot(def), 412 }, nil 413 } 414 415 pkgName, typeName = code.PkgAndType(model) 416 if pkgName == "" { 417 return nil, fmt.Errorf("missing package name for %s", schemaType.Name()) 418 } 419 420 ref := &TypeReference{ 421 Definition: def, 422 GQL: schemaType, 423 IsRoot: b.cfg.IsRoot(def), 424 } 425 426 obj, err := b.FindObject(pkgName, typeName) 427 if err != nil { 428 return nil, err 429 } 430 431 if fun, isFunc := obj.(*types.Func); isFunc { 432 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() 433 ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/niko0xdev/gqlgen/graphql.ContextMarshaler" 434 ref.Marshaler = fun 435 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) 436 } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") { 437 ref.GO = obj.Type() 438 ref.IsContext = true 439 ref.IsMarshaler = true 440 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { 441 ref.GO = obj.Type() 442 ref.IsMarshaler = true 443 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { 444 // TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595) 445 446 ref.GO = obj.Type() 447 ref.CastType = underlying 448 449 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) 450 if err != nil { 451 return nil, err 452 } 453 454 ref.Marshaler = underlyingRef.Marshaler 455 ref.Unmarshaler = underlyingRef.Unmarshaler 456 } else { 457 ref.GO = obj.Type() 458 } 459 460 ref.Target = ref.GO 461 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO) 462 463 if bindTarget != nil { 464 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil { 465 continue 466 } 467 ref.GO = bindTarget 468 } 469 470 ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput 471 472 return ref, nil 473 } 474 475 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String()) 476 } 477 478 func isValid(t types.Type) bool { 479 basic, isBasic := t.(*types.Basic) 480 if !isBasic { 481 return true 482 } 483 return basic.Kind() != types.Invalid 484 } 485 486 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { 487 if t.Elem != nil { 488 child := b.CopyModifiersFromAst(t.Elem, base) 489 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers { 490 child = types.NewPointer(child) 491 } 492 return types.NewSlice(child) 493 } 494 495 var isInterface bool 496 if named, ok := base.(*types.Named); ok { 497 _, isInterface = named.Underlying().(*types.Interface) 498 } 499 500 if !isInterface && !IsNilable(base) && !t.NonNull { 501 return types.NewPointer(base) 502 } 503 504 return base 505 } 506 507 func IsNilable(t types.Type) bool { 508 if namedType, isNamed := t.(*types.Named); isNamed { 509 return IsNilable(namedType.Underlying()) 510 } 511 _, isPtr := t.(*types.Pointer) 512 _, isMap := t.(*types.Map) 513 _, isInterface := t.(*types.Interface) 514 _, isSlice := t.(*types.Slice) 515 _, isChan := t.(*types.Chan) 516 return isPtr || isMap || isInterface || isSlice || isChan 517 } 518 519 func hasMethod(it types.Type, name string) bool { 520 if ptr, isPtr := it.(*types.Pointer); isPtr { 521 it = ptr.Elem() 522 } 523 namedType, ok := it.(*types.Named) 524 if !ok { 525 return false 526 } 527 528 for i := 0; i < namedType.NumMethods(); i++ { 529 if namedType.Method(i).Name() == name { 530 return true 531 } 532 } 533 return false 534 } 535 536 func basicUnderlying(it types.Type) *types.Basic { 537 if ptr, isPtr := it.(*types.Pointer); isPtr { 538 it = ptr.Elem() 539 } 540 namedType, ok := it.(*types.Named) 541 if !ok { 542 return nil 543 } 544 545 if basic, ok := namedType.Underlying().(*types.Basic); ok { 546 return basic 547 } 548 549 return nil 550 }