github.com/kerryoscer/gqlgen@v0.17.29/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	goast "go/ast"
     7  	"go/types"
     8  	"log"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/kerryoscer/gqlgen/codegen/config"
    14  	"github.com/kerryoscer/gqlgen/codegen/templates"
    15  	"github.com/vektah/gqlparser/v2/ast"
    16  	"golang.org/x/text/cases"
    17  	"golang.org/x/text/language"
    18  )
    19  
    20  type Field struct {
    21  	*ast.FieldDefinition
    22  
    23  	TypeReference    *config.TypeReference
    24  	GoFieldType      GoFieldType      // The field type in go, if any
    25  	GoReceiverName   string           // The name of method & var receiver in go, if any
    26  	GoFieldName      string           // The name of the method or var in go, if any
    27  	IsResolver       bool             // Does this field need a resolver
    28  	Args             []*FieldArgument // A list of arguments to be passed to this field
    29  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    30  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    31  	VOkFunc          bool             // If this is bound to a go method, is it of shape (interface{}, bool)
    32  	Object           *Object          // A link back to the parent object
    33  	Default          interface{}      // The default value
    34  	Stream           bool             // does this field return a channel?
    35  	Directives       []*Directive
    36  }
    37  
    38  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    39  	dirs, err := b.getDirectives(field.Directives)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	f := Field{
    45  		FieldDefinition: field,
    46  		Object:          obj,
    47  		Directives:      dirs,
    48  		GoFieldName:     templates.ToGo(field.Name),
    49  		GoFieldType:     GoFieldVariable,
    50  		GoReceiverName:  "obj",
    51  	}
    52  
    53  	if field.DefaultValue != nil {
    54  		var err error
    55  		f.Default, err = field.DefaultValue.Value(nil)
    56  		if err != nil {
    57  			return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
    58  		}
    59  	}
    60  
    61  	for _, arg := range field.Arguments {
    62  		newArg, err := b.buildArg(obj, arg)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		f.Args = append(f.Args, newArg)
    67  	}
    68  
    69  	if err = b.bindField(obj, &f); err != nil {
    70  		f.IsResolver = true
    71  		if errors.Is(err, config.ErrTypeNotFound) {
    72  			return nil, err
    73  		}
    74  		log.Println(err.Error())
    75  	}
    76  
    77  	if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    78  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    79  	}
    80  
    81  	return &f, nil
    82  }
    83  
    84  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    85  	defer func() {
    86  		if f.TypeReference == nil {
    87  			tr, err := b.Binder.TypeReference(f.Type, nil)
    88  			if err != nil {
    89  				errret = err
    90  			}
    91  			f.TypeReference = tr
    92  		}
    93  		if f.TypeReference != nil {
    94  			dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
    95  			if err != nil {
    96  				errret = err
    97  			}
    98  			for _, dir := range obj.Directives {
    99  				if dir.IsLocation(ast.LocationInputObject) {
   100  					dirs = append(dirs, dir)
   101  				}
   102  			}
   103  			f.Directives = append(dirs, f.Directives...)
   104  		}
   105  	}()
   106  
   107  	f.Stream = obj.Stream
   108  
   109  	switch {
   110  	case f.Name == "__schema":
   111  		f.GoFieldType = GoFieldMethod
   112  		f.GoReceiverName = "ec"
   113  		f.GoFieldName = "introspectSchema"
   114  		return nil
   115  	case f.Name == "__type":
   116  		f.GoFieldType = GoFieldMethod
   117  		f.GoReceiverName = "ec"
   118  		f.GoFieldName = "introspectType"
   119  		return nil
   120  	case f.Name == "_entities":
   121  		f.GoFieldType = GoFieldMethod
   122  		f.GoReceiverName = "ec"
   123  		f.GoFieldName = "__resolve_entities"
   124  		f.MethodHasContext = true
   125  		f.NoErr = true
   126  		return nil
   127  	case f.Name == "_service":
   128  		f.GoFieldType = GoFieldMethod
   129  		f.GoReceiverName = "ec"
   130  		f.GoFieldName = "__resolve__service"
   131  		f.MethodHasContext = true
   132  		return nil
   133  	case obj.Root:
   134  		f.IsResolver = true
   135  		return nil
   136  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   137  		f.IsResolver = true
   138  		return nil
   139  	case obj.Type == config.MapType:
   140  		f.GoFieldType = GoFieldMap
   141  		return nil
   142  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   143  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   144  	}
   145  
   146  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	pos := b.Binder.ObjectPosition(target)
   152  
   153  	switch target := target.(type) {
   154  	case nil:
   155  		objPos := b.Binder.TypePosition(obj.Type)
   156  		return fmt.Errorf(
   157  			"%s:%d adding resolver method for %s.%s, nothing matched",
   158  			objPos.Filename,
   159  			objPos.Line,
   160  			obj.Name,
   161  			f.Name,
   162  		)
   163  
   164  	case *types.Func:
   165  		sig := target.Type().(*types.Signature)
   166  		if sig.Results().Len() == 1 {
   167  			f.NoErr = true
   168  		} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
   169  			f.VOkFunc = true
   170  		} else if sig.Results().Len() != 2 {
   171  			return fmt.Errorf("method has wrong number of args")
   172  		}
   173  		params := sig.Params()
   174  		// If the first argument is the context, remove it from the comparison and set
   175  		// the MethodHasContext flag so that the context will be passed to this model's method
   176  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   177  			f.MethodHasContext = true
   178  			vars := make([]*types.Var, params.Len()-1)
   179  			for i := 1; i < params.Len(); i++ {
   180  				vars[i-1] = params.At(i)
   181  			}
   182  			params = types.NewTuple(vars...)
   183  		}
   184  
   185  		// Try to match target function's arguments with GraphQL field arguments.
   186  		newArgs, err := b.bindArgs(f, sig, params)
   187  		if err != nil {
   188  			return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
   189  		}
   190  
   191  		// Try to match target function's return types with GraphQL field return type
   192  		result := sig.Results().At(0)
   193  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   194  		if err != nil {
   195  			return err
   196  		}
   197  
   198  		// success, args and return type match. Bind to method
   199  		f.GoFieldType = GoFieldMethod
   200  		f.GoReceiverName = "obj"
   201  		f.GoFieldName = target.Name()
   202  		f.Args = newArgs
   203  		f.TypeReference = tr
   204  
   205  		return nil
   206  
   207  	case *types.Var:
   208  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   209  		if err != nil {
   210  			return err
   211  		}
   212  
   213  		// success, bind to var
   214  		f.GoFieldType = GoFieldVariable
   215  		f.GoReceiverName = "obj"
   216  		f.GoFieldName = target.Name()
   217  		f.TypeReference = tr
   218  
   219  		return nil
   220  	default:
   221  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   222  	}
   223  }
   224  
   225  // findBindTarget attempts to match the name to a field or method on a Type
   226  // with the following priorites:
   227  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   228  // 2. Any method or field with a matching name. Errors if more than one match is found
   229  // 3. Same logic again for embedded fields
   230  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   231  	// NOTE: a struct tag will override both methods and fields
   232  	// Bind to struct tag
   233  	found, err := b.findBindStructTagTarget(t, name)
   234  	if found != nil || err != nil {
   235  		return found, err
   236  	}
   237  
   238  	// Search for a method to bind to
   239  	foundMethod, err := b.findBindMethodTarget(t, name)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	// Search for a field to bind to
   245  	foundField, err := b.findBindFieldTarget(t, name)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	switch {
   251  	case foundField == nil && foundMethod != nil:
   252  		// Bind to method
   253  		return foundMethod, nil
   254  	case foundField != nil && foundMethod == nil:
   255  		// Bind to field
   256  		return foundField, nil
   257  	case foundField != nil && foundMethod != nil:
   258  		// Error
   259  		return nil, fmt.Errorf("found more than one way to bind for %s", name)
   260  	}
   261  
   262  	// Search embeds
   263  	return b.findBindEmbedsTarget(t, name)
   264  }
   265  
   266  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   267  	if b.Config.StructTag == "" {
   268  		return nil, nil
   269  	}
   270  
   271  	switch t := in.(type) {
   272  	case *types.Named:
   273  		return b.findBindStructTagTarget(t.Underlying(), name)
   274  	case *types.Struct:
   275  		var found types.Object
   276  		for i := 0; i < t.NumFields(); i++ {
   277  			field := t.Field(i)
   278  			if !field.Exported() || field.Embedded() {
   279  				continue
   280  			}
   281  			tags := reflect.StructTag(t.Tag(i))
   282  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   283  				if found != nil {
   284  					return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   285  				}
   286  
   287  				found = field
   288  			}
   289  		}
   290  
   291  		return found, nil
   292  	}
   293  
   294  	return nil, nil
   295  }
   296  
   297  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   298  	switch t := in.(type) {
   299  	case *types.Named:
   300  		if _, ok := t.Underlying().(*types.Interface); ok {
   301  			return b.findBindMethodTarget(t.Underlying(), name)
   302  		}
   303  
   304  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   305  	case *types.Interface:
   306  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   307  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   308  	}
   309  
   310  	return nil, nil
   311  }
   312  
   313  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   314  	var found types.Object
   315  	for i := 0; i < methodCount; i++ {
   316  		method := methodFunc(i)
   317  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   318  			continue
   319  		}
   320  
   321  		if found != nil {
   322  			return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
   323  		}
   324  
   325  		found = method
   326  	}
   327  
   328  	return found, nil
   329  }
   330  
   331  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   332  	switch t := in.(type) {
   333  	case *types.Named:
   334  		return b.findBindFieldTarget(t.Underlying(), name)
   335  	case *types.Struct:
   336  		var found types.Object
   337  		for i := 0; i < t.NumFields(); i++ {
   338  			field := t.Field(i)
   339  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   340  				continue
   341  			}
   342  
   343  			if found != nil {
   344  				return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
   345  			}
   346  
   347  			found = field
   348  		}
   349  
   350  		return found, nil
   351  	}
   352  
   353  	return nil, nil
   354  }
   355  
   356  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   357  	switch t := in.(type) {
   358  	case *types.Named:
   359  		return b.findBindEmbedsTarget(t.Underlying(), name)
   360  	case *types.Struct:
   361  		return b.findBindStructEmbedsTarget(t, name)
   362  	case *types.Interface:
   363  		return b.findBindInterfaceEmbedsTarget(t, name)
   364  	}
   365  
   366  	return nil, nil
   367  }
   368  
   369  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   370  	var found types.Object
   371  	for i := 0; i < strukt.NumFields(); i++ {
   372  		field := strukt.Field(i)
   373  		if !field.Embedded() {
   374  			continue
   375  		}
   376  
   377  		fieldType := field.Type()
   378  		if ptr, ok := fieldType.(*types.Pointer); ok {
   379  			fieldType = ptr.Elem()
   380  		}
   381  
   382  		f, err := b.findBindTarget(fieldType, name)
   383  		if err != nil {
   384  			return nil, err
   385  		}
   386  
   387  		if f != nil && found != nil {
   388  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   389  		}
   390  
   391  		if f != nil {
   392  			found = f
   393  		}
   394  	}
   395  
   396  	return found, nil
   397  }
   398  
   399  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   400  	var found types.Object
   401  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   402  		embeddedType := iface.EmbeddedType(i)
   403  
   404  		f, err := b.findBindTarget(embeddedType, name)
   405  		if err != nil {
   406  			return nil, err
   407  		}
   408  
   409  		if f != nil && found != nil {
   410  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   411  		}
   412  
   413  		if f != nil {
   414  			found = f
   415  		}
   416  	}
   417  
   418  	return found, nil
   419  }
   420  
   421  func (f *Field) HasDirectives() bool {
   422  	return len(f.ImplDirectives()) > 0
   423  }
   424  
   425  func (f *Field) DirectiveObjName() string {
   426  	if f.Object.Root {
   427  		return "nil"
   428  	}
   429  	return f.GoReceiverName
   430  }
   431  
   432  func (f *Field) ImplDirectives() []*Directive {
   433  	var d []*Directive
   434  	loc := ast.LocationFieldDefinition
   435  	if f.Object.IsInputType() {
   436  		loc = ast.LocationInputFieldDefinition
   437  	}
   438  	for i := range f.Directives {
   439  		if !f.Directives[i].Builtin &&
   440  			(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
   441  			d = append(d, f.Directives[i])
   442  		}
   443  	}
   444  	return d
   445  }
   446  
   447  func (f *Field) IsReserved() bool {
   448  	return strings.HasPrefix(f.Name, "__")
   449  }
   450  
   451  func (f *Field) IsMethod() bool {
   452  	return f.GoFieldType == GoFieldMethod
   453  }
   454  
   455  func (f *Field) IsVariable() bool {
   456  	return f.GoFieldType == GoFieldVariable
   457  }
   458  
   459  func (f *Field) IsMap() bool {
   460  	return f.GoFieldType == GoFieldMap
   461  }
   462  
   463  func (f *Field) IsConcurrent() bool {
   464  	if f.Object.DisableConcurrency {
   465  		return false
   466  	}
   467  	return f.MethodHasContext || f.IsResolver
   468  }
   469  
   470  func (f *Field) GoNameUnexported() string {
   471  	return templates.ToGoPrivate(f.Name)
   472  }
   473  
   474  func (f *Field) ShortInvocation() string {
   475  	caser := cases.Title(language.English, cases.NoLower)
   476  	if f.Object.Kind == ast.InputObject {
   477  		return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName)
   478  	}
   479  	return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
   480  }
   481  
   482  func (f *Field) ArgsFunc() string {
   483  	if len(f.Args) == 0 {
   484  		return ""
   485  	}
   486  
   487  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   488  }
   489  
   490  func (f *Field) FieldContextFunc() string {
   491  	return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name
   492  }
   493  
   494  func (f *Field) ChildFieldContextFunc(name string) string {
   495  	return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name
   496  }
   497  
   498  func (f *Field) ResolverType() string {
   499  	if !f.IsResolver {
   500  		return ""
   501  	}
   502  
   503  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   504  }
   505  
   506  func (f *Field) IsInputObject() bool {
   507  	return f.Object.Kind == ast.InputObject
   508  }
   509  
   510  func (f *Field) IsRoot() bool {
   511  	return f.Object.Root
   512  }
   513  
   514  func (f *Field) ShortResolverDeclaration() string {
   515  	return f.ShortResolverSignature(nil)
   516  }
   517  
   518  // ShortResolverSignature is identical to ShortResolverDeclaration,
   519  // but respects previous naming (return) conventions, if any.
   520  func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
   521  	if f.Object.Kind == ast.InputObject {
   522  		return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
   523  			templates.CurrentImports.LookupType(f.Object.Reference()),
   524  			templates.CurrentImports.LookupType(f.TypeReference.GO),
   525  		)
   526  	}
   527  
   528  	res := "(ctx context.Context"
   529  
   530  	if !f.Object.Root {
   531  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   532  	}
   533  	for _, arg := range f.Args {
   534  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   535  	}
   536  
   537  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   538  	if f.Object.Stream {
   539  		result = "<-chan " + result
   540  	}
   541  	// Named return.
   542  	var namedV, namedE string
   543  	if ft != nil {
   544  		if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 {
   545  			namedV = ft.Results.List[0].Names[0].Name
   546  		}
   547  		if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 {
   548  			namedE = ft.Results.List[1].Names[0].Name
   549  		}
   550  	}
   551  	res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
   552  	return res
   553  }
   554  
   555  func (f *Field) GoResultName() (string, bool) {
   556  	name := fmt.Sprintf("%v", f.TypeReference.GO)
   557  	splits := strings.Split(name, "/")
   558  
   559  	return splits[len(splits)-1], strings.HasPrefix(name, "[]")
   560  }
   561  
   562  func (f *Field) ComplexitySignature() string {
   563  	res := "func(childComplexity int"
   564  	for _, arg := range f.Args {
   565  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   566  	}
   567  	res += ") int"
   568  	return res
   569  }
   570  
   571  func (f *Field) ComplexityArgs() string {
   572  	args := make([]string, len(f.Args))
   573  	for i, arg := range f.Args {
   574  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   575  	}
   576  
   577  	return strings.Join(args, ", ")
   578  }
   579  
   580  func (f *Field) CallArgs() string {
   581  	args := make([]string, 0, len(f.Args)+2)
   582  
   583  	if f.IsResolver {
   584  		args = append(args, "rctx")
   585  
   586  		if !f.Object.Root {
   587  			args = append(args, "obj")
   588  		}
   589  	} else if f.MethodHasContext {
   590  		args = append(args, "ctx")
   591  	}
   592  
   593  	for _, arg := range f.Args {
   594  		tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   595  
   596  		if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() {
   597  			tmp = fmt.Sprintf(`
   598  				func () interface{} {
   599  					if fc.Args["%s"] == nil {
   600  						return nil
   601  					}
   602  					return fc.Args["%s"].(interface{})
   603  				}()`, arg.Name, arg.Name,
   604  			)
   605  		}
   606  
   607  		args = append(args, tmp)
   608  	}
   609  
   610  	return strings.Join(args, ", ")
   611  }