github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/codegen/config/binder.go (about)

     1  package config
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/token"
     7  	"go/types"
     8  
     9  	"golang.org/x/tools/go/packages"
    10  
    11  	"github.com/mstephano/gqlgen-schemagen/codegen/templates"
    12  	"github.com/mstephano/gqlgen-schemagen/internal/code"
    13  	"github.com/vektah/gqlparser/v2/ast"
    14  )
    15  
    16  var ErrTypeNotFound = errors.New("unable to find type")
    17  
    18  // Binder connects graphql types to golang types using static analysis
    19  type Binder struct {
    20  	pkgs        *code.Packages
    21  	schema      *ast.Schema
    22  	cfg         *Config
    23  	References  []*TypeReference
    24  	SawInvalid  bool
    25  	objectCache map[string]map[string]types.Object
    26  }
    27  
    28  func (c *Config) NewBinder() *Binder {
    29  	return &Binder{
    30  		pkgs:   c.Packages,
    31  		schema: c.Schema,
    32  		cfg:    c,
    33  	}
    34  }
    35  
    36  func (b *Binder) TypePosition(typ types.Type) token.Position {
    37  	named, isNamed := typ.(*types.Named)
    38  	if !isNamed {
    39  		return token.Position{
    40  			Filename: "unknown",
    41  		}
    42  	}
    43  
    44  	return b.ObjectPosition(named.Obj())
    45  }
    46  
    47  func (b *Binder) ObjectPosition(typ types.Object) token.Position {
    48  	if typ == nil {
    49  		return token.Position{
    50  			Filename: "unknown",
    51  		}
    52  	}
    53  	pkg := b.pkgs.Load(typ.Pkg().Path())
    54  	return pkg.Fset.Position(typ.Pos())
    55  }
    56  
    57  func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
    58  	pkgName, typeName := code.PkgAndType(name)
    59  	return b.FindType(pkgName, typeName)
    60  }
    61  
    62  func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
    63  	if pkgName == "" {
    64  		if typeName == "map[string]interface{}" {
    65  			return MapType, nil
    66  		}
    67  
    68  		if typeName == "interface{}" {
    69  			return InterfaceType, nil
    70  		}
    71  	}
    72  
    73  	obj, err := b.FindObject(pkgName, typeName)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	if fun, isFunc := obj.(*types.Func); isFunc {
    79  		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
    80  	}
    81  	return obj.Type(), nil
    82  }
    83  
    84  var (
    85  	MapType       = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
    86  	InterfaceType = types.NewInterfaceType(nil, nil)
    87  )
    88  
    89  func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
    90  	models := b.cfg.Models[name].Model
    91  	if len(models) == 0 {
    92  		return nil, fmt.Errorf(name + " not found in typemap")
    93  	}
    94  
    95  	if models[0] == "map[string]interface{}" {
    96  		return MapType, nil
    97  	}
    98  
    99  	if models[0] == "interface{}" {
   100  		return InterfaceType, nil
   101  	}
   102  
   103  	pkgName, typeName := code.PkgAndType(models[0])
   104  	if pkgName == "" {
   105  		return nil, fmt.Errorf("missing package name for %s", name)
   106  	}
   107  
   108  	obj, err := b.FindObject(pkgName, typeName)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	return obj.Type(), nil
   114  }
   115  
   116  func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
   117  	if pkgName == "" {
   118  		return nil, fmt.Errorf("package cannot be nil")
   119  	}
   120  
   121  	pkg := b.pkgs.LoadWithTypes(pkgName)
   122  	if pkg == nil {
   123  		err := b.pkgs.Errors()
   124  		if err != nil {
   125  			return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
   126  		}
   127  		return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
   128  	}
   129  
   130  	if b.objectCache == nil {
   131  		b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
   132  	}
   133  
   134  	defsIndex, ok := b.objectCache[pkgName]
   135  	if !ok {
   136  		defsIndex = indexDefs(pkg)
   137  		b.objectCache[pkgName] = defsIndex
   138  	}
   139  
   140  	// function based marshalers take precedence
   141  	if val, ok := defsIndex["Marshal"+typeName]; ok {
   142  		return val, nil
   143  	}
   144  
   145  	if val, ok := defsIndex[typeName]; ok {
   146  		return val, nil
   147  	}
   148  
   149  	return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
   150  }
   151  
   152  func indexDefs(pkg *packages.Package) map[string]types.Object {
   153  	res := make(map[string]types.Object)
   154  
   155  	scope := pkg.Types.Scope()
   156  	for astNode, def := range pkg.TypesInfo.Defs {
   157  		// only look at defs in the top scope
   158  		if def == nil {
   159  			continue
   160  		}
   161  		parent := def.Parent()
   162  		if parent == nil || parent != scope {
   163  			continue
   164  		}
   165  
   166  		if _, ok := res[astNode.Name]; !ok {
   167  			// The above check may not be really needed, it is only here to have a consistent behavior with
   168  			// previous implementation of FindObject() function which only honored the first inclusion of a def.
   169  			// If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups.
   170  			res[astNode.Name] = def
   171  		}
   172  	}
   173  
   174  	return res
   175  }
   176  
   177  func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
   178  	newRef := *ref
   179  	newRef.GO = types.NewPointer(ref.GO)
   180  	b.References = append(b.References, &newRef)
   181  	return &newRef
   182  }
   183  
   184  // TypeReference is used by args and field types. The Definition can refer to both input and output types.
   185  type TypeReference struct {
   186  	Definition              *ast.Definition
   187  	GQL                     *ast.Type
   188  	GO                      types.Type  // Type of the field being bound. Could be a pointer or a value type of Target.
   189  	Target                  types.Type  // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
   190  	CastType                types.Type  // Before calling marshalling functions cast from/to this base type
   191  	Marshaler               *types.Func // When using external marshalling functions this will point to the Marshal function
   192  	Unmarshaler             *types.Func // When using external marshalling functions this will point to the Unmarshal function
   193  	IsMarshaler             bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
   194  	IsContext               bool        // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
   195  	PointersInUmarshalInput bool        // Inverse values and pointers in return.
   196  }
   197  
   198  func (ref *TypeReference) Elem() *TypeReference {
   199  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   200  		newRef := *ref
   201  		newRef.GO = p.Elem()
   202  		return &newRef
   203  	}
   204  
   205  	if ref.IsSlice() {
   206  		newRef := *ref
   207  		newRef.GO = ref.GO.(*types.Slice).Elem()
   208  		newRef.GQL = ref.GQL.Elem
   209  		return &newRef
   210  	}
   211  	return nil
   212  }
   213  
   214  func (t *TypeReference) IsPtr() bool {
   215  	_, isPtr := t.GO.(*types.Pointer)
   216  	return isPtr
   217  }
   218  
   219  // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
   220  func (t *TypeReference) IsPtrToPtr() bool {
   221  	if p, isPtr := t.GO.(*types.Pointer); isPtr {
   222  		_, isPtr := p.Elem().(*types.Pointer)
   223  		return isPtr
   224  	}
   225  	return false
   226  }
   227  
   228  func (t *TypeReference) IsNilable() bool {
   229  	return IsNilable(t.GO)
   230  }
   231  
   232  func (t *TypeReference) IsSlice() bool {
   233  	_, isSlice := t.GO.(*types.Slice)
   234  	return t.GQL.Elem != nil && isSlice
   235  }
   236  
   237  func (t *TypeReference) IsPtrToSlice() bool {
   238  	if t.IsPtr() {
   239  		_, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice)
   240  		return isPointerToSlice
   241  	}
   242  	return false
   243  }
   244  
   245  func (t *TypeReference) IsNamed() bool {
   246  	_, isSlice := t.GO.(*types.Named)
   247  	return isSlice
   248  }
   249  
   250  func (t *TypeReference) IsStruct() bool {
   251  	_, isStruct := t.GO.Underlying().(*types.Struct)
   252  	return isStruct
   253  }
   254  
   255  func (t *TypeReference) IsUnderlyingBasic() bool {
   256  	_, isUnderlyingBasic := t.GO.Underlying().(*types.Basic)
   257  	return isUnderlyingBasic
   258  }
   259  
   260  func (t *TypeReference) IsScalarID() bool {
   261  	return t.Definition.Kind == ast.Scalar && t.Marshaler.Name() == "MarshalID"
   262  }
   263  
   264  func (t *TypeReference) IsScalar() bool {
   265  	return t.Definition.Kind == ast.Scalar
   266  }
   267  
   268  func (t *TypeReference) UniquenessKey() string {
   269  	nullability := "O"
   270  	if t.GQL.NonNull {
   271  		nullability = "N"
   272  	}
   273  
   274  	elemNullability := ""
   275  	if t.GQL.Elem != nil && t.GQL.Elem.NonNull {
   276  		// Fix for #896
   277  		elemNullability = "áš„"
   278  	}
   279  	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability
   280  }
   281  
   282  func (t *TypeReference) MarshalFunc() string {
   283  	if t.Definition == nil {
   284  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   285  	}
   286  
   287  	if t.Definition.Kind == ast.InputObject {
   288  		return ""
   289  	}
   290  
   291  	return "marshal" + t.UniquenessKey()
   292  }
   293  
   294  func (t *TypeReference) UnmarshalFunc() string {
   295  	if t.Definition == nil {
   296  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   297  	}
   298  
   299  	if !t.Definition.IsInputType() {
   300  		return ""
   301  	}
   302  
   303  	return "unmarshal" + t.UniquenessKey()
   304  }
   305  
   306  func (t *TypeReference) IsTargetNilable() bool {
   307  	return IsNilable(t.Target)
   308  }
   309  
   310  func (b *Binder) PushRef(ret *TypeReference) {
   311  	b.References = append(b.References, ret)
   312  }
   313  
   314  func isMap(t types.Type) bool {
   315  	if t == nil {
   316  		return true
   317  	}
   318  	_, ok := t.(*types.Map)
   319  	return ok
   320  }
   321  
   322  func isIntf(t types.Type) bool {
   323  	if t == nil {
   324  		return true
   325  	}
   326  	_, ok := t.(*types.Interface)
   327  	return ok
   328  }
   329  
   330  func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
   331  	if !isValid(bindTarget) {
   332  		b.SawInvalid = true
   333  		return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
   334  	}
   335  
   336  	var pkgName, typeName string
   337  	def := b.schema.Types[schemaType.Name()]
   338  	defer func() {
   339  		if err == nil && ret != nil {
   340  			b.PushRef(ret)
   341  		}
   342  	}()
   343  
   344  	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
   345  		return nil, fmt.Errorf("%s was not found", schemaType.Name())
   346  	}
   347  
   348  	for _, model := range b.cfg.Models[schemaType.Name()].Model {
   349  		if model == "map[string]interface{}" {
   350  			if !isMap(bindTarget) {
   351  				continue
   352  			}
   353  			return &TypeReference{
   354  				Definition: def,
   355  				GQL:        schemaType,
   356  				GO:         MapType,
   357  			}, nil
   358  		}
   359  
   360  		if model == "interface{}" {
   361  			if !isIntf(bindTarget) {
   362  				continue
   363  			}
   364  			return &TypeReference{
   365  				Definition: def,
   366  				GQL:        schemaType,
   367  				GO:         InterfaceType,
   368  			}, nil
   369  		}
   370  
   371  		pkgName, typeName = code.PkgAndType(model)
   372  		if pkgName == "" {
   373  			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
   374  		}
   375  
   376  		ref := &TypeReference{
   377  			Definition: def,
   378  			GQL:        schemaType,
   379  		}
   380  
   381  		obj, err := b.FindObject(pkgName, typeName)
   382  		if err != nil {
   383  			return nil, err
   384  		}
   385  
   386  		if fun, isFunc := obj.(*types.Func); isFunc {
   387  			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
   388  			ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/mstephano/gqlgen-schemagen/graphql.ContextMarshaler"
   389  			ref.Marshaler = fun
   390  			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
   391  		} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
   392  			ref.GO = obj.Type()
   393  			ref.IsContext = true
   394  			ref.IsMarshaler = true
   395  		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
   396  			ref.GO = obj.Type()
   397  			ref.IsMarshaler = true
   398  		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
   399  			// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
   400  
   401  			ref.GO = obj.Type()
   402  			ref.CastType = underlying
   403  
   404  			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
   405  			if err != nil {
   406  				return nil, err
   407  			}
   408  
   409  			ref.Marshaler = underlyingRef.Marshaler
   410  			ref.Unmarshaler = underlyingRef.Unmarshaler
   411  		} else {
   412  			ref.GO = obj.Type()
   413  		}
   414  
   415  		ref.Target = ref.GO
   416  		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
   417  
   418  		if bindTarget != nil {
   419  			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
   420  				continue
   421  			}
   422  			ref.GO = bindTarget
   423  		}
   424  
   425  		ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
   426  
   427  		return ref, nil
   428  	}
   429  
   430  	return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
   431  }
   432  
   433  func isValid(t types.Type) bool {
   434  	basic, isBasic := t.(*types.Basic)
   435  	if !isBasic {
   436  		return true
   437  	}
   438  	return basic.Kind() != types.Invalid
   439  }
   440  
   441  func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
   442  	if t.Elem != nil {
   443  		child := b.CopyModifiersFromAst(t.Elem, base)
   444  		if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
   445  			child = types.NewPointer(child)
   446  		}
   447  		return types.NewSlice(child)
   448  	}
   449  
   450  	var isInterface bool
   451  	if named, ok := base.(*types.Named); ok {
   452  		_, isInterface = named.Underlying().(*types.Interface)
   453  	}
   454  
   455  	if !isInterface && !IsNilable(base) && !t.NonNull {
   456  		return types.NewPointer(base)
   457  	}
   458  
   459  	return base
   460  }
   461  
   462  func IsNilable(t types.Type) bool {
   463  	if namedType, isNamed := t.(*types.Named); isNamed {
   464  		return IsNilable(namedType.Underlying())
   465  	}
   466  	_, isPtr := t.(*types.Pointer)
   467  	_, isMap := t.(*types.Map)
   468  	_, isInterface := t.(*types.Interface)
   469  	_, isSlice := t.(*types.Slice)
   470  	_, isChan := t.(*types.Chan)
   471  	return isPtr || isMap || isInterface || isSlice || isChan
   472  }
   473  
   474  func hasMethod(it types.Type, name string) bool {
   475  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   476  		it = ptr.Elem()
   477  	}
   478  	namedType, ok := it.(*types.Named)
   479  	if !ok {
   480  		return false
   481  	}
   482  
   483  	for i := 0; i < namedType.NumMethods(); i++ {
   484  		if namedType.Method(i).Name() == name {
   485  			return true
   486  		}
   487  	}
   488  	return false
   489  }
   490  
   491  func basicUnderlying(it types.Type) *types.Basic {
   492  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   493  		it = ptr.Elem()
   494  	}
   495  	namedType, ok := it.(*types.Named)
   496  	if !ok {
   497  		return nil
   498  	}
   499  
   500  	if basic, ok := namedType.Underlying().(*types.Basic); ok {
   501  		return basic
   502  	}
   503  
   504  	return nil
   505  }