github.com/operandinc/gqlgen@v0.16.1/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/types"
     7  	"log"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/operandinc/gqlgen/codegen/config"
    13  	"github.com/operandinc/gqlgen/codegen/templates"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  type Field struct {
    18  	*ast.FieldDefinition
    19  
    20  	TypeReference    *config.TypeReference
    21  	GoFieldType      GoFieldType      // The field type in go, if any
    22  	GoReceiverName   string           // The name of method & var receiver in go, if any
    23  	GoFieldName      string           // The name of the method or var in go, if any
    24  	IsResolver       bool             // Does this field need a resolver
    25  	Args             []*FieldArgument // A list of arguments to be passed to this field
    26  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    27  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    28  	VOkFunc          bool             // If this is bound to a go method, is it of shape (interface{}, bool)
    29  	Object           *Object          // A link back to the parent object
    30  	Default          interface{}      // The default value
    31  	Stream           bool             // does this field return a channel?
    32  	Directives       []*Directive
    33  }
    34  
    35  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    36  	dirs, err := b.getDirectives(field.Directives)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	f := Field{
    42  		FieldDefinition: field,
    43  		Object:          obj,
    44  		Directives:      dirs,
    45  		GoFieldName:     templates.ToGo(field.Name),
    46  		GoFieldType:     GoFieldVariable,
    47  		GoReceiverName:  "obj",
    48  	}
    49  
    50  	if field.DefaultValue != nil {
    51  		var err error
    52  		f.Default, err = field.DefaultValue.Value(nil)
    53  		if err != nil {
    54  			return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
    55  		}
    56  	}
    57  
    58  	for _, arg := range field.Arguments {
    59  		newArg, err := b.buildArg(obj, arg)
    60  		if err != nil {
    61  			return nil, err
    62  		}
    63  		f.Args = append(f.Args, newArg)
    64  	}
    65  
    66  	if err = b.bindField(obj, &f); err != nil {
    67  		f.IsResolver = true
    68  		if errors.Is(err, config.ErrTypeNotFound) {
    69  			return nil, err
    70  		}
    71  		log.Println(err.Error())
    72  	}
    73  
    74  	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    75  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    76  	}
    77  
    78  	return &f, nil
    79  }
    80  
    81  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    82  	defer func() {
    83  		if f.TypeReference == nil {
    84  			tr, err := b.Binder.TypeReference(f.Type, nil)
    85  			if err != nil {
    86  				errret = err
    87  			}
    88  			f.TypeReference = tr
    89  		}
    90  		if f.TypeReference != nil {
    91  			dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
    92  			if err != nil {
    93  				errret = err
    94  			}
    95  			for _, dir := range obj.Directives {
    96  				if dir.IsLocation(ast.LocationInputObject) {
    97  					dirs = append(dirs, dir)
    98  				}
    99  			}
   100  			f.Directives = append(dirs, f.Directives...)
   101  		}
   102  	}()
   103  
   104  	f.Stream = obj.Stream
   105  
   106  	switch {
   107  	case f.Name == "__schema":
   108  		f.GoFieldType = GoFieldMethod
   109  		f.GoReceiverName = "ec"
   110  		f.GoFieldName = "introspectSchema"
   111  		return nil
   112  	case f.Name == "__type":
   113  		f.GoFieldType = GoFieldMethod
   114  		f.GoReceiverName = "ec"
   115  		f.GoFieldName = "introspectType"
   116  		return nil
   117  	case f.Name == "_entities":
   118  		f.GoFieldType = GoFieldMethod
   119  		f.GoReceiverName = "ec"
   120  		f.GoFieldName = "__resolve_entities"
   121  		f.MethodHasContext = true
   122  		f.NoErr = true
   123  		return nil
   124  	case f.Name == "_service":
   125  		f.GoFieldType = GoFieldMethod
   126  		f.GoReceiverName = "ec"
   127  		f.GoFieldName = "__resolve__service"
   128  		f.MethodHasContext = true
   129  		return nil
   130  	case obj.Root:
   131  		f.IsResolver = true
   132  		return nil
   133  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   134  		f.IsResolver = true
   135  		return nil
   136  	case obj.Type == config.MapType:
   137  		f.GoFieldType = GoFieldMap
   138  		return nil
   139  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   140  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   141  	}
   142  
   143  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	pos := b.Binder.ObjectPosition(target)
   149  
   150  	switch target := target.(type) {
   151  	case nil:
   152  		objPos := b.Binder.TypePosition(obj.Type)
   153  		return fmt.Errorf(
   154  			"%s:%d adding resolver method for %s.%s, nothing matched",
   155  			objPos.Filename,
   156  			objPos.Line,
   157  			obj.Name,
   158  			f.Name,
   159  		)
   160  
   161  	case *types.Func:
   162  		sig := target.Type().(*types.Signature)
   163  		if sig.Results().Len() == 1 {
   164  			f.NoErr = true
   165  		} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
   166  			f.VOkFunc = true
   167  		} else if sig.Results().Len() != 2 {
   168  			return fmt.Errorf("method has wrong number of args")
   169  		}
   170  		params := sig.Params()
   171  		// If the first argument is the context, remove it from the comparison and set
   172  		// the MethodHasContext flag so that the context will be passed to this model's method
   173  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   174  			f.MethodHasContext = true
   175  			vars := make([]*types.Var, params.Len()-1)
   176  			for i := 1; i < params.Len(); i++ {
   177  				vars[i-1] = params.At(i)
   178  			}
   179  			params = types.NewTuple(vars...)
   180  		}
   181  
   182  		// Try to match target function's arguments with GraphQL field arguments
   183  		newArgs, err := b.bindArgs(f, params)
   184  		if err != nil {
   185  			return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
   186  		}
   187  
   188  		// Try to match target function's return types with GraphQL field return type
   189  		result := sig.Results().At(0)
   190  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   191  		if err != nil {
   192  			return err
   193  		}
   194  
   195  		// success, args and return type match. Bind to method
   196  		f.GoFieldType = GoFieldMethod
   197  		f.GoReceiverName = "obj"
   198  		f.GoFieldName = target.Name()
   199  		f.Args = newArgs
   200  		f.TypeReference = tr
   201  
   202  		return nil
   203  
   204  	case *types.Var:
   205  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   206  		if err != nil {
   207  			return err
   208  		}
   209  
   210  		// success, bind to var
   211  		f.GoFieldType = GoFieldVariable
   212  		f.GoReceiverName = "obj"
   213  		f.GoFieldName = target.Name()
   214  		f.TypeReference = tr
   215  
   216  		return nil
   217  	default:
   218  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   219  	}
   220  }
   221  
   222  // findBindTarget attempts to match the name to a field or method on a Type
   223  // with the following priorites:
   224  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   225  // 2. Any method or field with a matching name. Errors if more than one match is found
   226  // 3. Same logic again for embedded fields
   227  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   228  	// NOTE: a struct tag will override both methods and fields
   229  	// Bind to struct tag
   230  	found, err := b.findBindStructTagTarget(t, name)
   231  	if found != nil || err != nil {
   232  		return found, err
   233  	}
   234  
   235  	// Search for a method to bind to
   236  	foundMethod, err := b.findBindMethodTarget(t, name)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  
   241  	// Search for a field to bind to
   242  	foundField, err := b.findBindFieldTarget(t, name)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  
   247  	switch {
   248  	case foundField == nil && foundMethod != nil:
   249  		// Bind to method
   250  		return foundMethod, nil
   251  	case foundField != nil && foundMethod == nil:
   252  		// Bind to field
   253  		return foundField, nil
   254  	case foundField != nil && foundMethod != nil:
   255  		// Error
   256  		return nil, fmt.Errorf("found more than one way to bind for %s", name)
   257  	}
   258  
   259  	// Search embeds
   260  	return b.findBindEmbedsTarget(t, name)
   261  }
   262  
   263  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   264  	if b.Config.StructTag == "" {
   265  		return nil, nil
   266  	}
   267  
   268  	switch t := in.(type) {
   269  	case *types.Named:
   270  		return b.findBindStructTagTarget(t.Underlying(), name)
   271  	case *types.Struct:
   272  		var found types.Object
   273  		for i := 0; i < t.NumFields(); i++ {
   274  			field := t.Field(i)
   275  			if !field.Exported() || field.Embedded() {
   276  				continue
   277  			}
   278  			tags := reflect.StructTag(t.Tag(i))
   279  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   280  				if found != nil {
   281  					return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   282  				}
   283  
   284  				found = field
   285  			}
   286  		}
   287  
   288  		return found, nil
   289  	}
   290  
   291  	return nil, nil
   292  }
   293  
   294  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   295  	switch t := in.(type) {
   296  	case *types.Named:
   297  		if _, ok := t.Underlying().(*types.Interface); ok {
   298  			return b.findBindMethodTarget(t.Underlying(), name)
   299  		}
   300  
   301  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   302  	case *types.Interface:
   303  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   304  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   305  	}
   306  
   307  	return nil, nil
   308  }
   309  
   310  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   311  	var found types.Object
   312  	for i := 0; i < methodCount; i++ {
   313  		method := methodFunc(i)
   314  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   315  			continue
   316  		}
   317  
   318  		if found != nil {
   319  			return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
   320  		}
   321  
   322  		found = method
   323  	}
   324  
   325  	return found, nil
   326  }
   327  
   328  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   329  	switch t := in.(type) {
   330  	case *types.Named:
   331  		return b.findBindFieldTarget(t.Underlying(), name)
   332  	case *types.Struct:
   333  		var found types.Object
   334  		for i := 0; i < t.NumFields(); i++ {
   335  			field := t.Field(i)
   336  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   337  				continue
   338  			}
   339  
   340  			if found != nil {
   341  				return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
   342  			}
   343  
   344  			found = field
   345  		}
   346  
   347  		return found, nil
   348  	}
   349  
   350  	return nil, nil
   351  }
   352  
   353  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   354  	switch t := in.(type) {
   355  	case *types.Named:
   356  		return b.findBindEmbedsTarget(t.Underlying(), name)
   357  	case *types.Struct:
   358  		return b.findBindStructEmbedsTarget(t, name)
   359  	case *types.Interface:
   360  		return b.findBindInterfaceEmbedsTarget(t, name)
   361  	}
   362  
   363  	return nil, nil
   364  }
   365  
   366  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   367  	var found types.Object
   368  	for i := 0; i < strukt.NumFields(); i++ {
   369  		field := strukt.Field(i)
   370  		if !field.Embedded() {
   371  			continue
   372  		}
   373  
   374  		fieldType := field.Type()
   375  		if ptr, ok := fieldType.(*types.Pointer); ok {
   376  			fieldType = ptr.Elem()
   377  		}
   378  
   379  		f, err := b.findBindTarget(fieldType, name)
   380  		if err != nil {
   381  			return nil, err
   382  		}
   383  
   384  		if f != nil && found != nil {
   385  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   386  		}
   387  
   388  		if f != nil {
   389  			found = f
   390  		}
   391  	}
   392  
   393  	return found, nil
   394  }
   395  
   396  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   397  	var found types.Object
   398  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   399  		embeddedType := iface.EmbeddedType(i)
   400  
   401  		f, err := b.findBindTarget(embeddedType, name)
   402  		if err != nil {
   403  			return nil, err
   404  		}
   405  
   406  		if f != nil && found != nil {
   407  			return nil, fmt.Errorf("found more than one way to bind for %s", name)
   408  		}
   409  
   410  		if f != nil {
   411  			found = f
   412  		}
   413  	}
   414  
   415  	return found, nil
   416  }
   417  
   418  func (f *Field) HasDirectives() bool {
   419  	return len(f.ImplDirectives()) > 0
   420  }
   421  
   422  func (f *Field) DirectiveObjName() string {
   423  	if f.Object.Root {
   424  		return "nil"
   425  	}
   426  	return f.GoReceiverName
   427  }
   428  
   429  func (f *Field) ImplDirectives() []*Directive {
   430  	var d []*Directive
   431  	loc := ast.LocationFieldDefinition
   432  	if f.Object.IsInputType() {
   433  		loc = ast.LocationInputFieldDefinition
   434  	}
   435  	for i := range f.Directives {
   436  		if !f.Directives[i].Builtin &&
   437  			(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
   438  			d = append(d, f.Directives[i])
   439  		}
   440  	}
   441  	return d
   442  }
   443  
   444  func (f *Field) IsReserved() bool {
   445  	return strings.HasPrefix(f.Name, "__")
   446  }
   447  
   448  func (f *Field) IsMethod() bool {
   449  	return f.GoFieldType == GoFieldMethod
   450  }
   451  
   452  func (f *Field) IsVariable() bool {
   453  	return f.GoFieldType == GoFieldVariable
   454  }
   455  
   456  func (f *Field) IsMap() bool {
   457  	return f.GoFieldType == GoFieldMap
   458  }
   459  
   460  func (f *Field) IsConcurrent() bool {
   461  	if f.Object.DisableConcurrency {
   462  		return false
   463  	}
   464  	return f.MethodHasContext || f.IsResolver
   465  }
   466  
   467  func (f *Field) GoNameUnexported() string {
   468  	return templates.ToGoPrivate(f.Name)
   469  }
   470  
   471  func (f *Field) ShortInvocation() string {
   472  	if f.Object.Kind == ast.InputObject {
   473  		return fmt.Sprintf("%s().%s(ctx, &it, data)", strings.Title(f.Object.Definition.Name), f.GoFieldName)
   474  	}
   475  	return fmt.Sprintf("%s().%s(%s)", strings.Title(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
   476  }
   477  
   478  func (f *Field) ArgsFunc() string {
   479  	if len(f.Args) == 0 {
   480  		return ""
   481  	}
   482  
   483  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   484  }
   485  
   486  func (f *Field) ResolverType() string {
   487  	if !f.IsResolver {
   488  		return ""
   489  	}
   490  
   491  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   492  }
   493  
   494  func (f *Field) ShortResolverDeclaration() string {
   495  	if f.Object.Kind == ast.InputObject {
   496  		return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
   497  			templates.CurrentImports.LookupType(f.Object.Reference()),
   498  			templates.CurrentImports.LookupType(f.TypeReference.GO),
   499  		)
   500  	}
   501  
   502  	res := "(ctx context.Context"
   503  
   504  	if !f.Object.Root {
   505  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   506  	}
   507  	for _, arg := range f.Args {
   508  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   509  	}
   510  
   511  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   512  	if f.Object.Stream {
   513  		result = "<-chan " + result
   514  	}
   515  
   516  	res += fmt.Sprintf(") (%s, error)", result)
   517  	return res
   518  }
   519  
   520  func (f *Field) ComplexitySignature() string {
   521  	res := "func(childComplexity int"
   522  	for _, arg := range f.Args {
   523  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   524  	}
   525  	res += ") int"
   526  	return res
   527  }
   528  
   529  func (f *Field) ComplexityArgs() string {
   530  	args := make([]string, len(f.Args))
   531  	for i, arg := range f.Args {
   532  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   533  	}
   534  
   535  	return strings.Join(args, ", ")
   536  }
   537  
   538  func (f *Field) CallArgs() string {
   539  	args := make([]string, 0, len(f.Args)+2)
   540  
   541  	if f.IsResolver {
   542  		args = append(args, "rctx")
   543  
   544  		if !f.Object.Root {
   545  			args = append(args, "obj")
   546  		}
   547  	} else if f.MethodHasContext {
   548  		args = append(args, "ctx")
   549  	}
   550  
   551  	for _, arg := range f.Args {
   552  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   553  	}
   554  
   555  	return strings.Join(args, ", ")
   556  }