github.com/shippio/gqlgen@v0.0.0-20220912092219-633ea699ef07/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/types"
     7  	"log"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/99designs/gqlgen/codegen/config"
    13  	"github.com/99designs/gqlgen/codegen/templates"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  	"golang.org/x/text/cases"
    16  	"golang.org/x/text/language"
    17  )
    18  
    19  type Field struct {
    20  	*ast.FieldDefinition
    21  
    22  	TypeReference    *config.TypeReference
    23  	GoFieldType      GoFieldType      // The field type in go, if any
    24  	GoReceiverName   string           // The name of method & var receiver in go, if any
    25  	GoFieldName      string           // The name of the method or var in go, if any
    26  	IsResolver       bool             // Does this field need a resolver
    27  	Args             []*FieldArgument // A list of arguments to be passed to this field
    28  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    29  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    30  	VOkFunc          bool             // If this is bound to a go method, is it of shape (interface{}, bool)
    31  	Object           *Object          // A link back to the parent object
    32  	Default          interface{}      // The default value
    33  	Stream           bool             // does this field return a channel?
    34  	Directives       []*Directive
    35  }
    36  
    37  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    38  	dirs, err := b.getDirectives(field.Directives)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	f := Field{
    44  		FieldDefinition: field,
    45  		Object:          obj,
    46  		Directives:      dirs,
    47  		GoFieldName:     templates.ToGo(field.Name),
    48  		GoFieldType:     GoFieldVariable,
    49  		GoReceiverName:  "obj",
    50  	}
    51  
    52  	if field.DefaultValue != nil {
    53  		var err error
    54  		f.Default, err = field.DefaultValue.Value(nil)
    55  		if err != nil {
    56  			return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
    57  		}
    58  	}
    59  
    60  	for _, arg := range field.Arguments {
    61  		newArg, err := b.buildArg(obj, arg)
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		f.Args = append(f.Args, newArg)
    66  	}
    67  
    68  	if err = b.bindField(obj, &f); err != nil {
    69  		f.IsResolver = true
    70  		if errors.Is(err, config.ErrTypeNotFound) {
    71  			return nil, err
    72  		}
    73  		log.Println(err.Error())
    74  	}
    75  
    76  	if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    77  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    78  	}
    79  
    80  	return &f, nil
    81  }
    82  
    83  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    84  	defer func() {
    85  		if f.TypeReference == nil {
    86  			tr, err := b.Binder.TypeReference(f.Type, nil)
    87  			if err != nil {
    88  				errret = err
    89  			}
    90  			f.TypeReference = tr
    91  		}
    92  		if f.TypeReference != nil {
    93  			dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
    94  			if err != nil {
    95  				errret = err
    96  			}
    97  			for _, dir := range obj.Directives {
    98  				if dir.IsLocation(ast.LocationInputObject) {
    99  					dirs = append(dirs, dir)
   100  				}
   101  			}
   102  			f.Directives = append(dirs, f.Directives...)
   103  		}
   104  	}()
   105  
   106  	f.Stream = obj.Stream
   107  
   108  	switch {
   109  	case f.Name == "__schema":
   110  		f.GoFieldType = GoFieldMethod
   111  		f.GoReceiverName = "ec"
   112  		f.GoFieldName = "introspectSchema"
   113  		return nil
   114  	case f.Name == "__type":
   115  		f.GoFieldType = GoFieldMethod
   116  		f.GoReceiverName = "ec"
   117  		f.GoFieldName = "introspectType"
   118  		return nil
   119  	case f.Name == "_entities":
   120  		f.GoFieldType = GoFieldMethod
   121  		f.GoReceiverName = "ec"
   122  		f.GoFieldName = "__resolve_entities"
   123  		f.MethodHasContext = true
   124  		f.NoErr = true
   125  		return nil
   126  	case f.Name == "_service":
   127  		f.GoFieldType = GoFieldMethod
   128  		f.GoReceiverName = "ec"
   129  		f.GoFieldName = "__resolve__service"
   130  		f.MethodHasContext = true
   131  		return nil
   132  	case obj.Root:
   133  		f.IsResolver = true
   134  		return nil
   135  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   136  		f.IsResolver = true
   137  		return nil
   138  	case obj.Type == config.MapType:
   139  		f.GoFieldType = GoFieldMap
   140  		return nil
   141  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   142  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   143  	}
   144  
   145  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   146  	if err != nil {
   147  		return err
   148  	}
   149  
   150  	pos := b.Binder.ObjectPosition(target)
   151  
   152  	switch target := target.(type) {
   153  	case nil:
   154  		objPos := b.Binder.TypePosition(obj.Type)
   155  		return fmt.Errorf(
   156  			"%s:%d adding resolver method for %s.%s, nothing matched",
   157  			objPos.Filename,
   158  			objPos.Line,
   159  			obj.Name,
   160  			f.Name,
   161  		)
   162  
   163  	case *types.Func:
   164  		sig := target.Type().(*types.Signature)
   165  		if sig.Results().Len() == 1 {
   166  			f.NoErr = true
   167  		} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
   168  			f.VOkFunc = true
   169  		} else if sig.Results().Len() != 2 {
   170  			return fmt.Errorf("method has wrong number of args")
   171  		}
   172  		params := sig.Params()
   173  		// If the first argument is the context, remove it from the comparison and set
   174  		// the MethodHasContext flag so that the context will be passed to this model's method
   175  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   176  			f.MethodHasContext = true
   177  			vars := make([]*types.Var, params.Len()-1)
   178  			for i := 1; i < params.Len(); i++ {
   179  				vars[i-1] = params.At(i)
   180  			}
   181  			params = types.NewTuple(vars...)
   182  		}
   183  
   184  		// Try to match target function's arguments with GraphQL field arguments.
   185  		newArgs, err := b.bindArgs(f, sig, params)
   186  		if err != nil {
   187  			return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
   188  		}
   189  
   190  		// Try to match target function's return types with GraphQL field return type
   191  		result := sig.Results().At(0)
   192  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   193  		if err != nil {
   194  			return err
   195  		}
   196  
   197  		// success, args and return type match. Bind to method
   198  		f.GoFieldType = GoFieldMethod
   199  		f.GoReceiverName = "obj"
   200  		f.GoFieldName = target.Name()
   201  		f.Args = newArgs
   202  		f.TypeReference = tr
   203  
   204  		return nil
   205  
   206  	case *types.Var:
   207  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   208  		if err != nil {
   209  			return err
   210  		}
   211  
   212  		// success, bind to var
   213  		f.GoFieldType = GoFieldVariable
   214  		f.GoReceiverName = "obj"
   215  		f.GoFieldName = target.Name()
   216  		f.TypeReference = tr
   217  
   218  		return nil
   219  	default:
   220  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   221  	}
   222  }
   223  
   224  // findBindTarget attempts to match the name to a field or method on a Type
   225  // with the following priorites:
   226  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   227  // 2. Any method or field with a matching name. Errors if more than one match is found
   228  // 3. Same logic again for embedded fields
   229  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   230  	// NOTE: a struct tag will override both methods and fields
   231  	// Bind to struct tag
   232  	found, err := b.findBindStructTagTarget(t, name)
   233  	if found != nil || err != nil {
   234  		return found, err
   235  	}
   236  
   237  	// Search for a method to bind to
   238  	foundMethod, err := b.findBindMethodTarget(t, name)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	// Search for a field to bind to
   244  	foundField, err := b.findBindFieldTarget(t, name)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	switch {
   250  	case foundField == nil && foundMethod != nil:
   251  		// Bind to method
   252  		return foundMethod, nil
   253  	case foundField != nil && foundMethod == nil:
   254  		// Bind to field
   255  		return foundField, nil
   256  	case foundField != nil && foundMethod != nil:
   257  		// Error
   258  		return nil, fmt.Errorf("found more than one way to bind for %s", name)
   259  	}
   260  
   261  	// Search embeds
   262  	return b.findBindEmbedsTarget(t, name)
   263  }
   264  
   265  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   266  	if b.Config.StructTag == "" {
   267  		return nil, nil
   268  	}
   269  
   270  	switch t := in.(type) {
   271  	case *types.Named:
   272  		return b.findBindStructTagTarget(t.Underlying(), name)
   273  	case *types.Struct:
   274  		var found types.Object
   275  		for i := 0; i < t.NumFields(); i++ {
   276  			field := t.Field(i)
   277  			if !field.Exported() || field.Embedded() {
   278  				continue
   279  			}
   280  			tags := reflect.StructTag(t.Tag(i))
   281  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   282  				if found != nil {
   283  					return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   284  				}
   285  
   286  				found = field
   287  			}
   288  		}
   289  
   290  		return found, nil
   291  	}
   292  
   293  	return nil, nil
   294  }
   295  
   296  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   297  	switch t := in.(type) {
   298  	case *types.Named:
   299  		if _, ok := t.Underlying().(*types.Interface); ok {
   300  			return b.findBindMethodTarget(t.Underlying(), name)
   301  		}
   302  
   303  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   304  	case *types.Interface:
   305  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   306  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   307  	}
   308  
   309  	return nil, nil
   310  }
   311  
   312  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   313  	var found types.Object
   314  	for i := 0; i < methodCount; i++ {
   315  		method := methodFunc(i)
   316  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   317  			continue
   318  		}
   319  
   320  		if found != nil {
   321  			return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
   322  		}
   323  
   324  		found = method
   325  	}
   326  
   327  	return found, nil
   328  }
   329  
   330  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   331  	switch t := in.(type) {
   332  	case *types.Named:
   333  		return b.findBindFieldTarget(t.Underlying(), name)
   334  	case *types.Struct:
   335  		var found types.Object
   336  		for i := 0; i < t.NumFields(); i++ {
   337  			field := t.Field(i)
   338  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   339  				continue
   340  			}
   341  
   342  			if found != nil {
   343  				return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
   344  			}
   345  
   346  			found = field
   347  		}
   348  
   349  		return found, nil
   350  	}
   351  
   352  	return nil, nil
   353  }
   354  
   355  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   356  	switch t := in.(type) {
   357  	case *types.Named:
   358  		return b.findBindEmbedsTarget(t.Underlying(), name)
   359  	case *types.Struct:
   360  		return b.findBindStructEmbedsTarget(t, name)
   361  	case *types.Interface:
   362  		return b.findBindInterfaceEmbedsTarget(t, name)
   363  	}
   364  
   365  	return nil, nil
   366  }
   367  
   368  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   369  	var found types.Object
   370  	for i := 0; i < strukt.NumFields(); i++ {
   371  		field := strukt.Field(i)
   372  		if !field.Embedded() {
   373  			continue
   374  		}
   375  
   376  		fieldType := field.Type()
   377  		if ptr, ok := fieldType.(*types.Pointer); ok {
   378  			fieldType = ptr.Elem()
   379  		}
   380  
   381  		f, err := b.findBindTarget(fieldType, name)
   382  		if err != nil {
   383  			return nil, err
   384  		}
   385  
   386  		if f != nil && found != nil {
   387  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   388  		}
   389  
   390  		if f != nil {
   391  			found = f
   392  		}
   393  	}
   394  
   395  	return found, nil
   396  }
   397  
   398  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   399  	var found types.Object
   400  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   401  		embeddedType := iface.EmbeddedType(i)
   402  
   403  		f, err := b.findBindTarget(embeddedType, name)
   404  		if err != nil {
   405  			return nil, err
   406  		}
   407  
   408  		if f != nil && found != nil {
   409  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   410  		}
   411  
   412  		if f != nil {
   413  			found = f
   414  		}
   415  	}
   416  
   417  	return found, nil
   418  }
   419  
   420  func (f *Field) HasDirectives() bool {
   421  	return len(f.ImplDirectives()) > 0
   422  }
   423  
   424  func (f *Field) DirectiveObjName() string {
   425  	if f.Object.Root {
   426  		return "nil"
   427  	}
   428  	return f.GoReceiverName
   429  }
   430  
   431  func (f *Field) ImplDirectives() []*Directive {
   432  	var d []*Directive
   433  	loc := ast.LocationFieldDefinition
   434  	if f.Object.IsInputType() {
   435  		loc = ast.LocationInputFieldDefinition
   436  	}
   437  	for i := range f.Directives {
   438  		if !f.Directives[i].Builtin &&
   439  			(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
   440  			d = append(d, f.Directives[i])
   441  		}
   442  	}
   443  	return d
   444  }
   445  
   446  func (f *Field) IsReserved() bool {
   447  	return strings.HasPrefix(f.Name, "__")
   448  }
   449  
   450  func (f *Field) IsMethod() bool {
   451  	return f.GoFieldType == GoFieldMethod
   452  }
   453  
   454  func (f *Field) IsVariable() bool {
   455  	return f.GoFieldType == GoFieldVariable
   456  }
   457  
   458  func (f *Field) IsMap() bool {
   459  	return f.GoFieldType == GoFieldMap
   460  }
   461  
   462  func (f *Field) IsConcurrent() bool {
   463  	if f.Object.DisableConcurrency {
   464  		return false
   465  	}
   466  	return f.MethodHasContext || f.IsResolver
   467  }
   468  
   469  func (f *Field) GoNameUnexported() string {
   470  	return templates.ToGoPrivate(f.Name)
   471  }
   472  
   473  func (f *Field) ShortInvocation() string {
   474  	caser := cases.Title(language.English, cases.NoLower)
   475  	if f.Object.Kind == ast.InputObject {
   476  		return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName)
   477  	}
   478  	return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
   479  }
   480  
   481  func (f *Field) ArgsFunc() string {
   482  	if len(f.Args) == 0 {
   483  		return ""
   484  	}
   485  
   486  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   487  }
   488  
   489  func (f *Field) FieldContextFunc() string {
   490  	return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name
   491  }
   492  
   493  func (f *Field) ChildFieldContextFunc(name string) string {
   494  	return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name
   495  }
   496  
   497  func (f *Field) ResolverType() string {
   498  	if !f.IsResolver {
   499  		return ""
   500  	}
   501  
   502  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   503  }
   504  
   505  func (f *Field) ShortResolverDeclaration() string {
   506  	if f.Object.Kind == ast.InputObject {
   507  		return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
   508  			templates.CurrentImports.LookupType(f.Object.Reference()),
   509  			templates.CurrentImports.LookupType(f.TypeReference.GO),
   510  		)
   511  	}
   512  
   513  	res := "(ctx context.Context"
   514  
   515  	if !f.Object.Root {
   516  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   517  	}
   518  	for _, arg := range f.Args {
   519  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   520  	}
   521  
   522  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   523  	if f.Object.Stream {
   524  		result = "<-chan " + result
   525  	}
   526  
   527  	res += fmt.Sprintf(") (%s, error)", result)
   528  	return res
   529  }
   530  
   531  func (f *Field) ComplexitySignature() string {
   532  	res := "func(childComplexity int"
   533  	for _, arg := range f.Args {
   534  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   535  	}
   536  	res += ") int"
   537  	return res
   538  }
   539  
   540  func (f *Field) ComplexityArgs() string {
   541  	args := make([]string, len(f.Args))
   542  	for i, arg := range f.Args {
   543  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   544  	}
   545  
   546  	return strings.Join(args, ", ")
   547  }
   548  
   549  func (f *Field) CallArgs() string {
   550  	args := make([]string, 0, len(f.Args)+2)
   551  
   552  	if f.IsResolver {
   553  		args = append(args, "rctx")
   554  
   555  		if !f.Object.Root {
   556  			args = append(args, "obj")
   557  		}
   558  	} else if f.MethodHasContext {
   559  		args = append(args, "ctx")
   560  	}
   561  
   562  	for _, arg := range f.Args {
   563  		tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   564  
   565  		if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() {
   566  			tmp = fmt.Sprintf(`
   567  				func () interface{} {
   568  					if fc.Args["%s"] == nil {
   569  						return nil
   570  					}
   571  					return fc.Args["%s"].(interface{})
   572  				}()`, arg.Name, arg.Name,
   573  			)
   574  		}
   575  
   576  		args = append(args, tmp)
   577  	}
   578  
   579  	return strings.Join(args, ", ")
   580  }