git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/config/binder.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"go/types"
     7  
     8  	"git.sr.ht/~sircmpwn/gqlgen/codegen/templates"
     9  	"git.sr.ht/~sircmpwn/gqlgen/internal/code"
    10  	"github.com/pkg/errors"
    11  	"github.com/vektah/gqlparser/v2/ast"
    12  )
    13  
    14  // Binder connects graphql types to golang types using static analysis
    15  type Binder struct {
    16  	pkgs       *code.Packages
    17  	schema     *ast.Schema
    18  	cfg        *Config
    19  	References []*TypeReference
    20  	SawInvalid bool
    21  }
    22  
    23  func (c *Config) NewBinder() *Binder {
    24  	return &Binder{
    25  		pkgs:   c.Packages,
    26  		schema: c.Schema,
    27  		cfg:    c,
    28  	}
    29  }
    30  
    31  func (b *Binder) TypePosition(typ types.Type) token.Position {
    32  	named, isNamed := typ.(*types.Named)
    33  	if !isNamed {
    34  		return token.Position{
    35  			Filename: "unknown",
    36  		}
    37  	}
    38  
    39  	return b.ObjectPosition(named.Obj())
    40  }
    41  
    42  func (b *Binder) ObjectPosition(typ types.Object) token.Position {
    43  	if typ == nil {
    44  		return token.Position{
    45  			Filename: "unknown",
    46  		}
    47  	}
    48  	pkg := b.pkgs.Load(typ.Pkg().Path())
    49  	return pkg.Fset.Position(typ.Pos())
    50  }
    51  
    52  func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
    53  	pkgName, typeName := code.PkgAndType(name)
    54  	return b.FindType(pkgName, typeName)
    55  }
    56  
    57  func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
    58  	if pkgName == "" {
    59  		if typeName == "map[string]interface{}" {
    60  			return MapType, nil
    61  		}
    62  
    63  		if typeName == "interface{}" {
    64  			return InterfaceType, nil
    65  		}
    66  	}
    67  
    68  	obj, err := b.FindObject(pkgName, typeName)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	if fun, isFunc := obj.(*types.Func); isFunc {
    74  		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
    75  	}
    76  	return obj.Type(), nil
    77  }
    78  
    79  var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
    80  var InterfaceType = types.NewInterfaceType(nil, nil)
    81  
    82  func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
    83  	models := b.cfg.Models[name].Model
    84  	if len(models) == 0 {
    85  		return nil, fmt.Errorf(name + " not found in typemap")
    86  	}
    87  
    88  	if models[0] == "map[string]interface{}" {
    89  		return MapType, nil
    90  	}
    91  
    92  	if models[0] == "interface{}" {
    93  		return InterfaceType, nil
    94  	}
    95  
    96  	pkgName, typeName := code.PkgAndType(models[0])
    97  	if pkgName == "" {
    98  		return nil, fmt.Errorf("missing package name for %s", name)
    99  	}
   100  
   101  	obj, err := b.FindObject(pkgName, typeName)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	return obj.Type(), nil
   107  }
   108  
   109  func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
   110  	if pkgName == "" {
   111  		return nil, fmt.Errorf("package cannot be nil")
   112  	}
   113  	fullName := typeName
   114  	if pkgName != "" {
   115  		fullName = pkgName + "." + typeName
   116  	}
   117  
   118  	pkg := b.pkgs.LoadWithTypes(pkgName)
   119  	if pkg == nil {
   120  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
   121  	}
   122  
   123  	// function based marshalers take precedence
   124  	for astNode, def := range pkg.TypesInfo.Defs {
   125  		// only look at defs in the top scope
   126  		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
   127  			continue
   128  		}
   129  
   130  		if astNode.Name == "Marshal"+typeName {
   131  			return def, nil
   132  		}
   133  	}
   134  
   135  	// then look for types directly
   136  	for astNode, def := range pkg.TypesInfo.Defs {
   137  		// only look at defs in the top scope
   138  		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
   139  			continue
   140  		}
   141  
   142  		if astNode.Name == typeName {
   143  			return def, nil
   144  		}
   145  	}
   146  
   147  	return nil, errors.Errorf("unable to find type %s\n", fullName)
   148  }
   149  
   150  func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
   151  	newRef := &TypeReference{
   152  		GO:          types.NewPointer(ref.GO),
   153  		GQL:         ref.GQL,
   154  		CastType:    ref.CastType,
   155  		Definition:  ref.Definition,
   156  		Unmarshaler: ref.Unmarshaler,
   157  		Marshaler:   ref.Marshaler,
   158  		IsMarshaler: ref.IsMarshaler,
   159  	}
   160  
   161  	b.References = append(b.References, newRef)
   162  	return newRef
   163  }
   164  
   165  // TypeReference is used by args and field types. The Definition can refer to both input and output types.
   166  type TypeReference struct {
   167  	Definition  *ast.Definition
   168  	GQL         *ast.Type
   169  	GO          types.Type
   170  	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
   171  	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
   172  	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
   173  	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
   174  }
   175  
   176  func (ref *TypeReference) Elem() *TypeReference {
   177  	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
   178  		return &TypeReference{
   179  			GO:          p.Elem(),
   180  			GQL:         ref.GQL,
   181  			CastType:    ref.CastType,
   182  			Definition:  ref.Definition,
   183  			Unmarshaler: ref.Unmarshaler,
   184  			Marshaler:   ref.Marshaler,
   185  			IsMarshaler: ref.IsMarshaler,
   186  		}
   187  	}
   188  
   189  	if ref.IsSlice() {
   190  		return &TypeReference{
   191  			GO:          ref.GO.(*types.Slice).Elem(),
   192  			GQL:         ref.GQL.Elem,
   193  			CastType:    ref.CastType,
   194  			Definition:  ref.Definition,
   195  			Unmarshaler: ref.Unmarshaler,
   196  			Marshaler:   ref.Marshaler,
   197  			IsMarshaler: ref.IsMarshaler,
   198  		}
   199  	}
   200  	return nil
   201  }
   202  
   203  func (t *TypeReference) IsPtr() bool {
   204  	_, isPtr := t.GO.(*types.Pointer)
   205  	return isPtr
   206  }
   207  
   208  func (t *TypeReference) IsNilable() bool {
   209  	return IsNilable(t.GO)
   210  }
   211  
   212  func (t *TypeReference) IsSlice() bool {
   213  	_, isSlice := t.GO.(*types.Slice)
   214  	return t.GQL.Elem != nil && isSlice
   215  }
   216  
   217  func (t *TypeReference) IsNamed() bool {
   218  	_, isSlice := t.GO.(*types.Named)
   219  	return isSlice
   220  }
   221  
   222  func (t *TypeReference) IsStruct() bool {
   223  	_, isStruct := t.GO.Underlying().(*types.Struct)
   224  	return isStruct
   225  }
   226  
   227  func (t *TypeReference) IsScalar() bool {
   228  	return t.Definition.Kind == ast.Scalar
   229  }
   230  
   231  func (t *TypeReference) UniquenessKey() string {
   232  	var nullability = "O"
   233  	if t.GQL.NonNull {
   234  		nullability = "N"
   235  	}
   236  
   237  	var elemNullability = ""
   238  	if t.GQL.Elem != nil && t.GQL.Elem.NonNull {
   239  		// Fix for #896
   240  		elemNullability = "áš„"
   241  	}
   242  	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability
   243  }
   244  
   245  func (t *TypeReference) MarshalFunc() string {
   246  	if t.Definition == nil {
   247  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   248  	}
   249  
   250  	if t.Definition.Kind == ast.InputObject {
   251  		return ""
   252  	}
   253  
   254  	return "marshal" + t.UniquenessKey()
   255  }
   256  
   257  func (t *TypeReference) UnmarshalFunc() string {
   258  	if t.Definition == nil {
   259  		panic(errors.New("Definition missing for " + t.GQL.Name()))
   260  	}
   261  
   262  	if !t.Definition.IsInputType() {
   263  		return ""
   264  	}
   265  
   266  	return "unmarshal" + t.UniquenessKey()
   267  }
   268  
   269  func (b *Binder) PushRef(ret *TypeReference) {
   270  	b.References = append(b.References, ret)
   271  }
   272  
   273  func isMap(t types.Type) bool {
   274  	if t == nil {
   275  		return true
   276  	}
   277  	_, ok := t.(*types.Map)
   278  	return ok
   279  }
   280  
   281  func isIntf(t types.Type) bool {
   282  	if t == nil {
   283  		return true
   284  	}
   285  	_, ok := t.(*types.Interface)
   286  	return ok
   287  }
   288  
   289  func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
   290  	if !isValid(bindTarget) {
   291  		b.SawInvalid = true
   292  		return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
   293  	}
   294  
   295  	var pkgName, typeName string
   296  	def := b.schema.Types[schemaType.Name()]
   297  	defer func() {
   298  		if err == nil && ret != nil {
   299  			b.PushRef(ret)
   300  		}
   301  	}()
   302  
   303  	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
   304  		return nil, fmt.Errorf("%s was not found", schemaType.Name())
   305  	}
   306  
   307  	for _, model := range b.cfg.Models[schemaType.Name()].Model {
   308  		if model == "map[string]interface{}" {
   309  			if !isMap(bindTarget) {
   310  				continue
   311  			}
   312  			return &TypeReference{
   313  				Definition: def,
   314  				GQL:        schemaType,
   315  				GO:         MapType,
   316  			}, nil
   317  		}
   318  
   319  		if model == "interface{}" {
   320  			if !isIntf(bindTarget) {
   321  				continue
   322  			}
   323  			return &TypeReference{
   324  				Definition: def,
   325  				GQL:        schemaType,
   326  				GO:         InterfaceType,
   327  			}, nil
   328  		}
   329  
   330  		pkgName, typeName = code.PkgAndType(model)
   331  		if pkgName == "" {
   332  			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
   333  		}
   334  
   335  		ref := &TypeReference{
   336  			Definition: def,
   337  			GQL:        schemaType,
   338  		}
   339  
   340  		obj, err := b.FindObject(pkgName, typeName)
   341  		if err != nil {
   342  			return nil, err
   343  		}
   344  
   345  		if fun, isFunc := obj.(*types.Func); isFunc {
   346  			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
   347  			ref.Marshaler = fun
   348  			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
   349  		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
   350  			ref.GO = obj.Type()
   351  			ref.IsMarshaler = true
   352  		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
   353  			// Special case for named types wrapping strings. Used by default enum implementations.
   354  
   355  			ref.GO = obj.Type()
   356  			ref.CastType = underlying
   357  
   358  			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
   359  			if err != nil {
   360  				return nil, err
   361  			}
   362  
   363  			ref.Marshaler = underlyingRef.Marshaler
   364  			ref.Unmarshaler = underlyingRef.Unmarshaler
   365  		} else {
   366  			ref.GO = obj.Type()
   367  		}
   368  
   369  		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
   370  
   371  		if bindTarget != nil {
   372  			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
   373  				continue
   374  			}
   375  			ref.GO = bindTarget
   376  		}
   377  
   378  		return ref, nil
   379  	}
   380  
   381  	return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
   382  }
   383  
   384  func isValid(t types.Type) bool {
   385  	basic, isBasic := t.(*types.Basic)
   386  	if !isBasic {
   387  		return true
   388  	}
   389  	return basic.Kind() != types.Invalid
   390  }
   391  
   392  func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
   393  	if t.Elem != nil {
   394  		child := b.CopyModifiersFromAst(t.Elem, base)
   395  		if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
   396  			child = types.NewPointer(child)
   397  		}
   398  		return types.NewSlice(child)
   399  	}
   400  
   401  	var isInterface bool
   402  	if named, ok := base.(*types.Named); ok {
   403  		_, isInterface = named.Underlying().(*types.Interface)
   404  	}
   405  
   406  	if !isInterface && !IsNilable(base) && !t.NonNull {
   407  		return types.NewPointer(base)
   408  	}
   409  
   410  	return base
   411  }
   412  
   413  func IsNilable(t types.Type) bool {
   414  	if namedType, isNamed := t.(*types.Named); isNamed {
   415  		t = namedType.Underlying()
   416  	}
   417  	_, isPtr := t.(*types.Pointer)
   418  	_, isMap := t.(*types.Map)
   419  	_, isInterface := t.(*types.Interface)
   420  	return isPtr || isMap || isInterface
   421  }
   422  
   423  func hasMethod(it types.Type, name string) bool {
   424  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   425  		it = ptr.Elem()
   426  	}
   427  	namedType, ok := it.(*types.Named)
   428  	if !ok {
   429  		return false
   430  	}
   431  
   432  	for i := 0; i < namedType.NumMethods(); i++ {
   433  		if namedType.Method(i).Name() == name {
   434  			return true
   435  		}
   436  	}
   437  	return false
   438  }
   439  
   440  func basicUnderlying(it types.Type) *types.Basic {
   441  	if ptr, isPtr := it.(*types.Pointer); isPtr {
   442  		it = ptr.Elem()
   443  	}
   444  	namedType, ok := it.(*types.Named)
   445  	if !ok {
   446  		return nil
   447  	}
   448  
   449  	if basic, ok := namedType.Underlying().(*types.Basic); ok {
   450  		return basic
   451  	}
   452  
   453  	return nil
   454  }