github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	goast "go/ast"
     7  	"go/types"
     8  	"log"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/vektah/gqlparser/v2/ast"
    14  	"golang.org/x/text/cases"
    15  	"golang.org/x/text/language"
    16  
    17  	"github.com/niko0xdev/gqlgen/codegen/config"
    18  	"github.com/niko0xdev/gqlgen/codegen/templates"
    19  )
    20  
    21  type Field struct {
    22  	*ast.FieldDefinition
    23  
    24  	TypeReference    *config.TypeReference
    25  	GoFieldType      GoFieldType      // The field type in go, if any
    26  	GoReceiverName   string           // The name of method & var receiver in go, if any
    27  	GoFieldName      string           // The name of the method or var in go, if any
    28  	IsResolver       bool             // Does this field need a resolver
    29  	Args             []*FieldArgument // A list of arguments to be passed to this field
    30  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    31  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    32  	VOkFunc          bool             // If this is bound to a go method, is it of shape (interface{}, bool)
    33  	Object           *Object          // A link back to the parent object
    34  	Default          interface{}      // The default value
    35  	Stream           bool             // does this field return a channel?
    36  	Directives       []*Directive
    37  }
    38  
    39  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    40  	dirs, err := b.getDirectives(field.Directives)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	f := Field{
    46  		FieldDefinition: field,
    47  		Object:          obj,
    48  		Directives:      dirs,
    49  		GoFieldName:     templates.ToGo(field.Name),
    50  		GoFieldType:     GoFieldVariable,
    51  		GoReceiverName:  "obj",
    52  	}
    53  
    54  	if field.DefaultValue != nil {
    55  		var err error
    56  		f.Default, err = field.DefaultValue.Value(nil)
    57  		if err != nil {
    58  			return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
    59  		}
    60  	}
    61  
    62  	for _, arg := range field.Arguments {
    63  		newArg, err := b.buildArg(obj, arg)
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  		f.Args = append(f.Args, newArg)
    68  	}
    69  
    70  	if err = b.bindField(obj, &f); err != nil {
    71  		f.IsResolver = true
    72  		if errors.Is(err, config.ErrTypeNotFound) {
    73  			return nil, err
    74  		}
    75  		log.Println(err.Error())
    76  	}
    77  
    78  	if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    79  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    80  	}
    81  
    82  	return &f, nil
    83  }
    84  
    85  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    86  	defer func() {
    87  		if f.TypeReference == nil {
    88  			tr, err := b.Binder.TypeReference(f.Type, nil)
    89  			if err != nil {
    90  				errret = err
    91  			}
    92  			f.TypeReference = tr
    93  		}
    94  		if f.TypeReference != nil {
    95  			dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
    96  			if err != nil {
    97  				errret = err
    98  			}
    99  			for _, dir := range obj.Directives {
   100  				if dir.IsLocation(ast.LocationInputObject) {
   101  					dirs = append(dirs, dir)
   102  				}
   103  			}
   104  			f.Directives = append(dirs, f.Directives...)
   105  		}
   106  	}()
   107  
   108  	f.Stream = obj.Stream
   109  
   110  	switch {
   111  	case f.Name == "__schema":
   112  		f.GoFieldType = GoFieldMethod
   113  		f.GoReceiverName = "ec"
   114  		f.GoFieldName = "introspectSchema"
   115  		return nil
   116  	case f.Name == "__type":
   117  		f.GoFieldType = GoFieldMethod
   118  		f.GoReceiverName = "ec"
   119  		f.GoFieldName = "introspectType"
   120  		return nil
   121  	case f.Name == "_entities":
   122  		f.GoFieldType = GoFieldMethod
   123  		f.GoReceiverName = "ec"
   124  		f.GoFieldName = "__resolve_entities"
   125  		f.MethodHasContext = true
   126  		f.NoErr = true
   127  		return nil
   128  	case f.Name == "_service":
   129  		f.GoFieldType = GoFieldMethod
   130  		f.GoReceiverName = "ec"
   131  		f.GoFieldName = "__resolve__service"
   132  		f.MethodHasContext = true
   133  		return nil
   134  	case obj.Root:
   135  		f.IsResolver = true
   136  		return nil
   137  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   138  		f.IsResolver = true
   139  		return nil
   140  	case obj.Type == config.MapType:
   141  		f.GoFieldType = GoFieldMap
   142  		return nil
   143  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   144  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   145  	}
   146  
   147  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	pos := b.Binder.ObjectPosition(target)
   153  
   154  	switch target := target.(type) {
   155  	case nil:
   156  		// Skips creating a resolver for any root types
   157  		if b.Config.IsRoot(b.Schema.Types[f.Type.Name()]) {
   158  			return nil
   159  		}
   160  
   161  		objPos := b.Binder.TypePosition(obj.Type)
   162  		return fmt.Errorf(
   163  			"%s:%d adding resolver method for %s.%s, nothing matched",
   164  			objPos.Filename,
   165  			objPos.Line,
   166  			obj.Name,
   167  			f.Name,
   168  		)
   169  
   170  	case *types.Func:
   171  		sig := target.Type().(*types.Signature)
   172  		if sig.Results().Len() == 1 {
   173  			f.NoErr = true
   174  		} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
   175  			f.VOkFunc = true
   176  		} else if sig.Results().Len() != 2 {
   177  			return fmt.Errorf("method has wrong number of args")
   178  		}
   179  		params := sig.Params()
   180  		// If the first argument is the context, remove it from the comparison and set
   181  		// the MethodHasContext flag so that the context will be passed to this model's method
   182  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   183  			f.MethodHasContext = true
   184  			vars := make([]*types.Var, params.Len()-1)
   185  			for i := 1; i < params.Len(); i++ {
   186  				vars[i-1] = params.At(i)
   187  			}
   188  			params = types.NewTuple(vars...)
   189  		}
   190  
   191  		// Try to match target function's arguments with GraphQL field arguments.
   192  		newArgs, err := b.bindArgs(f, sig, params)
   193  		if err != nil {
   194  			return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
   195  		}
   196  
   197  		// Try to match target function's return types with GraphQL field return type
   198  		result := sig.Results().At(0)
   199  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   200  		if err != nil {
   201  			return err
   202  		}
   203  
   204  		// success, args and return type match. Bind to method
   205  		f.GoFieldType = GoFieldMethod
   206  		f.GoReceiverName = "obj"
   207  		f.GoFieldName = target.Name()
   208  		f.Args = newArgs
   209  		f.TypeReference = tr
   210  
   211  		return nil
   212  
   213  	case *types.Var:
   214  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   215  		if err != nil {
   216  			return err
   217  		}
   218  
   219  		// success, bind to var
   220  		f.GoFieldType = GoFieldVariable
   221  		f.GoReceiverName = "obj"
   222  		f.GoFieldName = target.Name()
   223  		f.TypeReference = tr
   224  
   225  		return nil
   226  	default:
   227  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   228  	}
   229  }
   230  
   231  // findBindTarget attempts to match the name to a field or method on a Type
   232  // with the following priorites:
   233  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   234  // 2. Any method or field with a matching name. Errors if more than one match is found
   235  // 3. Same logic again for embedded fields
   236  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   237  	// NOTE: a struct tag will override both methods and fields
   238  	// Bind to struct tag
   239  	found, err := b.findBindStructTagTarget(t, name)
   240  	if found != nil || err != nil {
   241  		return found, err
   242  	}
   243  
   244  	// Search for a method to bind to
   245  	foundMethod, err := b.findBindMethodTarget(t, name)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	// Search for a field to bind to
   251  	foundField, err := b.findBindFieldTarget(t, name)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  
   256  	switch {
   257  	case foundField == nil && foundMethod != nil:
   258  		// Bind to method
   259  		return foundMethod, nil
   260  	case foundField != nil && foundMethod == nil:
   261  		// Bind to field
   262  		return foundField, nil
   263  	case foundField != nil && foundMethod != nil:
   264  		// Error
   265  		return nil, fmt.Errorf("found more than one way to bind for %s", name)
   266  	}
   267  
   268  	// Search embeds
   269  	return b.findBindEmbedsTarget(t, name)
   270  }
   271  
   272  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   273  	if b.Config.StructTag == "" {
   274  		return nil, nil
   275  	}
   276  
   277  	switch t := in.(type) {
   278  	case *types.Named:
   279  		return b.findBindStructTagTarget(t.Underlying(), name)
   280  	case *types.Struct:
   281  		var found types.Object
   282  		for i := 0; i < t.NumFields(); i++ {
   283  			field := t.Field(i)
   284  			if !field.Exported() || field.Embedded() {
   285  				continue
   286  			}
   287  			tags := reflect.StructTag(t.Tag(i))
   288  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   289  				if found != nil {
   290  					return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   291  				}
   292  
   293  				found = field
   294  			}
   295  		}
   296  
   297  		return found, nil
   298  	}
   299  
   300  	return nil, nil
   301  }
   302  
   303  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   304  	switch t := in.(type) {
   305  	case *types.Named:
   306  		if _, ok := t.Underlying().(*types.Interface); ok {
   307  			return b.findBindMethodTarget(t.Underlying(), name)
   308  		}
   309  
   310  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   311  	case *types.Interface:
   312  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   313  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   314  	}
   315  
   316  	return nil, nil
   317  }
   318  
   319  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   320  	var found types.Object
   321  	for i := 0; i < methodCount; i++ {
   322  		method := methodFunc(i)
   323  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   324  			continue
   325  		}
   326  
   327  		if found != nil {
   328  			return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
   329  		}
   330  
   331  		found = method
   332  	}
   333  
   334  	return found, nil
   335  }
   336  
   337  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   338  	switch t := in.(type) {
   339  	case *types.Named:
   340  		return b.findBindFieldTarget(t.Underlying(), name)
   341  	case *types.Struct:
   342  		var found types.Object
   343  		for i := 0; i < t.NumFields(); i++ {
   344  			field := t.Field(i)
   345  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   346  				continue
   347  			}
   348  
   349  			if found != nil {
   350  				return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
   351  			}
   352  
   353  			found = field
   354  		}
   355  
   356  		return found, nil
   357  	}
   358  
   359  	return nil, nil
   360  }
   361  
   362  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   363  	switch t := in.(type) {
   364  	case *types.Named:
   365  		return b.findBindEmbedsTarget(t.Underlying(), name)
   366  	case *types.Struct:
   367  		return b.findBindStructEmbedsTarget(t, name)
   368  	case *types.Interface:
   369  		return b.findBindInterfaceEmbedsTarget(t, name)
   370  	}
   371  
   372  	return nil, nil
   373  }
   374  
   375  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   376  	var found types.Object
   377  	for i := 0; i < strukt.NumFields(); i++ {
   378  		field := strukt.Field(i)
   379  		if !field.Embedded() {
   380  			continue
   381  		}
   382  
   383  		fieldType := field.Type()
   384  		if ptr, ok := fieldType.(*types.Pointer); ok {
   385  			fieldType = ptr.Elem()
   386  		}
   387  
   388  		f, err := b.findBindTarget(fieldType, name)
   389  		if err != nil {
   390  			return nil, err
   391  		}
   392  
   393  		if f != nil && found != nil {
   394  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   395  		}
   396  
   397  		if f != nil {
   398  			found = f
   399  		}
   400  	}
   401  
   402  	return found, nil
   403  }
   404  
   405  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   406  	var found types.Object
   407  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   408  		embeddedType := iface.EmbeddedType(i)
   409  
   410  		f, err := b.findBindTarget(embeddedType, name)
   411  		if err != nil {
   412  			return nil, err
   413  		}
   414  
   415  		if f != nil && found != nil {
   416  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   417  		}
   418  
   419  		if f != nil {
   420  			found = f
   421  		}
   422  	}
   423  
   424  	return found, nil
   425  }
   426  
   427  func (f *Field) HasDirectives() bool {
   428  	return len(f.ImplDirectives()) > 0
   429  }
   430  
   431  func (f *Field) DirectiveObjName() string {
   432  	if f.Object.Root {
   433  		return "nil"
   434  	}
   435  	return f.GoReceiverName
   436  }
   437  
   438  func (f *Field) ImplDirectives() []*Directive {
   439  	var d []*Directive
   440  	loc := ast.LocationFieldDefinition
   441  	if f.Object.IsInputType() {
   442  		loc = ast.LocationInputFieldDefinition
   443  	}
   444  	for i := range f.Directives {
   445  		if !f.Directives[i].Builtin &&
   446  			(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
   447  			d = append(d, f.Directives[i])
   448  		}
   449  	}
   450  	return d
   451  }
   452  
   453  func (f *Field) IsReserved() bool {
   454  	return strings.HasPrefix(f.Name, "__")
   455  }
   456  
   457  func (f *Field) IsMethod() bool {
   458  	return f.GoFieldType == GoFieldMethod
   459  }
   460  
   461  func (f *Field) IsVariable() bool {
   462  	return f.GoFieldType == GoFieldVariable
   463  }
   464  
   465  func (f *Field) IsMap() bool {
   466  	return f.GoFieldType == GoFieldMap
   467  }
   468  
   469  func (f *Field) IsConcurrent() bool {
   470  	if f.Object.DisableConcurrency {
   471  		return false
   472  	}
   473  	return f.MethodHasContext || f.IsResolver
   474  }
   475  
   476  func (f *Field) GoNameUnexported() string {
   477  	return templates.ToGoPrivate(f.Name)
   478  }
   479  
   480  func (f *Field) ShortInvocation() string {
   481  	caser := cases.Title(language.English, cases.NoLower)
   482  	if f.Object.Kind == ast.InputObject {
   483  		return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName)
   484  	}
   485  	return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
   486  }
   487  
   488  func (f *Field) ArgsFunc() string {
   489  	if len(f.Args) == 0 {
   490  		return ""
   491  	}
   492  
   493  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   494  }
   495  
   496  func (f *Field) FieldContextFunc() string {
   497  	return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name
   498  }
   499  
   500  func (f *Field) ChildFieldContextFunc(name string) string {
   501  	return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name
   502  }
   503  
   504  func (f *Field) ResolverType() string {
   505  	if !f.IsResolver {
   506  		return ""
   507  	}
   508  
   509  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   510  }
   511  
   512  func (f *Field) IsInputObject() bool {
   513  	return f.Object.Kind == ast.InputObject
   514  }
   515  
   516  func (f *Field) IsRoot() bool {
   517  	return f.Object.Root
   518  }
   519  
   520  func (f *Field) ShortResolverDeclaration() string {
   521  	return f.ShortResolverSignature(nil)
   522  }
   523  
   524  // ShortResolverSignature is identical to ShortResolverDeclaration,
   525  // but respects previous naming (return) conventions, if any.
   526  func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
   527  	if f.Object.Kind == ast.InputObject {
   528  		return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
   529  			templates.CurrentImports.LookupType(f.Object.Reference()),
   530  			templates.CurrentImports.LookupType(f.TypeReference.GO),
   531  		)
   532  	}
   533  
   534  	res := "(ctx context.Context"
   535  
   536  	if !f.Object.Root {
   537  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   538  	}
   539  	for _, arg := range f.Args {
   540  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   541  	}
   542  
   543  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   544  	if f.Object.Stream {
   545  		result = "<-chan " + result
   546  	}
   547  	// Named return.
   548  	var namedV, namedE string
   549  	if ft != nil {
   550  		if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 {
   551  			namedV = ft.Results.List[0].Names[0].Name
   552  		}
   553  		if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 {
   554  			namedE = ft.Results.List[1].Names[0].Name
   555  		}
   556  	}
   557  	res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
   558  	return res
   559  }
   560  
   561  func (f *Field) GoResultName() (string, bool) {
   562  	name := fmt.Sprintf("%v", f.TypeReference.GO)
   563  	splits := strings.Split(name, "/")
   564  
   565  	return splits[len(splits)-1], strings.HasPrefix(name, "[]")
   566  }
   567  
   568  func (f *Field) ComplexitySignature() string {
   569  	res := "func(childComplexity int"
   570  	for _, arg := range f.Args {
   571  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   572  	}
   573  	res += ") int"
   574  	return res
   575  }
   576  
   577  func (f *Field) ComplexityArgs() string {
   578  	args := make([]string, len(f.Args))
   579  	for i, arg := range f.Args {
   580  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   581  	}
   582  
   583  	return strings.Join(args, ", ")
   584  }
   585  
   586  func (f *Field) CallArgs() string {
   587  	args := make([]string, 0, len(f.Args)+2)
   588  
   589  	if f.IsResolver {
   590  		args = append(args, "rctx")
   591  
   592  		if !f.Object.Root {
   593  			args = append(args, "obj")
   594  		}
   595  	} else if f.MethodHasContext {
   596  		args = append(args, "ctx")
   597  	}
   598  
   599  	for _, arg := range f.Args {
   600  		tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   601  
   602  		if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() {
   603  			tmp = fmt.Sprintf(`
   604  				func () interface{} {
   605  					if fc.Args["%s"] == nil {
   606  						return nil
   607  					}
   608  					return fc.Args["%s"].(interface{})
   609  				}()`, arg.Name, arg.Name,
   610  			)
   611  		}
   612  
   613  		args = append(args, tmp)
   614  	}
   615  
   616  	return strings.Join(args, ", ")
   617  }