github.com/shish-dev/gqlgen@v0.99.0/codegen/config/binder.go (about) 1 package config 2 3 import ( 4 "errors" 5 "fmt" 6 "go/token" 7 "go/types" 8 9 "golang.org/x/tools/go/packages" 10 11 "github.com/shish-dev/gqlgen/codegen/templates" 12 "github.com/shish-dev/gqlgen/internal/code" 13 "github.com/vektah/gqlparser/v2/ast" 14 ) 15 16 var ErrTypeNotFound = errors.New("unable to find type") 17 18 // Binder connects graphql types to golang types using static analysis 19 type Binder struct { 20 pkgs *code.Packages 21 schema *ast.Schema 22 cfg *Config 23 References []*TypeReference 24 SawInvalid bool 25 objectCache map[string]map[string]types.Object 26 } 27 28 func (c *Config) NewBinder() *Binder { 29 return &Binder{ 30 pkgs: c.Packages, 31 schema: c.Schema, 32 cfg: c, 33 } 34 } 35 36 func (b *Binder) TypePosition(typ types.Type) token.Position { 37 named, isNamed := typ.(*types.Named) 38 if !isNamed { 39 return token.Position{ 40 Filename: "unknown", 41 } 42 } 43 44 return b.ObjectPosition(named.Obj()) 45 } 46 47 func (b *Binder) ObjectPosition(typ types.Object) token.Position { 48 if typ == nil { 49 return token.Position{ 50 Filename: "unknown", 51 } 52 } 53 pkg := b.pkgs.Load(typ.Pkg().Path()) 54 return pkg.Fset.Position(typ.Pos()) 55 } 56 57 func (b *Binder) FindTypeFromName(name string) (types.Type, error) { 58 pkgName, typeName := code.PkgAndType(name) 59 return b.FindType(pkgName, typeName) 60 } 61 62 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { 63 if pkgName == "" { 64 if typeName == "map[string]interface{}" { 65 return MapType, nil 66 } 67 68 if typeName == "interface{}" { 69 return InterfaceType, nil 70 } 71 } 72 73 obj, err := b.FindObject(pkgName, typeName) 74 if err != nil { 75 return nil, err 76 } 77 78 if fun, isFunc := obj.(*types.Func); isFunc { 79 return fun.Type().(*types.Signature).Params().At(0).Type(), nil 80 } 81 return obj.Type(), nil 82 } 83 84 var ( 85 MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) 86 InterfaceType = types.NewInterfaceType(nil, nil) 87 ) 88 89 func (b *Binder) DefaultUserObject(name string) (types.Type, error) { 90 models := b.cfg.Models[name].Model 91 if len(models) == 0 { 92 return nil, fmt.Errorf(name + " not found in typemap") 93 } 94 95 if models[0] == "map[string]interface{}" { 96 return MapType, nil 97 } 98 99 if models[0] == "interface{}" { 100 return InterfaceType, nil 101 } 102 103 pkgName, typeName := code.PkgAndType(models[0]) 104 if pkgName == "" { 105 return nil, fmt.Errorf("missing package name for %s", name) 106 } 107 108 obj, err := b.FindObject(pkgName, typeName) 109 if err != nil { 110 return nil, err 111 } 112 113 return obj.Type(), nil 114 } 115 116 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) { 117 if pkgName == "" { 118 return nil, fmt.Errorf("package cannot be nil") 119 } 120 121 pkg := b.pkgs.LoadWithTypes(pkgName) 122 if pkg == nil { 123 err := b.pkgs.Errors() 124 if err != nil { 125 return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err) 126 } 127 return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName) 128 } 129 130 if b.objectCache == nil { 131 b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count()) 132 } 133 134 defsIndex, ok := b.objectCache[pkgName] 135 if !ok { 136 defsIndex = indexDefs(pkg) 137 b.objectCache[pkgName] = defsIndex 138 } 139 140 // function based marshalers take precedence 141 if val, ok := defsIndex["Marshal"+typeName]; ok { 142 return val, nil 143 } 144 145 if val, ok := defsIndex[typeName]; ok { 146 return val, nil 147 } 148 149 return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName) 150 } 151 152 func indexDefs(pkg *packages.Package) map[string]types.Object { 153 res := make(map[string]types.Object) 154 155 scope := pkg.Types.Scope() 156 for astNode, def := range pkg.TypesInfo.Defs { 157 // only look at defs in the top scope 158 if def == nil { 159 continue 160 } 161 parent := def.Parent() 162 if parent == nil || parent != scope { 163 continue 164 } 165 166 if _, ok := res[astNode.Name]; !ok { 167 // The above check may not be really needed, it is only here to have a consistent behavior with 168 // previous implementation of FindObject() function which only honored the first inclusion of a def. 169 // If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups. 170 res[astNode.Name] = def 171 } 172 } 173 174 return res 175 } 176 177 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { 178 newRef := *ref 179 newRef.GO = types.NewPointer(ref.GO) 180 b.References = append(b.References, &newRef) 181 return &newRef 182 } 183 184 // TypeReference is used by args and field types. The Definition can refer to both input and output types. 185 type TypeReference struct { 186 Definition *ast.Definition 187 GQL *ast.Type 188 GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. 189 Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. 190 CastType types.Type // Before calling marshalling functions cast from/to this base type 191 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function 192 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function 193 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler 194 IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. 195 } 196 197 func (ref *TypeReference) Elem() *TypeReference { 198 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 199 newRef := *ref 200 newRef.GO = p.Elem() 201 return &newRef 202 } 203 204 if ref.IsSlice() { 205 newRef := *ref 206 newRef.GO = ref.GO.(*types.Slice).Elem() 207 newRef.GQL = ref.GQL.Elem 208 return &newRef 209 } 210 return nil 211 } 212 213 func (t *TypeReference) IsPtr() bool { 214 _, isPtr := t.GO.(*types.Pointer) 215 return isPtr 216 } 217 218 // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful) 219 // 220 func (t *TypeReference) IsPtrToPtr() bool { 221 if p, isPtr := t.GO.(*types.Pointer); isPtr { 222 _, isPtr := p.Elem().(*types.Pointer) 223 return isPtr 224 } 225 return false 226 } 227 228 func (t *TypeReference) IsNilable() bool { 229 return IsNilable(t.GO) 230 } 231 232 func (t *TypeReference) IsSlice() bool { 233 _, isSlice := t.GO.(*types.Slice) 234 return t.GQL.Elem != nil && isSlice 235 } 236 237 func (t *TypeReference) IsPtrToSlice() bool { 238 if t.IsPtr() { 239 _, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice) 240 return isPointerToSlice 241 } 242 return false 243 } 244 245 func (t *TypeReference) IsNamed() bool { 246 _, isSlice := t.GO.(*types.Named) 247 return isSlice 248 } 249 250 func (t *TypeReference) IsStruct() bool { 251 _, isStruct := t.GO.Underlying().(*types.Struct) 252 return isStruct 253 } 254 255 func (t *TypeReference) IsScalar() bool { 256 return t.Definition.Kind == ast.Scalar 257 } 258 259 func (t *TypeReference) UniquenessKey() string { 260 nullability := "O" 261 if t.GQL.NonNull { 262 nullability = "N" 263 } 264 265 elemNullability := "" 266 if t.GQL.Elem != nil && t.GQL.Elem.NonNull { 267 // Fix for #896 268 elemNullability = "áš„" 269 } 270 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability 271 } 272 273 func (t *TypeReference) MarshalFunc() string { 274 if t.Definition == nil { 275 panic(errors.New("Definition missing for " + t.GQL.Name())) 276 } 277 278 if t.Definition.Kind == ast.InputObject { 279 return "" 280 } 281 282 return "marshal" + t.UniquenessKey() 283 } 284 285 func (t *TypeReference) UnmarshalFunc() string { 286 if t.Definition == nil { 287 panic(errors.New("Definition missing for " + t.GQL.Name())) 288 } 289 290 if !t.Definition.IsInputType() { 291 return "" 292 } 293 294 return "unmarshal" + t.UniquenessKey() 295 } 296 297 func (t *TypeReference) IsTargetNilable() bool { 298 return IsNilable(t.Target) 299 } 300 301 func (b *Binder) PushRef(ret *TypeReference) { 302 b.References = append(b.References, ret) 303 } 304 305 func isMap(t types.Type) bool { 306 if t == nil { 307 return true 308 } 309 _, ok := t.(*types.Map) 310 return ok 311 } 312 313 func isIntf(t types.Type) bool { 314 if t == nil { 315 return true 316 } 317 _, ok := t.(*types.Interface) 318 return ok 319 } 320 321 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) { 322 if !isValid(bindTarget) { 323 b.SawInvalid = true 324 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name()) 325 } 326 327 var pkgName, typeName string 328 def := b.schema.Types[schemaType.Name()] 329 defer func() { 330 if err == nil && ret != nil { 331 b.PushRef(ret) 332 } 333 }() 334 335 if len(b.cfg.Models[schemaType.Name()].Model) == 0 { 336 return nil, fmt.Errorf("%s was not found", schemaType.Name()) 337 } 338 339 for _, model := range b.cfg.Models[schemaType.Name()].Model { 340 if model == "map[string]interface{}" { 341 if !isMap(bindTarget) { 342 continue 343 } 344 return &TypeReference{ 345 Definition: def, 346 GQL: schemaType, 347 GO: MapType, 348 }, nil 349 } 350 351 if model == "interface{}" { 352 if !isIntf(bindTarget) { 353 continue 354 } 355 return &TypeReference{ 356 Definition: def, 357 GQL: schemaType, 358 GO: InterfaceType, 359 }, nil 360 } 361 362 pkgName, typeName = code.PkgAndType(model) 363 if pkgName == "" { 364 return nil, fmt.Errorf("missing package name for %s", schemaType.Name()) 365 } 366 367 ref := &TypeReference{ 368 Definition: def, 369 GQL: schemaType, 370 } 371 372 obj, err := b.FindObject(pkgName, typeName) 373 if err != nil { 374 return nil, err 375 } 376 377 if fun, isFunc := obj.(*types.Func); isFunc { 378 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() 379 ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler" 380 ref.Marshaler = fun 381 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) 382 } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") { 383 ref.GO = obj.Type() 384 ref.IsContext = true 385 ref.IsMarshaler = true 386 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { 387 ref.GO = obj.Type() 388 ref.IsMarshaler = true 389 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { 390 // TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595) 391 392 ref.GO = obj.Type() 393 ref.CastType = underlying 394 395 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) 396 if err != nil { 397 return nil, err 398 } 399 400 ref.Marshaler = underlyingRef.Marshaler 401 ref.Unmarshaler = underlyingRef.Unmarshaler 402 } else { 403 ref.GO = obj.Type() 404 } 405 406 ref.Target = ref.GO 407 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO) 408 409 if bindTarget != nil { 410 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil { 411 continue 412 } 413 ref.GO = bindTarget 414 } 415 416 return ref, nil 417 } 418 419 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String()) 420 } 421 422 func isValid(t types.Type) bool { 423 basic, isBasic := t.(*types.Basic) 424 if !isBasic { 425 return true 426 } 427 return basic.Kind() != types.Invalid 428 } 429 430 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { 431 if t.Elem != nil { 432 child := b.CopyModifiersFromAst(t.Elem, base) 433 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers { 434 child = types.NewPointer(child) 435 } 436 return types.NewSlice(child) 437 } 438 439 var isInterface bool 440 if named, ok := base.(*types.Named); ok { 441 _, isInterface = named.Underlying().(*types.Interface) 442 } 443 444 if !isInterface && !IsNilable(base) && !t.NonNull { 445 return types.NewPointer(base) 446 } 447 448 return base 449 } 450 451 func IsNilable(t types.Type) bool { 452 if namedType, isNamed := t.(*types.Named); isNamed { 453 return IsNilable(namedType.Underlying()) 454 } 455 _, isPtr := t.(*types.Pointer) 456 _, isMap := t.(*types.Map) 457 _, isInterface := t.(*types.Interface) 458 _, isSlice := t.(*types.Slice) 459 _, isChan := t.(*types.Chan) 460 return isPtr || isMap || isInterface || isSlice || isChan 461 } 462 463 func hasMethod(it types.Type, name string) bool { 464 if ptr, isPtr := it.(*types.Pointer); isPtr { 465 it = ptr.Elem() 466 } 467 namedType, ok := it.(*types.Named) 468 if !ok { 469 return false 470 } 471 472 for i := 0; i < namedType.NumMethods(); i++ { 473 if namedType.Method(i).Name() == name { 474 return true 475 } 476 } 477 return false 478 } 479 480 func basicUnderlying(it types.Type) *types.Basic { 481 if ptr, isPtr := it.(*types.Pointer); isPtr { 482 it = ptr.Elem() 483 } 484 namedType, ok := it.(*types.Named) 485 if !ok { 486 return nil 487 } 488 489 if basic, ok := namedType.Underlying().(*types.Basic); ok { 490 return basic 491 } 492 493 return nil 494 }