github.com/senomas/gqlgen@v0.17.11-0.20220626120754-9aee61b0716a/codegen/config/binder.go (about)

     1  package config
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/token"
     7  	"go/types"
     8  
     9  	"golang.org/x/tools/go/packages"
    10  
    11  	"github.com/99designs/gqlgen/codegen/templates"
    12  	"github.com/99designs/gqlgen/internal/code"
    13  	"github.com/vektah/gqlparser/v2/ast"
    14  )
    15  
    16  var ErrTypeNotFound = errors.New("unable to find type")
    17  
    18  // Binder connects graphql types to golang types using static analysis
    19  type Binder struct {
    20  	pkgs        *code.Packages
    21  	schema      *ast.Schema
    22  	cfg         *Config
    23  	References  []*TypeReference
    24  	SawInvalid  bool
    25  	objectCache map[string]map[string]types.Object
    26  }
    27  
    28  func (c *Config) NewBinder() *Binder {
    29  	return &Binder{
    30  		pkgs:   c.Packages,
    31  		schema: c.Schema,
    32  		cfg:    c,
    33  	}
    34  }
    35  
    36  func (b *Binder) TypePosition(typ types.Type) token.Position {
    37  	named, isNamed := typ.(*types.Named)
    38  	if !isNamed {
    39  		return token.Position{
    40  			Filename: "unknown",
    41  		}
    42  	}
    43  
    44  	return b.ObjectPosition(named.Obj())
    45  }
    46  
    47  func (b *Binder) ObjectPosition(typ types.Object) token.Position {
    48  	if typ == nil {
    49  		return token.Position{
    50  			Filename: "unknown",
    51  		}
    52  	}
    53  	pkg := b.pkgs.Load(typ.Pkg().Path())
    54  	return pkg.Fset.Position(typ.Pos())
    55  }
    56  
    57  func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
    58  	pkgName, typeName := code.PkgAndType(name)
    59  	return b.FindType(pkgName, typeName)
    60  }
    61  
    62  func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
    63  	if pkgName == "" {
    64  		if typeName == "map[string]interface{}" {
    65  			return MapType, nil
    66  		}
    67  
    68  		if typeName == "interface{}" {
    69  			return InterfaceType, nil
    70  		}
    71  	}
    72  
    73  	obj, err := b.FindObject(pkgName, typeName)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	if fun, isFunc := obj.(*types.Func); isFunc {
    79  		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
    80  	}
    81  	return obj.Type(), nil
    82  }
    83  
    84  var (
    85  	MapType       = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
    86  	InterfaceType = types.NewInterfaceType(nil, nil)
    87  )
    88  
    89  func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
    90  	models := b.cfg.Models[name].Model
    91  	if len(models) == 0 {
    92  		return nil, fmt.Errorf(name + " not found in typemap")
    93  	}
    94  
    95  	if models[0] == "map[string]interface{}" {
    96  		return MapType, nil
    97  	}
    98  
    99  	if models[0] == "interface{}" {
   100  		return InterfaceType, nil
   101  	}
   102  
   103  	pkgName, typeName := code.PkgAndType(models[0])
   104  	if pkgName == "" {
   105  		return nil, fmt.Errorf("missing package name for %s", name)
   106  	}
   107  
   108  	obj, err := b.FindObject(pkgName, typeName)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	return obj.Type(), nil
   114  }
   115  
   116  func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
   117  	if pkgName == "" {
   118  		return nil, fmt.Errorf("package cannot be nil")
   119  	}
   120  
   121  	pkg := b.pkgs.LoadWithTypes(pkgName)
   122  	if pkg == nil {
   123  		err := b.pkgs.Errors()
   124  		if err != nil {
   125  			return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
   126  		}
   127  		return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
   128  	}
   129  
   130  	if b.objectCache == nil {
   131  		b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
   132  	}
   133  
   134  	defsIndex, ok := b.objectCache[pkgName]
   135  	if !ok {
   136  		defsIndex = indexDefs(pkg)
   137  		b.objectCache[pkgName] = defsIndex
   138  	}
   139  
   140  	// function based marshalers take precedence
   141  	if val, ok := defsIndex["Marshal"+typeName]; ok {
   142  		return val, nil
   143  	}
   144  
   145  	if val, ok := defsIndex[typeName]; ok {
   146  		return val, nil
   147  	}
   148  
   149  	return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
   150  }
   151  
   152  func indexDefs(pkg *packages.Package) map[string]types.Object {
   153  	res := make(map[string]types.Object)
   154  
   155  	scope := pkg.Types.Scope()
   156  	for astNode, def := range pkg.TypesInfo.Defs {
   157  		// only look at defs in the top scope
   158  		if def == nil {
   159  			continue
   160  		}
   161  		parent := def.Parent()
   162  		if parent == nil || parent != scope {
   163  			continue
   164  		}
   165  
   166  		if _, ok := res[astNode.Name]; !ok {
   167  			// The above check may not be really needed, it is only here to have a consistent behavior with
   168  			// previous implementation of FindObject() function which only honored the first inclusion of a def.
   169  			// If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups.
   170  			res[astNode.Name] = def
   171  		}
   172  	}
   173  
   174  	return res
   175  }
   176  
   177  func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
   178  	newRef := *ref
   179  	newRef.GO = types.NewPointer(ref.GO)
   180  	b.References = append(b.References, &newRef)
   181  	return &newRef
   182  }
   183  
   184  // TypeReference is used by args and field types. The Definition can refer to both input and output types.
   185  type TypeReference struct {
   186  	Definition  *ast.Definition
   187  	GQL         *ast.Type
   188  	GO          types.Type  // Type of the field being bound. Could be a pointer or a value type of Target.
   189  	Target      types.Type  // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
   190  	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
   191  	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
   192  	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
   193  	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
   194  	IsContext   bool        // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
   195  }
   196  
   197  func (ref *TypeReference) Elem() *TypeReference {
   198  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   199  		newRef := *ref
   200  		newRef.GO = p.Elem()
   201  		return &newRef
   202  	}
   203  
   204  	if ref.IsSlice() {
   205  		newRef := *ref
   206  		newRef.GO = ref.GO.(*types.Slice).Elem()
   207  		newRef.GQL = ref.GQL.Elem
   208  		return &newRef
   209  	}
   210  	return nil
   211  }
   212  
   213  func (t *TypeReference) IsPtr() bool {
   214  	_, isPtr := t.GO.(*types.Pointer)
   215  	return isPtr
   216  }
   217  
   218  // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
   219  //
   220  func (t *TypeReference) IsPtrToPtr() bool {
   221  	if p, isPtr := t.GO.(*types.Pointer); isPtr {
   222  		_, isPtr := p.Elem().(*types.Pointer)
   223  		return isPtr
   224  	}
   225  	return false
   226  }
   227  
   228  func (t *TypeReference) IsNilable() bool {
   229  	return IsNilable(t.GO)
   230  }
   231  
   232  func (t *TypeReference) IsSlice() bool {
   233  	_, isSlice := t.GO.(*types.Slice)
   234  	return t.GQL.Elem != nil && isSlice
   235  }
   236  
   237  func (t *TypeReference) IsPtrToSlice() bool {
   238  	if t.IsPtr() {
   239  		_, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice)
   240  		return isPointerToSlice
   241  	}
   242  	return false
   243  }
   244  
   245  func (t *TypeReference) IsNamed() bool {
   246  	_, isSlice := t.GO.(*types.Named)
   247  	return isSlice
   248  }
   249  
   250  func (t *TypeReference) IsStruct() bool {
   251  	_, isStruct := t.GO.Underlying().(*types.Struct)
   252  	return isStruct
   253  }
   254  
   255  func (t *TypeReference) IsScalar() bool {
   256  	return t.Definition.Kind == ast.Scalar
   257  }
   258  
   259  func (t *TypeReference) UniquenessKey() string {
   260  	nullability := "O"
   261  	if t.GQL.NonNull {
   262  		nullability = "N"
   263  	}
   264  
   265  	elemNullability := ""
   266  	if t.GQL.Elem != nil && t.GQL.Elem.NonNull {
   267  		// Fix for #896
   268  		elemNullability = "áš„"
   269  	}
   270  	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability
   271  }
   272  
   273  func (t *TypeReference) MarshalFunc() string {
   274  	if t.Definition == nil {
   275  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   276  	}
   277  
   278  	if t.Definition.Kind == ast.InputObject {
   279  		return ""
   280  	}
   281  
   282  	return "marshal" + t.UniquenessKey()
   283  }
   284  
   285  func (t *TypeReference) UnmarshalFunc() string {
   286  	if t.Definition == nil {
   287  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   288  	}
   289  
   290  	if !t.Definition.IsInputType() {
   291  		return ""
   292  	}
   293  
   294  	return "unmarshal" + t.UniquenessKey()
   295  }
   296  
   297  func (t *TypeReference) IsTargetNilable() bool {
   298  	return IsNilable(t.Target)
   299  }
   300  
   301  func (b *Binder) PushRef(ret *TypeReference) {
   302  	b.References = append(b.References, ret)
   303  }
   304  
   305  func isMap(t types.Type) bool {
   306  	if t == nil {
   307  		return true
   308  	}
   309  	_, ok := t.(*types.Map)
   310  	return ok
   311  }
   312  
   313  func isIntf(t types.Type) bool {
   314  	if t == nil {
   315  		return true
   316  	}
   317  	_, ok := t.(*types.Interface)
   318  	return ok
   319  }
   320  
   321  func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
   322  	if !isValid(bindTarget) {
   323  		b.SawInvalid = true
   324  		return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
   325  	}
   326  
   327  	var pkgName, typeName string
   328  	def := b.schema.Types[schemaType.Name()]
   329  	defer func() {
   330  		if err == nil && ret != nil {
   331  			b.PushRef(ret)
   332  		}
   333  	}()
   334  
   335  	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
   336  		return nil, fmt.Errorf("%s was not found", schemaType.Name())
   337  	}
   338  
   339  	for _, model := range b.cfg.Models[schemaType.Name()].Model {
   340  		if model == "map[string]interface{}" {
   341  			if !isMap(bindTarget) {
   342  				continue
   343  			}
   344  			return &TypeReference{
   345  				Definition: def,
   346  				GQL:        schemaType,
   347  				GO:         MapType,
   348  			}, nil
   349  		}
   350  
   351  		if model == "interface{}" {
   352  			if !isIntf(bindTarget) {
   353  				continue
   354  			}
   355  			return &TypeReference{
   356  				Definition: def,
   357  				GQL:        schemaType,
   358  				GO:         InterfaceType,
   359  			}, nil
   360  		}
   361  
   362  		pkgName, typeName = code.PkgAndType(model)
   363  		if pkgName == "" {
   364  			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
   365  		}
   366  
   367  		ref := &TypeReference{
   368  			Definition: def,
   369  			GQL:        schemaType,
   370  		}
   371  
   372  		obj, err := b.FindObject(pkgName, typeName)
   373  		if err != nil {
   374  			return nil, err
   375  		}
   376  
   377  		if fun, isFunc := obj.(*types.Func); isFunc {
   378  			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
   379  			ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
   380  			ref.Marshaler = fun
   381  			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
   382  		} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
   383  			ref.GO = obj.Type()
   384  			ref.IsContext = true
   385  			ref.IsMarshaler = true
   386  		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
   387  			ref.GO = obj.Type()
   388  			ref.IsMarshaler = true
   389  		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
   390  			// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
   391  
   392  			ref.GO = obj.Type()
   393  			ref.CastType = underlying
   394  
   395  			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
   396  			if err != nil {
   397  				return nil, err
   398  			}
   399  
   400  			ref.Marshaler = underlyingRef.Marshaler
   401  			ref.Unmarshaler = underlyingRef.Unmarshaler
   402  		} else {
   403  			ref.GO = obj.Type()
   404  		}
   405  
   406  		ref.Target = ref.GO
   407  		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
   408  
   409  		if bindTarget != nil {
   410  			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
   411  				continue
   412  			}
   413  			ref.GO = bindTarget
   414  		}
   415  
   416  		return ref, nil
   417  	}
   418  
   419  	return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
   420  }
   421  
   422  func isValid(t types.Type) bool {
   423  	basic, isBasic := t.(*types.Basic)
   424  	if !isBasic {
   425  		return true
   426  	}
   427  	return basic.Kind() != types.Invalid
   428  }
   429  
   430  func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
   431  	if t.Elem != nil {
   432  		child := b.CopyModifiersFromAst(t.Elem, base)
   433  		if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
   434  			child = types.NewPointer(child)
   435  		}
   436  		return types.NewSlice(child)
   437  	}
   438  
   439  	var isInterface bool
   440  	if named, ok := base.(*types.Named); ok {
   441  		_, isInterface = named.Underlying().(*types.Interface)
   442  	}
   443  
   444  	if !isInterface && !IsNilable(base) && !t.NonNull {
   445  		return types.NewPointer(base)
   446  	}
   447  
   448  	return base
   449  }
   450  
   451  func IsNilable(t types.Type) bool {
   452  	if namedType, isNamed := t.(*types.Named); isNamed {
   453  		return IsNilable(namedType.Underlying())
   454  	}
   455  	_, isPtr := t.(*types.Pointer)
   456  	_, isMap := t.(*types.Map)
   457  	_, isInterface := t.(*types.Interface)
   458  	_, isSlice := t.(*types.Slice)
   459  	_, isChan := t.(*types.Chan)
   460  	return isPtr || isMap || isInterface || isSlice || isChan
   461  }
   462  
   463  func hasMethod(it types.Type, name string) bool {
   464  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   465  		it = ptr.Elem()
   466  	}
   467  	namedType, ok := it.(*types.Named)
   468  	if !ok {
   469  		return false
   470  	}
   471  
   472  	for i := 0; i < namedType.NumMethods(); i++ {
   473  		if namedType.Method(i).Name() == name {
   474  			return true
   475  		}
   476  	}
   477  	return false
   478  }
   479  
   480  func basicUnderlying(it types.Type) *types.Basic {
   481  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   482  		it = ptr.Elem()
   483  	}
   484  	namedType, ok := it.(*types.Named)
   485  	if !ok {
   486  		return nil
   487  	}
   488  
   489  	if basic, ok := namedType.Underlying().(*types.Basic); ok {
   490  		return basic
   491  	}
   492  
   493  	return nil
   494  }