github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/codegen/config/binder.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"go/types"
     7  
     8  	"github.com/99designs/gqlgen/codegen/templates"
     9  	"github.com/99designs/gqlgen/internal/code"
    10  	"github.com/pkg/errors"
    11  	"github.com/vektah/gqlparser/v2/ast"
    12  )
    13  
    14  // Binder connects graphql types to golang types using static analysis
    15  type Binder struct {
    16  	pkgs       *code.Packages
    17  	schema     *ast.Schema
    18  	cfg        *Config
    19  	References []*TypeReference
    20  	SawInvalid bool
    21  }
    22  
    23  func (c *Config) NewBinder() *Binder {
    24  	return &Binder{
    25  		pkgs:   c.Packages,
    26  		schema: c.Schema,
    27  		cfg:    c,
    28  	}
    29  }
    30  
    31  func (b *Binder) TypePosition(typ types.Type) token.Position {
    32  	named, isNamed := typ.(*types.Named)
    33  	if !isNamed {
    34  		return token.Position{
    35  			Filename: "unknown",
    36  		}
    37  	}
    38  
    39  	return b.ObjectPosition(named.Obj())
    40  }
    41  
    42  func (b *Binder) ObjectPosition(typ types.Object) token.Position {
    43  	if typ == nil {
    44  		return token.Position{
    45  			Filename: "unknown",
    46  		}
    47  	}
    48  	pkg := b.pkgs.Load(typ.Pkg().Path())
    49  	return pkg.Fset.Position(typ.Pos())
    50  }
    51  
    52  func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
    53  	pkgName, typeName := code.PkgAndType(name)
    54  	return b.FindType(pkgName, typeName)
    55  }
    56  
    57  func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
    58  	if pkgName == "" {
    59  		if typeName == "map[string]interface{}" {
    60  			return MapType, nil
    61  		}
    62  
    63  		if typeName == "interface{}" {
    64  			return InterfaceType, nil
    65  		}
    66  	}
    67  
    68  	obj, err := b.FindObject(pkgName, typeName)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	if fun, isFunc := obj.(*types.Func); isFunc {
    74  		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
    75  	}
    76  	return obj.Type(), nil
    77  }
    78  
    79  var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
    80  var InterfaceType = types.NewInterfaceType(nil, nil)
    81  
    82  func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
    83  	models := b.cfg.Models[name].Model
    84  	if len(models) == 0 {
    85  		return nil, fmt.Errorf(name + " not found in typemap")
    86  	}
    87  
    88  	if models[0] == "map[string]interface{}" {
    89  		return MapType, nil
    90  	}
    91  
    92  	if models[0] == "interface{}" {
    93  		return InterfaceType, nil
    94  	}
    95  
    96  	pkgName, typeName := code.PkgAndType(models[0])
    97  	if pkgName == "" {
    98  		return nil, fmt.Errorf("missing package name for %s", name)
    99  	}
   100  
   101  	obj, err := b.FindObject(pkgName, typeName)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	return obj.Type(), nil
   107  }
   108  
   109  func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
   110  	if pkgName == "" {
   111  		return nil, fmt.Errorf("package cannot be nil")
   112  	}
   113  	fullName := typeName
   114  	if pkgName != "" {
   115  		fullName = pkgName + "." + typeName
   116  	}
   117  
   118  	pkg := b.pkgs.LoadWithTypes(pkgName)
   119  	if pkg == nil {
   120  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
   121  	}
   122  
   123  	// function based marshalers take precedence
   124  	for astNode, def := range pkg.TypesInfo.Defs {
   125  		// only look at defs in the top scope
   126  		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
   127  			continue
   128  		}
   129  
   130  		if astNode.Name == "Marshal"+typeName {
   131  			return def, nil
   132  		}
   133  	}
   134  
   135  	// then look for types directly
   136  	for astNode, def := range pkg.TypesInfo.Defs {
   137  		// only look at defs in the top scope
   138  		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
   139  			continue
   140  		}
   141  
   142  		if astNode.Name == typeName {
   143  			return def, nil
   144  		}
   145  	}
   146  
   147  	return nil, errors.Errorf("unable to find type %s\n", fullName)
   148  }
   149  
   150  func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
   151  	newRef := &TypeReference{
   152  		GO:          types.NewPointer(ref.GO),
   153  		GQL:         ref.GQL,
   154  		CastType:    ref.CastType,
   155  		Definition:  ref.Definition,
   156  		Unmarshaler: ref.Unmarshaler,
   157  		Marshaler:   ref.Marshaler,
   158  		IsMarshaler: ref.IsMarshaler,
   159  	}
   160  
   161  	b.References = append(b.References, newRef)
   162  	return newRef
   163  }
   164  
   165  // TypeReference is used by args and field types. The Definition can refer to both input and output types.
   166  type TypeReference struct {
   167  	Definition  *ast.Definition
   168  	GQL         *ast.Type
   169  	GO          types.Type  // Type of the field being bound. Could be a pointer or a value type of Target.
   170  	Target      types.Type  // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
   171  	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
   172  	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
   173  	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
   174  	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
   175  }
   176  
   177  func (ref *TypeReference) Elem() *TypeReference {
   178  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   179  		return &TypeReference{
   180  			GO:          p.Elem(),
   181  			Target:      ref.Target,
   182  			GQL:         ref.GQL,
   183  			CastType:    ref.CastType,
   184  			Definition:  ref.Definition,
   185  			Unmarshaler: ref.Unmarshaler,
   186  			Marshaler:   ref.Marshaler,
   187  			IsMarshaler: ref.IsMarshaler,
   188  		}
   189  	}
   190  
   191  	if ref.IsSlice() {
   192  		return &TypeReference{
   193  			GO:          ref.GO.(*types.Slice).Elem(),
   194  			Target:      ref.Target,
   195  			GQL:         ref.GQL.Elem,
   196  			CastType:    ref.CastType,
   197  			Definition:  ref.Definition,
   198  			Unmarshaler: ref.Unmarshaler,
   199  			Marshaler:   ref.Marshaler,
   200  			IsMarshaler: ref.IsMarshaler,
   201  		}
   202  	}
   203  	return nil
   204  }
   205  
   206  func (t *TypeReference) IsPtr() bool {
   207  	_, isPtr := t.GO.(*types.Pointer)
   208  	return isPtr
   209  }
   210  
   211  func (t *TypeReference) IsNilable() bool {
   212  	return IsNilable(t.GO)
   213  }
   214  
   215  func (t *TypeReference) IsSlice() bool {
   216  	_, isSlice := t.GO.(*types.Slice)
   217  	return t.GQL.Elem != nil && isSlice
   218  }
   219  
   220  func (t *TypeReference) IsNamed() bool {
   221  	_, isSlice := t.GO.(*types.Named)
   222  	return isSlice
   223  }
   224  
   225  func (t *TypeReference) IsStruct() bool {
   226  	_, isStruct := t.GO.Underlying().(*types.Struct)
   227  	return isStruct
   228  }
   229  
   230  func (t *TypeReference) IsScalar() bool {
   231  	return t.Definition.Kind == ast.Scalar
   232  }
   233  
   234  func (t *TypeReference) UniquenessKey() string {
   235  	var nullability = "O"
   236  	if t.GQL.NonNull {
   237  		nullability = "N"
   238  	}
   239  
   240  	var elemNullability = ""
   241  	if t.GQL.Elem != nil && t.GQL.Elem.NonNull {
   242  		// Fix for #896
   243  		elemNullability = "áš„"
   244  	}
   245  	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability
   246  }
   247  
   248  func (t *TypeReference) MarshalFunc() string {
   249  	if t.Definition == nil {
   250  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   251  	}
   252  
   253  	if t.Definition.Kind == ast.InputObject {
   254  		return ""
   255  	}
   256  
   257  	return "marshal" + t.UniquenessKey()
   258  }
   259  
   260  func (t *TypeReference) UnmarshalFunc() string {
   261  	if t.Definition == nil {
   262  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   263  	}
   264  
   265  	if !t.Definition.IsInputType() {
   266  		return ""
   267  	}
   268  
   269  	return "unmarshal" + t.UniquenessKey()
   270  }
   271  
   272  func (t *TypeReference) IsTargetNilable() bool {
   273  	return IsNilable(t.Target)
   274  }
   275  
   276  func (b *Binder) PushRef(ret *TypeReference) {
   277  	b.References = append(b.References, ret)
   278  }
   279  
   280  func isMap(t types.Type) bool {
   281  	if t == nil {
   282  		return true
   283  	}
   284  	_, ok := t.(*types.Map)
   285  	return ok
   286  }
   287  
   288  func isIntf(t types.Type) bool {
   289  	if t == nil {
   290  		return true
   291  	}
   292  	_, ok := t.(*types.Interface)
   293  	return ok
   294  }
   295  
   296  func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
   297  	if !isValid(bindTarget) {
   298  		b.SawInvalid = true
   299  		return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
   300  	}
   301  
   302  	var pkgName, typeName string
   303  	def := b.schema.Types[schemaType.Name()]
   304  	defer func() {
   305  		if err == nil && ret != nil {
   306  			b.PushRef(ret)
   307  		}
   308  	}()
   309  
   310  	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
   311  		return nil, fmt.Errorf("%s was not found", schemaType.Name())
   312  	}
   313  
   314  	for _, model := range b.cfg.Models[schemaType.Name()].Model {
   315  		if model == "map[string]interface{}" {
   316  			if !isMap(bindTarget) {
   317  				continue
   318  			}
   319  			return &TypeReference{
   320  				Definition: def,
   321  				GQL:        schemaType,
   322  				GO:         MapType,
   323  			}, nil
   324  		}
   325  
   326  		if model == "interface{}" {
   327  			if !isIntf(bindTarget) {
   328  				continue
   329  			}
   330  			return &TypeReference{
   331  				Definition: def,
   332  				GQL:        schemaType,
   333  				GO:         InterfaceType,
   334  			}, nil
   335  		}
   336  
   337  		pkgName, typeName = code.PkgAndType(model)
   338  		if pkgName == "" {
   339  			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
   340  		}
   341  
   342  		ref := &TypeReference{
   343  			Definition: def,
   344  			GQL:        schemaType,
   345  		}
   346  
   347  		obj, err := b.FindObject(pkgName, typeName)
   348  		if err != nil {
   349  			return nil, err
   350  		}
   351  
   352  		if fun, isFunc := obj.(*types.Func); isFunc {
   353  			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
   354  			ref.Marshaler = fun
   355  			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
   356  		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
   357  			ref.GO = obj.Type()
   358  			ref.IsMarshaler = true
   359  		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
   360  			// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
   361  
   362  			ref.GO = obj.Type()
   363  			ref.CastType = underlying
   364  
   365  			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
   366  			if err != nil {
   367  				return nil, err
   368  			}
   369  
   370  			ref.Marshaler = underlyingRef.Marshaler
   371  			ref.Unmarshaler = underlyingRef.Unmarshaler
   372  		} else {
   373  			ref.GO = obj.Type()
   374  		}
   375  
   376  		ref.Target = ref.GO
   377  		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
   378  
   379  		if bindTarget != nil {
   380  			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
   381  				continue
   382  			}
   383  			ref.GO = bindTarget
   384  		}
   385  
   386  		return ref, nil
   387  	}
   388  
   389  	return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
   390  }
   391  
   392  func isValid(t types.Type) bool {
   393  	basic, isBasic := t.(*types.Basic)
   394  	if !isBasic {
   395  		return true
   396  	}
   397  	return basic.Kind() != types.Invalid
   398  }
   399  
   400  func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
   401  	if t.Elem != nil {
   402  		child := b.CopyModifiersFromAst(t.Elem, base)
   403  		if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
   404  			child = types.NewPointer(child)
   405  		}
   406  		return types.NewSlice(child)
   407  	}
   408  
   409  	var isInterface bool
   410  	if named, ok := base.(*types.Named); ok {
   411  		_, isInterface = named.Underlying().(*types.Interface)
   412  	}
   413  
   414  	if !isInterface && !IsNilable(base) && !t.NonNull {
   415  		return types.NewPointer(base)
   416  	}
   417  
   418  	return base
   419  }
   420  
   421  func IsNilable(t types.Type) bool {
   422  	if namedType, isNamed := t.(*types.Named); isNamed {
   423  		return IsNilable(namedType.Underlying())
   424  	}
   425  	_, isPtr := t.(*types.Pointer)
   426  	_, isMap := t.(*types.Map)
   427  	_, isInterface := t.(*types.Interface)
   428  	_, isSlice := t.(*types.Slice)
   429  	_, isChan := t.(*types.Chan)
   430  	return isPtr || isMap || isInterface || isSlice || isChan
   431  }
   432  
   433  func hasMethod(it types.Type, name string) bool {
   434  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   435  		it = ptr.Elem()
   436  	}
   437  	namedType, ok := it.(*types.Named)
   438  	if !ok {
   439  		return false
   440  	}
   441  
   442  	for i := 0; i < namedType.NumMethods(); i++ {
   443  		if namedType.Method(i).Name() == name {
   444  			return true
   445  		}
   446  	}
   447  	return false
   448  }
   449  
   450  func basicUnderlying(it types.Type) *types.Basic {
   451  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   452  		it = ptr.Elem()
   453  	}
   454  	namedType, ok := it.(*types.Named)
   455  	if !ok {
   456  		return nil
   457  	}
   458  
   459  	if basic, ok := namedType.Underlying().(*types.Basic); ok {
   460  		return basic
   461  	}
   462  
   463  	return nil
   464  }