github.com/kerryoscer/gqlgen@v0.17.29/codegen/config/binder.go (about) 1 package config 2 3 import ( 4 "errors" 5 "fmt" 6 "go/token" 7 "go/types" 8 9 "golang.org/x/tools/go/packages" 10 11 "github.com/kerryoscer/gqlgen/codegen/templates" 12 "github.com/kerryoscer/gqlgen/internal/code" 13 "github.com/vektah/gqlparser/v2/ast" 14 ) 15 16 var ErrTypeNotFound = errors.New("unable to find type") 17 18 // Binder connects graphql types to golang types using static analysis 19 type Binder struct { 20 pkgs *code.Packages 21 schema *ast.Schema 22 cfg *Config 23 References []*TypeReference 24 SawInvalid bool 25 objectCache map[string]map[string]types.Object 26 } 27 28 func (c *Config) NewBinder() *Binder { 29 return &Binder{ 30 pkgs: c.Packages, 31 schema: c.Schema, 32 cfg: c, 33 } 34 } 35 36 func (b *Binder) TypePosition(typ types.Type) token.Position { 37 named, isNamed := typ.(*types.Named) 38 if !isNamed { 39 return token.Position{ 40 Filename: "unknown", 41 } 42 } 43 44 return b.ObjectPosition(named.Obj()) 45 } 46 47 func (b *Binder) ObjectPosition(typ types.Object) token.Position { 48 if typ == nil { 49 return token.Position{ 50 Filename: "unknown", 51 } 52 } 53 pkg := b.pkgs.Load(typ.Pkg().Path()) 54 return pkg.Fset.Position(typ.Pos()) 55 } 56 57 func (b *Binder) FindTypeFromName(name string) (types.Type, error) { 58 pkgName, typeName := code.PkgAndType(name) 59 return b.FindType(pkgName, typeName) 60 } 61 62 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { 63 if pkgName == "" { 64 if typeName == "map[string]interface{}" { 65 return MapType, nil 66 } 67 68 if typeName == "interface{}" { 69 return InterfaceType, nil 70 } 71 } 72 73 obj, err := b.FindObject(pkgName, typeName) 74 if err != nil { 75 return nil, err 76 } 77 78 if fun, isFunc := obj.(*types.Func); isFunc { 79 return fun.Type().(*types.Signature).Params().At(0).Type(), nil 80 } 81 return obj.Type(), nil 82 } 83 84 var ( 85 MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) 86 InterfaceType = types.NewInterfaceType(nil, nil) 87 ) 88 89 func (b *Binder) DefaultUserObject(name string) (types.Type, error) { 90 models := b.cfg.Models[name].Model 91 if len(models) == 0 { 92 return nil, fmt.Errorf(name + " not found in typemap") 93 } 94 95 if models[0] == "map[string]interface{}" { 96 return MapType, nil 97 } 98 99 if models[0] == "interface{}" { 100 return InterfaceType, nil 101 } 102 103 pkgName, typeName := code.PkgAndType(models[0]) 104 if pkgName == "" { 105 return nil, fmt.Errorf("missing package name for %s", name) 106 } 107 108 obj, err := b.FindObject(pkgName, typeName) 109 if err != nil { 110 return nil, err 111 } 112 113 return obj.Type(), nil 114 } 115 116 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) { 117 if pkgName == "" { 118 return nil, fmt.Errorf("package cannot be nil") 119 } 120 121 pkg := b.pkgs.LoadWithTypes(pkgName) 122 if pkg == nil { 123 err := b.pkgs.Errors() 124 if err != nil { 125 return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err) 126 } 127 return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName) 128 } 129 130 if b.objectCache == nil { 131 b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count()) 132 } 133 134 defsIndex, ok := b.objectCache[pkgName] 135 if !ok { 136 defsIndex = indexDefs(pkg) 137 b.objectCache[pkgName] = defsIndex 138 } 139 140 // function based marshalers take precedence 141 if val, ok := defsIndex["Marshal"+typeName]; ok { 142 return val, nil 143 } 144 145 if val, ok := defsIndex[typeName]; ok { 146 return val, nil 147 } 148 149 return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName) 150 } 151 152 func indexDefs(pkg *packages.Package) map[string]types.Object { 153 res := make(map[string]types.Object) 154 155 scope := pkg.Types.Scope() 156 for astNode, def := range pkg.TypesInfo.Defs { 157 // only look at defs in the top scope 158 if def == nil { 159 continue 160 } 161 parent := def.Parent() 162 if parent == nil || parent != scope { 163 continue 164 } 165 166 if _, ok := res[astNode.Name]; !ok { 167 // The above check may not be really needed, it is only here to have a consistent behavior with 168 // previous implementation of FindObject() function which only honored the first inclusion of a def. 169 // If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups. 170 res[astNode.Name] = def 171 } 172 } 173 174 return res 175 } 176 177 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { 178 newRef := *ref 179 newRef.GO = types.NewPointer(ref.GO) 180 b.References = append(b.References, &newRef) 181 return &newRef 182 } 183 184 // TypeReference is used by args and field types. The Definition can refer to both input and output types. 185 type TypeReference struct { 186 Definition *ast.Definition 187 GQL *ast.Type 188 GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. 189 Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. 190 CastType types.Type // Before calling marshalling functions cast from/to this base type 191 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function 192 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function 193 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler 194 IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. 195 PointersInUmarshalInput bool // Inverse values and pointers in return. 196 } 197 198 func (ref *TypeReference) Elem() *TypeReference { 199 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 200 newRef := *ref 201 newRef.GO = p.Elem() 202 return &newRef 203 } 204 205 if ref.IsSlice() { 206 newRef := *ref 207 newRef.GO = ref.GO.(*types.Slice).Elem() 208 newRef.GQL = ref.GQL.Elem 209 return &newRef 210 } 211 return nil 212 } 213 214 func (t *TypeReference) IsPtr() bool { 215 _, isPtr := t.GO.(*types.Pointer) 216 return isPtr 217 } 218 219 // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful) 220 func (t *TypeReference) IsPtrToPtr() bool { 221 if p, isPtr := t.GO.(*types.Pointer); isPtr { 222 _, isPtr := p.Elem().(*types.Pointer) 223 return isPtr 224 } 225 return false 226 } 227 228 func (t *TypeReference) IsNilable() bool { 229 return IsNilable(t.GO) 230 } 231 232 func (t *TypeReference) IsSlice() bool { 233 _, isSlice := t.GO.(*types.Slice) 234 return t.GQL.Elem != nil && isSlice 235 } 236 237 func (t *TypeReference) IsPtrToSlice() bool { 238 if t.IsPtr() { 239 _, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice) 240 return isPointerToSlice 241 } 242 return false 243 } 244 245 func (t *TypeReference) IsNamed() bool { 246 _, isSlice := t.GO.(*types.Named) 247 return isSlice 248 } 249 250 func (t *TypeReference) IsStruct() bool { 251 _, isStruct := t.GO.Underlying().(*types.Struct) 252 return isStruct 253 } 254 255 func (t *TypeReference) IsScalar() bool { 256 return t.Definition.Kind == ast.Scalar 257 } 258 259 func (t *TypeReference) UniquenessKey() string { 260 nullability := "O" 261 if t.GQL.NonNull { 262 nullability = "N" 263 } 264 265 elemNullability := "" 266 if t.GQL.Elem != nil && t.GQL.Elem.NonNull { 267 // Fix for #896 268 elemNullability = "áš„" 269 } 270 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability 271 } 272 273 func (t *TypeReference) MarshalFunc() string { 274 if t.Definition == nil { 275 panic(errors.New("Definition missing for " + t.GQL.Name())) 276 } 277 278 if t.Definition.Kind == ast.InputObject { 279 return "" 280 } 281 282 return "marshal" + t.UniquenessKey() 283 } 284 285 func (t *TypeReference) UnmarshalFunc() string { 286 if t.Definition == nil { 287 panic(errors.New("Definition missing for " + t.GQL.Name())) 288 } 289 290 if !t.Definition.IsInputType() { 291 return "" 292 } 293 294 return "unmarshal" + t.UniquenessKey() 295 } 296 297 func (t *TypeReference) IsTargetNilable() bool { 298 return IsNilable(t.Target) 299 } 300 301 func (b *Binder) PushRef(ret *TypeReference) { 302 b.References = append(b.References, ret) 303 } 304 305 func isMap(t types.Type) bool { 306 if t == nil { 307 return true 308 } 309 _, ok := t.(*types.Map) 310 return ok 311 } 312 313 func isIntf(t types.Type) bool { 314 if t == nil { 315 return true 316 } 317 _, ok := t.(*types.Interface) 318 return ok 319 } 320 321 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) { 322 if !isValid(bindTarget) { 323 b.SawInvalid = true 324 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name()) 325 } 326 327 var pkgName, typeName string 328 def := b.schema.Types[schemaType.Name()] 329 defer func() { 330 if err == nil && ret != nil { 331 b.PushRef(ret) 332 } 333 }() 334 335 if len(b.cfg.Models[schemaType.Name()].Model) == 0 { 336 return nil, fmt.Errorf("%s was not found", schemaType.Name()) 337 } 338 339 for _, model := range b.cfg.Models[schemaType.Name()].Model { 340 if model == "map[string]interface{}" { 341 if !isMap(bindTarget) { 342 continue 343 } 344 return &TypeReference{ 345 Definition: def, 346 GQL: schemaType, 347 GO: MapType, 348 }, nil 349 } 350 351 if model == "interface{}" { 352 if !isIntf(bindTarget) { 353 continue 354 } 355 return &TypeReference{ 356 Definition: def, 357 GQL: schemaType, 358 GO: InterfaceType, 359 }, nil 360 } 361 362 pkgName, typeName = code.PkgAndType(model) 363 if pkgName == "" { 364 return nil, fmt.Errorf("missing package name for %s", schemaType.Name()) 365 } 366 367 ref := &TypeReference{ 368 Definition: def, 369 GQL: schemaType, 370 } 371 372 obj, err := b.FindObject(pkgName, typeName) 373 if err != nil { 374 return nil, err 375 } 376 377 if fun, isFunc := obj.(*types.Func); isFunc { 378 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() 379 ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/kerryoscer/gqlgen/graphql.ContextMarshaler" 380 ref.Marshaler = fun 381 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) 382 } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") { 383 ref.GO = obj.Type() 384 ref.IsContext = true 385 ref.IsMarshaler = true 386 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { 387 ref.GO = obj.Type() 388 ref.IsMarshaler = true 389 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { 390 // TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595) 391 392 ref.GO = obj.Type() 393 ref.CastType = underlying 394 395 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) 396 if err != nil { 397 return nil, err 398 } 399 400 ref.Marshaler = underlyingRef.Marshaler 401 ref.Unmarshaler = underlyingRef.Unmarshaler 402 } else { 403 ref.GO = obj.Type() 404 } 405 406 ref.Target = ref.GO 407 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO) 408 409 if bindTarget != nil { 410 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil { 411 continue 412 } 413 ref.GO = bindTarget 414 } 415 416 ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput 417 418 return ref, nil 419 } 420 421 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String()) 422 } 423 424 func isValid(t types.Type) bool { 425 basic, isBasic := t.(*types.Basic) 426 if !isBasic { 427 return true 428 } 429 return basic.Kind() != types.Invalid 430 } 431 432 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { 433 if t.Elem != nil { 434 child := b.CopyModifiersFromAst(t.Elem, base) 435 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers { 436 child = types.NewPointer(child) 437 } 438 return types.NewSlice(child) 439 } 440 441 var isInterface bool 442 if named, ok := base.(*types.Named); ok { 443 _, isInterface = named.Underlying().(*types.Interface) 444 } 445 446 if !isInterface && !IsNilable(base) && !t.NonNull { 447 return types.NewPointer(base) 448 } 449 450 return base 451 } 452 453 func IsNilable(t types.Type) bool { 454 if namedType, isNamed := t.(*types.Named); isNamed { 455 return IsNilable(namedType.Underlying()) 456 } 457 _, isPtr := t.(*types.Pointer) 458 _, isMap := t.(*types.Map) 459 _, isInterface := t.(*types.Interface) 460 _, isSlice := t.(*types.Slice) 461 _, isChan := t.(*types.Chan) 462 return isPtr || isMap || isInterface || isSlice || isChan 463 } 464 465 func hasMethod(it types.Type, name string) bool { 466 if ptr, isPtr := it.(*types.Pointer); isPtr { 467 it = ptr.Elem() 468 } 469 namedType, ok := it.(*types.Named) 470 if !ok { 471 return false 472 } 473 474 for i := 0; i < namedType.NumMethods(); i++ { 475 if namedType.Method(i).Name() == name { 476 return true 477 } 478 } 479 return false 480 } 481 482 func basicUnderlying(it types.Type) *types.Basic { 483 if ptr, isPtr := it.(*types.Pointer); isPtr { 484 it = ptr.Elem() 485 } 486 namedType, ok := it.(*types.Named) 487 if !ok { 488 return nil 489 } 490 491 if basic, ok := namedType.Underlying().(*types.Basic); ok { 492 return basic 493 } 494 495 return nil 496 }