github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/config/binder.go (about)

     1  package config
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/token"
     7  	"go/types"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  	"golang.org/x/tools/go/packages"
    11  
    12  	"github.com/niko0xdev/gqlgen/codegen/templates"
    13  	"github.com/niko0xdev/gqlgen/internal/code"
    14  )
    15  
    16  var ErrTypeNotFound = errors.New("unable to find type")
    17  
    18  // Binder connects graphql types to golang types using static analysis
    19  type Binder struct {
    20  	pkgs        *code.Packages
    21  	schema      *ast.Schema
    22  	cfg         *Config
    23  	tctx        *types.Context
    24  	References  []*TypeReference
    25  	SawInvalid  bool
    26  	objectCache map[string]map[string]types.Object
    27  }
    28  
    29  func (c *Config) NewBinder() *Binder {
    30  	return &Binder{
    31  		pkgs:   c.Packages,
    32  		schema: c.Schema,
    33  		cfg:    c,
    34  	}
    35  }
    36  
    37  func (b *Binder) TypePosition(typ types.Type) token.Position {
    38  	named, isNamed := typ.(*types.Named)
    39  	if !isNamed {
    40  		return token.Position{
    41  			Filename: "unknown",
    42  		}
    43  	}
    44  
    45  	return b.ObjectPosition(named.Obj())
    46  }
    47  
    48  func (b *Binder) ObjectPosition(typ types.Object) token.Position {
    49  	if typ == nil {
    50  		return token.Position{
    51  			Filename: "unknown",
    52  		}
    53  	}
    54  	pkg := b.pkgs.Load(typ.Pkg().Path())
    55  	return pkg.Fset.Position(typ.Pos())
    56  }
    57  
    58  func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
    59  	pkgName, typeName := code.PkgAndType(name)
    60  	return b.FindType(pkgName, typeName)
    61  }
    62  
    63  func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
    64  	if pkgName == "" {
    65  		if typeName == "map[string]interface{}" {
    66  			return MapType, nil
    67  		}
    68  
    69  		if typeName == "interface{}" {
    70  			return InterfaceType, nil
    71  		}
    72  	}
    73  
    74  	obj, err := b.FindObject(pkgName, typeName)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	if fun, isFunc := obj.(*types.Func); isFunc {
    80  		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
    81  	}
    82  	return obj.Type(), nil
    83  }
    84  
    85  func (b *Binder) InstantiateType(orig types.Type, targs []types.Type) (types.Type, error) {
    86  	if b.tctx == nil {
    87  		b.tctx = types.NewContext()
    88  	}
    89  
    90  	return types.Instantiate(b.tctx, orig, targs, false)
    91  }
    92  
    93  var (
    94  	MapType       = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
    95  	InterfaceType = types.NewInterfaceType(nil, nil)
    96  )
    97  
    98  func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
    99  	models := b.cfg.Models[name].Model
   100  	if len(models) == 0 {
   101  		return nil, fmt.Errorf(name + " not found in typemap")
   102  	}
   103  
   104  	if models[0] == "map[string]interface{}" {
   105  		return MapType, nil
   106  	}
   107  
   108  	if models[0] == "interface{}" {
   109  		return InterfaceType, nil
   110  	}
   111  
   112  	pkgName, typeName := code.PkgAndType(models[0])
   113  	if pkgName == "" {
   114  		return nil, fmt.Errorf("missing package name for %s", name)
   115  	}
   116  
   117  	obj, err := b.FindObject(pkgName, typeName)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	return obj.Type(), nil
   123  }
   124  
   125  func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
   126  	if pkgName == "" {
   127  		return nil, fmt.Errorf("package cannot be nil")
   128  	}
   129  
   130  	pkg := b.pkgs.LoadWithTypes(pkgName)
   131  	if pkg == nil {
   132  		err := b.pkgs.Errors()
   133  		if err != nil {
   134  			return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
   135  		}
   136  		return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
   137  	}
   138  
   139  	if b.objectCache == nil {
   140  		b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
   141  	}
   142  
   143  	defsIndex, ok := b.objectCache[pkgName]
   144  	if !ok {
   145  		defsIndex = indexDefs(pkg)
   146  		b.objectCache[pkgName] = defsIndex
   147  	}
   148  
   149  	// function based marshalers take precedence
   150  	if val, ok := defsIndex["Marshal"+typeName]; ok {
   151  		return val, nil
   152  	}
   153  
   154  	if val, ok := defsIndex[typeName]; ok {
   155  		return val, nil
   156  	}
   157  
   158  	return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
   159  }
   160  
   161  func indexDefs(pkg *packages.Package) map[string]types.Object {
   162  	res := make(map[string]types.Object)
   163  
   164  	scope := pkg.Types.Scope()
   165  	for astNode, def := range pkg.TypesInfo.Defs {
   166  		// only look at defs in the top scope
   167  		if def == nil {
   168  			continue
   169  		}
   170  		parent := def.Parent()
   171  		if parent == nil || parent != scope {
   172  			continue
   173  		}
   174  
   175  		if _, ok := res[astNode.Name]; !ok {
   176  			// The above check may not be really needed, it is only here to have a consistent behavior with
   177  			// previous implementation of FindObject() function which only honored the first inclusion of a def.
   178  			// If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups.
   179  			res[astNode.Name] = def
   180  		}
   181  	}
   182  
   183  	return res
   184  }
   185  
   186  func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
   187  	newRef := *ref
   188  	newRef.GO = types.NewPointer(ref.GO)
   189  	b.References = append(b.References, &newRef)
   190  	return &newRef
   191  }
   192  
   193  // TypeReference is used by args and field types. The Definition can refer to both input and output types.
   194  type TypeReference struct {
   195  	Definition              *ast.Definition
   196  	GQL                     *ast.Type
   197  	GO                      types.Type  // Type of the field being bound. Could be a pointer or a value type of Target.
   198  	Target                  types.Type  // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
   199  	CastType                types.Type  // Before calling marshalling functions cast from/to this base type
   200  	Marshaler               *types.Func // When using external marshalling functions this will point to the Marshal function
   201  	Unmarshaler             *types.Func // When using external marshalling functions this will point to the Unmarshal function
   202  	IsMarshaler             bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
   203  	IsOmittable             bool        // Is the type wrapped with Omittable
   204  	IsContext               bool        // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
   205  	PointersInUmarshalInput bool        // Inverse values and pointers in return.
   206  	IsRoot                  bool        // Is the type a root level definition such as Query, Mutation or Subscription
   207  }
   208  
   209  func (ref *TypeReference) Elem() *TypeReference {
   210  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   211  		newRef := *ref
   212  		newRef.GO = p.Elem()
   213  		return &newRef
   214  	}
   215  
   216  	if ref.IsSlice() {
   217  		newRef := *ref
   218  		newRef.GO = ref.GO.(*types.Slice).Elem()
   219  		newRef.GQL = ref.GQL.Elem
   220  		return &newRef
   221  	}
   222  	return nil
   223  }
   224  
   225  func (ref *TypeReference) IsPtr() bool {
   226  	_, isPtr := ref.GO.(*types.Pointer)
   227  	return isPtr
   228  }
   229  
   230  // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
   231  func (ref *TypeReference) IsPtrToPtr() bool {
   232  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   233  		_, isPtr := p.Elem().(*types.Pointer)
   234  		return isPtr
   235  	}
   236  	return false
   237  }
   238  
   239  func (ref *TypeReference) IsNilable() bool {
   240  	return IsNilable(ref.GO)
   241  }
   242  
   243  func (ref *TypeReference) IsSlice() bool {
   244  	_, isSlice := ref.GO.(*types.Slice)
   245  	return ref.GQL.Elem != nil && isSlice
   246  }
   247  
   248  func (ref *TypeReference) IsPtrToSlice() bool {
   249  	if ref.IsPtr() {
   250  		_, isPointerToSlice := ref.GO.(*types.Pointer).Elem().(*types.Slice)
   251  		return isPointerToSlice
   252  	}
   253  	return false
   254  }
   255  
   256  func (ref *TypeReference) IsPtrToIntf() bool {
   257  	if ref.IsPtr() {
   258  		_, isPointerToInterface := ref.GO.(*types.Pointer).Elem().(*types.Interface)
   259  		return isPointerToInterface
   260  	}
   261  	return false
   262  }
   263  
   264  func (ref *TypeReference) IsNamed() bool {
   265  	_, isSlice := ref.GO.(*types.Named)
   266  	return isSlice
   267  }
   268  
   269  func (ref *TypeReference) IsStruct() bool {
   270  	_, isStruct := ref.GO.Underlying().(*types.Struct)
   271  	return isStruct
   272  }
   273  
   274  func (ref *TypeReference) IsScalar() bool {
   275  	return ref.Definition.Kind == ast.Scalar
   276  }
   277  
   278  func (ref *TypeReference) IsMap() bool {
   279  	return ref.GO == MapType
   280  }
   281  
   282  func (ref *TypeReference) UniquenessKey() string {
   283  	nullability := "O"
   284  	if ref.GQL.NonNull {
   285  		nullability = "N"
   286  	}
   287  
   288  	elemNullability := ""
   289  	if ref.GQL.Elem != nil && ref.GQL.Elem.NonNull {
   290  		// Fix for #896
   291  		elemNullability = "áš„"
   292  	}
   293  	return nullability + ref.Definition.Name + "2" + templates.TypeIdentifier(ref.GO) + elemNullability
   294  }
   295  
   296  func (ref *TypeReference) MarshalFunc() string {
   297  	if ref.Definition == nil {
   298  		panic(errors.New("Definition missing for " + ref.GQL.Name()))
   299  	}
   300  
   301  	if ref.Definition.Kind == ast.InputObject {
   302  		return ""
   303  	}
   304  
   305  	return "marshal" + ref.UniquenessKey()
   306  }
   307  
   308  func (ref *TypeReference) UnmarshalFunc() string {
   309  	if ref.Definition == nil {
   310  		panic(errors.New("Definition missing for " + ref.GQL.Name()))
   311  	}
   312  
   313  	if !ref.Definition.IsInputType() {
   314  		return ""
   315  	}
   316  
   317  	return "unmarshal" + ref.UniquenessKey()
   318  }
   319  
   320  func (ref *TypeReference) IsTargetNilable() bool {
   321  	return IsNilable(ref.Target)
   322  }
   323  
   324  func (b *Binder) PushRef(ret *TypeReference) {
   325  	b.References = append(b.References, ret)
   326  }
   327  
   328  func isMap(t types.Type) bool {
   329  	if t == nil {
   330  		return true
   331  	}
   332  	_, ok := t.(*types.Map)
   333  	return ok
   334  }
   335  
   336  func isIntf(t types.Type) bool {
   337  	if t == nil {
   338  		return true
   339  	}
   340  	_, ok := t.(*types.Interface)
   341  	return ok
   342  }
   343  
   344  func unwrapOmittable(t types.Type) (types.Type, bool) {
   345  	if t == nil {
   346  		return t, false
   347  	}
   348  	named, ok := t.(*types.Named)
   349  	if !ok {
   350  		return t, false
   351  	}
   352  	if named.Origin().String() != "github.com/niko0xdev/gqlgen/graphql.Omittable[T any]" {
   353  		return t, false
   354  	}
   355  	return named.TypeArgs().At(0), true
   356  }
   357  
   358  func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
   359  	if innerType, ok := unwrapOmittable(bindTarget); ok {
   360  		if schemaType.NonNull {
   361  			return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name())
   362  		}
   363  
   364  		ref, err := b.TypeReference(schemaType, innerType)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  
   369  		ref.IsOmittable = true
   370  		return ref, err
   371  	}
   372  
   373  	if !isValid(bindTarget) {
   374  		b.SawInvalid = true
   375  		return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
   376  	}
   377  
   378  	var pkgName, typeName string
   379  	def := b.schema.Types[schemaType.Name()]
   380  	defer func() {
   381  		if err == nil && ret != nil {
   382  			b.PushRef(ret)
   383  		}
   384  	}()
   385  
   386  	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
   387  		return nil, fmt.Errorf("%s was not found", schemaType.Name())
   388  	}
   389  
   390  	for _, model := range b.cfg.Models[schemaType.Name()].Model {
   391  		if model == "map[string]interface{}" {
   392  			if !isMap(bindTarget) {
   393  				continue
   394  			}
   395  			return &TypeReference{
   396  				Definition: def,
   397  				GQL:        schemaType,
   398  				GO:         MapType,
   399  				IsRoot:     b.cfg.IsRoot(def),
   400  			}, nil
   401  		}
   402  
   403  		if model == "interface{}" {
   404  			if !isIntf(bindTarget) {
   405  				continue
   406  			}
   407  			return &TypeReference{
   408  				Definition: def,
   409  				GQL:        schemaType,
   410  				GO:         InterfaceType,
   411  				IsRoot:     b.cfg.IsRoot(def),
   412  			}, nil
   413  		}
   414  
   415  		pkgName, typeName = code.PkgAndType(model)
   416  		if pkgName == "" {
   417  			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
   418  		}
   419  
   420  		ref := &TypeReference{
   421  			Definition: def,
   422  			GQL:        schemaType,
   423  			IsRoot:     b.cfg.IsRoot(def),
   424  		}
   425  
   426  		obj, err := b.FindObject(pkgName, typeName)
   427  		if err != nil {
   428  			return nil, err
   429  		}
   430  
   431  		if fun, isFunc := obj.(*types.Func); isFunc {
   432  			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
   433  			ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/niko0xdev/gqlgen/graphql.ContextMarshaler"
   434  			ref.Marshaler = fun
   435  			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
   436  		} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
   437  			ref.GO = obj.Type()
   438  			ref.IsContext = true
   439  			ref.IsMarshaler = true
   440  		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
   441  			ref.GO = obj.Type()
   442  			ref.IsMarshaler = true
   443  		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
   444  			// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
   445  
   446  			ref.GO = obj.Type()
   447  			ref.CastType = underlying
   448  
   449  			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
   450  			if err != nil {
   451  				return nil, err
   452  			}
   453  
   454  			ref.Marshaler = underlyingRef.Marshaler
   455  			ref.Unmarshaler = underlyingRef.Unmarshaler
   456  		} else {
   457  			ref.GO = obj.Type()
   458  		}
   459  
   460  		ref.Target = ref.GO
   461  		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
   462  
   463  		if bindTarget != nil {
   464  			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
   465  				continue
   466  			}
   467  			ref.GO = bindTarget
   468  		}
   469  
   470  		ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
   471  
   472  		return ref, nil
   473  	}
   474  
   475  	return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
   476  }
   477  
   478  func isValid(t types.Type) bool {
   479  	basic, isBasic := t.(*types.Basic)
   480  	if !isBasic {
   481  		return true
   482  	}
   483  	return basic.Kind() != types.Invalid
   484  }
   485  
   486  func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
   487  	if t.Elem != nil {
   488  		child := b.CopyModifiersFromAst(t.Elem, base)
   489  		if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
   490  			child = types.NewPointer(child)
   491  		}
   492  		return types.NewSlice(child)
   493  	}
   494  
   495  	var isInterface bool
   496  	if named, ok := base.(*types.Named); ok {
   497  		_, isInterface = named.Underlying().(*types.Interface)
   498  	}
   499  
   500  	if !isInterface && !IsNilable(base) && !t.NonNull {
   501  		return types.NewPointer(base)
   502  	}
   503  
   504  	return base
   505  }
   506  
   507  func IsNilable(t types.Type) bool {
   508  	if namedType, isNamed := t.(*types.Named); isNamed {
   509  		return IsNilable(namedType.Underlying())
   510  	}
   511  	_, isPtr := t.(*types.Pointer)
   512  	_, isMap := t.(*types.Map)
   513  	_, isInterface := t.(*types.Interface)
   514  	_, isSlice := t.(*types.Slice)
   515  	_, isChan := t.(*types.Chan)
   516  	return isPtr || isMap || isInterface || isSlice || isChan
   517  }
   518  
   519  func hasMethod(it types.Type, name string) bool {
   520  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   521  		it = ptr.Elem()
   522  	}
   523  	namedType, ok := it.(*types.Named)
   524  	if !ok {
   525  		return false
   526  	}
   527  
   528  	for i := 0; i < namedType.NumMethods(); i++ {
   529  		if namedType.Method(i).Name() == name {
   530  			return true
   531  		}
   532  	}
   533  	return false
   534  }
   535  
   536  func basicUnderlying(it types.Type) *types.Basic {
   537  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   538  		it = ptr.Elem()
   539  	}
   540  	namedType, ok := it.(*types.Named)
   541  	if !ok {
   542  		return nil
   543  	}
   544  
   545  	if basic, ok := namedType.Underlying().(*types.Basic); ok {
   546  		return basic
   547  	}
   548  
   549  	return nil
   550  }