github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"log"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/animeshon/gqlgen/codegen/config"
    12  	"github.com/animeshon/gqlgen/codegen/templates"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  type Field struct {
    18  	*ast.FieldDefinition
    19  
    20  	TypeReference    *config.TypeReference
    21  	GoFieldType      GoFieldType      // The field type in go, if any
    22  	GoReceiverName   string           // The name of method & var receiver in go, if any
    23  	GoFieldName      string           // The name of the method or var in go, if any
    24  	IsResolver       bool             // Does this field need a resolver
    25  	Args             []*FieldArgument // A list of arguments to be passed to this field
    26  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    27  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    28  	Object           *Object          // A link back to the parent object
    29  	Default          interface{}      // The default value
    30  	Stream           bool             // does this field return a channel?
    31  	Directives       []*Directive
    32  }
    33  
    34  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    35  	dirs, err := b.getDirectives(field.Directives)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	f := Field{
    41  		FieldDefinition: field,
    42  		Object:          obj,
    43  		Directives:      dirs,
    44  		GoFieldName:     templates.ToGo(field.Name),
    45  		GoFieldType:     GoFieldVariable,
    46  		GoReceiverName:  "obj",
    47  	}
    48  
    49  	if field.DefaultValue != nil {
    50  		var err error
    51  		f.Default, err = field.DefaultValue.Value(nil)
    52  		if err != nil {
    53  			return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error())
    54  		}
    55  	}
    56  
    57  	for _, arg := range field.Arguments {
    58  		newArg, err := b.buildArg(obj, arg)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		f.Args = append(f.Args, newArg)
    63  	}
    64  
    65  	if err = b.bindField(obj, &f); err != nil {
    66  		f.IsResolver = true
    67  		log.Println(err.Error())
    68  	}
    69  
    70  	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    71  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    72  	}
    73  
    74  	return &f, nil
    75  }
    76  
    77  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    78  	defer func() {
    79  		if f.TypeReference == nil {
    80  			tr, err := b.Binder.TypeReference(f.Type, nil)
    81  			if err != nil {
    82  				errret = err
    83  			}
    84  			f.TypeReference = tr
    85  		}
    86  		if f.TypeReference != nil {
    87  			dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
    88  			if err != nil {
    89  				errret = err
    90  			}
    91  			f.Directives = append(dirs, f.Directives...)
    92  		}
    93  	}()
    94  
    95  	f.Stream = obj.Stream
    96  
    97  	switch {
    98  	case f.Name == "__schema":
    99  		f.GoFieldType = GoFieldMethod
   100  		f.GoReceiverName = "ec"
   101  		f.GoFieldName = "introspectSchema"
   102  		return nil
   103  	case f.Name == "__type":
   104  		f.GoFieldType = GoFieldMethod
   105  		f.GoReceiverName = "ec"
   106  		f.GoFieldName = "introspectType"
   107  		return nil
   108  	case f.Name == "_entities":
   109  		f.GoFieldType = GoFieldMethod
   110  		f.GoReceiverName = "ec"
   111  		f.GoFieldName = "__resolve_entities"
   112  		f.MethodHasContext = true
   113  		return nil
   114  	case f.Name == "_service":
   115  		f.GoFieldType = GoFieldMethod
   116  		f.GoReceiverName = "ec"
   117  		f.GoFieldName = "__resolve__service"
   118  		f.MethodHasContext = true
   119  		return nil
   120  	case obj.Root:
   121  		f.IsResolver = true
   122  		return nil
   123  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   124  		f.IsResolver = true
   125  		return nil
   126  	case obj.Type == config.MapType:
   127  		f.GoFieldType = GoFieldMap
   128  		return nil
   129  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   130  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   131  	}
   132  
   133  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	pos := b.Binder.ObjectPosition(target)
   139  
   140  	switch target := target.(type) {
   141  	case nil:
   142  		objPos := b.Binder.TypePosition(obj.Type)
   143  		return fmt.Errorf(
   144  			"%s:%d adding resolver method for %s.%s, nothing matched",
   145  			objPos.Filename,
   146  			objPos.Line,
   147  			obj.Name,
   148  			f.Name,
   149  		)
   150  
   151  	case *types.Func:
   152  		sig := target.Type().(*types.Signature)
   153  		if sig.Results().Len() == 1 {
   154  			f.NoErr = true
   155  		} else if sig.Results().Len() != 2 {
   156  			return fmt.Errorf("method has wrong number of args")
   157  		}
   158  		params := sig.Params()
   159  		// If the first argument is the context, remove it from the comparison and set
   160  		// the MethodHasContext flag so that the context will be passed to this model's method
   161  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   162  			f.MethodHasContext = true
   163  			vars := make([]*types.Var, params.Len()-1)
   164  			for i := 1; i < params.Len(); i++ {
   165  				vars[i-1] = params.At(i)
   166  			}
   167  			params = types.NewTuple(vars...)
   168  		}
   169  
   170  		if err = b.bindArgs(f, params); err != nil {
   171  			return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line)
   172  		}
   173  
   174  		result := sig.Results().At(0)
   175  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   176  		if err != nil {
   177  			return err
   178  		}
   179  
   180  		// success, args and return type match. Bind to method
   181  		f.GoFieldType = GoFieldMethod
   182  		f.GoReceiverName = "obj"
   183  		f.GoFieldName = target.Name()
   184  		f.TypeReference = tr
   185  
   186  		return nil
   187  
   188  	case *types.Var:
   189  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   190  		if err != nil {
   191  			return err
   192  		}
   193  
   194  		// success, bind to var
   195  		f.GoFieldType = GoFieldVariable
   196  		f.GoReceiverName = "obj"
   197  		f.GoFieldName = target.Name()
   198  		f.TypeReference = tr
   199  
   200  		return nil
   201  	default:
   202  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   203  	}
   204  }
   205  
   206  // findBindTarget attempts to match the name to a field or method on a Type
   207  // with the following priorites:
   208  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   209  // 2. Any method or field with a matching name. Errors if more than one match is found
   210  // 3. Same logic again for embedded fields
   211  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   212  	// NOTE: a struct tag will override both methods and fields
   213  	// Bind to struct tag
   214  	found, err := b.findBindStructTagTarget(t, name)
   215  	if found != nil || err != nil {
   216  		return found, err
   217  	}
   218  
   219  	// Search for a method to bind to
   220  	foundMethod, err := b.findBindMethodTarget(t, name)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  
   225  	// Search for a field to bind to
   226  	foundField, err := b.findBindFieldTarget(t, name)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	switch {
   232  	case foundField == nil && foundMethod != nil:
   233  		// Bind to method
   234  		return foundMethod, nil
   235  	case foundField != nil && foundMethod == nil:
   236  		// Bind to field
   237  		return foundField, nil
   238  	case foundField != nil && foundMethod != nil:
   239  		// Error
   240  		return nil, errors.Errorf("found more than one way to bind for %s", name)
   241  	}
   242  
   243  	// Search embeds
   244  	return b.findBindEmbedsTarget(t, name)
   245  }
   246  
   247  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   248  	if b.Config.StructTag == "" {
   249  		return nil, nil
   250  	}
   251  
   252  	switch t := in.(type) {
   253  	case *types.Named:
   254  		return b.findBindStructTagTarget(t.Underlying(), name)
   255  	case *types.Struct:
   256  		var found types.Object
   257  		for i := 0; i < t.NumFields(); i++ {
   258  			field := t.Field(i)
   259  			if !field.Exported() || field.Embedded() {
   260  				continue
   261  			}
   262  			tags := reflect.StructTag(t.Tag(i))
   263  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   264  				if found != nil {
   265  					return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   266  				}
   267  
   268  				found = field
   269  			}
   270  		}
   271  
   272  		return found, nil
   273  	}
   274  
   275  	return nil, nil
   276  }
   277  
   278  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   279  	switch t := in.(type) {
   280  	case *types.Named:
   281  		if _, ok := t.Underlying().(*types.Interface); ok {
   282  			return b.findBindMethodTarget(t.Underlying(), name)
   283  		}
   284  
   285  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   286  	case *types.Interface:
   287  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   288  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   289  	}
   290  
   291  	return nil, nil
   292  }
   293  
   294  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   295  	var found types.Object
   296  	for i := 0; i < methodCount; i++ {
   297  		method := methodFunc(i)
   298  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   299  			continue
   300  		}
   301  
   302  		if found != nil {
   303  			return nil, errors.Errorf("found more than one matching method to bind for %s", name)
   304  		}
   305  
   306  		found = method
   307  	}
   308  
   309  	return found, nil
   310  }
   311  
   312  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   313  	switch t := in.(type) {
   314  	case *types.Named:
   315  		return b.findBindFieldTarget(t.Underlying(), name)
   316  	case *types.Struct:
   317  		var found types.Object
   318  		for i := 0; i < t.NumFields(); i++ {
   319  			field := t.Field(i)
   320  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   321  				continue
   322  			}
   323  
   324  			if found != nil {
   325  				return nil, errors.Errorf("found more than one matching field to bind for %s", name)
   326  			}
   327  
   328  			found = field
   329  		}
   330  
   331  		return found, nil
   332  	}
   333  
   334  	return nil, nil
   335  }
   336  
   337  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   338  	switch t := in.(type) {
   339  	case *types.Named:
   340  		return b.findBindEmbedsTarget(t.Underlying(), name)
   341  	case *types.Struct:
   342  		return b.findBindStructEmbedsTarget(t, name)
   343  	case *types.Interface:
   344  		return b.findBindInterfaceEmbedsTarget(t, name)
   345  	}
   346  
   347  	return nil, nil
   348  }
   349  
   350  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   351  	var found types.Object
   352  	for i := 0; i < strukt.NumFields(); i++ {
   353  		field := strukt.Field(i)
   354  		if !field.Embedded() {
   355  			continue
   356  		}
   357  
   358  		fieldType := field.Type()
   359  		if ptr, ok := fieldType.(*types.Pointer); ok {
   360  			fieldType = ptr.Elem()
   361  		}
   362  
   363  		f, err := b.findBindTarget(fieldType, name)
   364  		if err != nil {
   365  			return nil, err
   366  		}
   367  
   368  		if f != nil && found != nil {
   369  			return nil, errors.Errorf("found more than one way to bind for %s", name)
   370  		}
   371  
   372  		if f != nil {
   373  			found = f
   374  		}
   375  	}
   376  
   377  	return found, nil
   378  }
   379  
   380  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   381  	var found types.Object
   382  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   383  		embeddedType := iface.EmbeddedType(i)
   384  
   385  		f, err := b.findBindTarget(embeddedType, name)
   386  		if err != nil {
   387  			return nil, err
   388  		}
   389  
   390  		if f != nil && found != nil {
   391  			return nil, errors.Errorf("found more than one way to bind for %s", name)
   392  		}
   393  
   394  		if f != nil {
   395  			found = f
   396  		}
   397  	}
   398  
   399  	return found, nil
   400  }
   401  
   402  func (f *Field) HasDirectives() bool {
   403  	return len(f.ImplDirectives()) > 0
   404  }
   405  
   406  func (f *Field) DirectiveObjName() string {
   407  	if f.Object.Root {
   408  		return "nil"
   409  	}
   410  	return f.GoReceiverName
   411  }
   412  
   413  func (f *Field) ImplDirectives() []*Directive {
   414  	var d []*Directive
   415  	loc := ast.LocationFieldDefinition
   416  	if f.Object.IsInputType() {
   417  		loc = ast.LocationInputFieldDefinition
   418  	}
   419  	for i := range f.Directives {
   420  		if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc, ast.LocationObject) {
   421  			d = append(d, f.Directives[i])
   422  		}
   423  	}
   424  	return d
   425  }
   426  
   427  func (f *Field) IsReserved() bool {
   428  	return strings.HasPrefix(f.Name, "__")
   429  }
   430  
   431  func (f *Field) IsMethod() bool {
   432  	return f.GoFieldType == GoFieldMethod
   433  }
   434  
   435  func (f *Field) IsVariable() bool {
   436  	return f.GoFieldType == GoFieldVariable
   437  }
   438  
   439  func (f *Field) IsMap() bool {
   440  	return f.GoFieldType == GoFieldMap
   441  }
   442  
   443  func (f *Field) IsConcurrent() bool {
   444  	if f.Object.DisableConcurrency {
   445  		return false
   446  	}
   447  	return f.MethodHasContext || f.IsResolver
   448  }
   449  
   450  func (f *Field) GoNameUnexported() string {
   451  	return templates.ToGoPrivate(f.Name)
   452  }
   453  
   454  func (f *Field) ShortInvocation() string {
   455  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   456  }
   457  
   458  func (f *Field) ArgsFunc() string {
   459  	if len(f.Args) == 0 {
   460  		return ""
   461  	}
   462  
   463  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   464  }
   465  
   466  func (f *Field) ResolverType() string {
   467  	if !f.IsResolver {
   468  		return ""
   469  	}
   470  
   471  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   472  }
   473  
   474  func (f *Field) ShortResolverDeclaration() string {
   475  	res := "(ctx context.Context"
   476  
   477  	if !f.Object.Root {
   478  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   479  	}
   480  	for _, arg := range f.Args {
   481  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   482  	}
   483  
   484  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   485  	if f.Object.Stream {
   486  		result = "<-chan " + result
   487  	}
   488  
   489  	res += fmt.Sprintf(") (%s, error)", result)
   490  	return res
   491  }
   492  
   493  func (f *Field) ComplexitySignature() string {
   494  	res := "func(childComplexity int"
   495  	for _, arg := range f.Args {
   496  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   497  	}
   498  	res += ") int"
   499  	return res
   500  }
   501  
   502  func (f *Field) ComplexityArgs() string {
   503  	args := make([]string, len(f.Args))
   504  	for i, arg := range f.Args {
   505  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   506  	}
   507  
   508  	return strings.Join(args, ", ")
   509  }
   510  
   511  func (f *Field) CallArgs() string {
   512  	args := make([]string, 0, len(f.Args)+2)
   513  
   514  	if f.IsResolver {
   515  		args = append(args, "rctx")
   516  
   517  		if !f.Object.Root {
   518  			args = append(args, "obj")
   519  		}
   520  	} else if f.MethodHasContext {
   521  		args = append(args, "ctx")
   522  	}
   523  
   524  	for _, arg := range f.Args {
   525  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   526  	}
   527  
   528  	return strings.Join(args, ", ")
   529  }