git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"log"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"git.sr.ht/~sircmpwn/gqlgen/codegen/config"
    12  	"git.sr.ht/~sircmpwn/gqlgen/codegen/templates"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  type Field struct {
    18  	*ast.FieldDefinition
    19  
    20  	TypeReference    *config.TypeReference
    21  	GoFieldType      GoFieldType      // The field type in go, if any
    22  	GoReceiverName   string           // The name of method & var receiver in go, if any
    23  	GoFieldName      string           // The name of the method or var in go, if any
    24  	IsResolver       bool             // Does this field need a resolver
    25  	Args             []*FieldArgument // A list of arguments to be passed to this field
    26  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    27  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    28  	Object           *Object          // A link back to the parent object
    29  	Default          interface{}      // The default value
    30  	Stream           bool             // does this field return a channel?
    31  	Directives       []*Directive
    32  }
    33  
    34  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    35  	dirs, err := b.getDirectives(field.Directives)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	f := Field{
    41  		FieldDefinition: field,
    42  		Object:          obj,
    43  		Directives:      dirs,
    44  		GoFieldName:     templates.ToGo(field.Name),
    45  		GoFieldType:     GoFieldVariable,
    46  		GoReceiverName:  "obj",
    47  	}
    48  
    49  	if field.DefaultValue != nil {
    50  		var err error
    51  		f.Default, err = field.DefaultValue.Value(nil)
    52  		if err != nil {
    53  			return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error())
    54  		}
    55  	}
    56  
    57  	for _, arg := range field.Arguments {
    58  		newArg, err := b.buildArg(obj, arg)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		f.Args = append(f.Args, newArg)
    63  	}
    64  
    65  	if err = b.bindField(obj, &f); err != nil {
    66  		f.IsResolver = true
    67  		log.Println(err.Error())
    68  	}
    69  
    70  	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    71  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    72  	}
    73  
    74  	return &f, nil
    75  }
    76  
    77  func (b *builder) bindField(obj *Object, f *Field) (errret error) {
    78  	defer func() {
    79  		if f.TypeReference == nil {
    80  			tr, err := b.Binder.TypeReference(f.Type, nil)
    81  			if err != nil {
    82  				errret = err
    83  			}
    84  			f.TypeReference = tr
    85  		}
    86  	}()
    87  
    88  	f.Stream = obj.Stream
    89  
    90  	switch {
    91  	case f.Name == "__schema":
    92  		f.GoFieldType = GoFieldMethod
    93  		f.GoReceiverName = "ec"
    94  		f.GoFieldName = "introspectSchema"
    95  		return nil
    96  	case f.Name == "__type":
    97  		f.GoFieldType = GoFieldMethod
    98  		f.GoReceiverName = "ec"
    99  		f.GoFieldName = "introspectType"
   100  		return nil
   101  	case f.Name == "_entities":
   102  		f.GoFieldType = GoFieldMethod
   103  		f.GoReceiverName = "ec"
   104  		f.GoFieldName = "__resolve_entities"
   105  		f.MethodHasContext = true
   106  		return nil
   107  	case f.Name == "_service":
   108  		f.GoFieldType = GoFieldMethod
   109  		f.GoReceiverName = "ec"
   110  		f.GoFieldName = "__resolve__service"
   111  		f.MethodHasContext = true
   112  		return nil
   113  	case obj.Root:
   114  		f.IsResolver = true
   115  		return nil
   116  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   117  		f.IsResolver = true
   118  		return nil
   119  	case obj.Type == config.MapType:
   120  		f.GoFieldType = GoFieldMap
   121  		return nil
   122  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   123  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   124  	}
   125  
   126  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	pos := b.Binder.ObjectPosition(target)
   132  
   133  	switch target := target.(type) {
   134  	case nil:
   135  		objPos := b.Binder.TypePosition(obj.Type)
   136  		return fmt.Errorf(
   137  			"%s:%d adding resolver method for %s.%s, nothing matched",
   138  			objPos.Filename,
   139  			objPos.Line,
   140  			obj.Name,
   141  			f.Name,
   142  		)
   143  
   144  	case *types.Func:
   145  		sig := target.Type().(*types.Signature)
   146  		if sig.Results().Len() == 1 {
   147  			f.NoErr = true
   148  		} else if sig.Results().Len() != 2 {
   149  			return fmt.Errorf("method has wrong number of args")
   150  		}
   151  		params := sig.Params()
   152  		// If the first argument is the context, remove it from the comparison and set
   153  		// the MethodHasContext flag so that the context will be passed to this model's method
   154  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   155  			f.MethodHasContext = true
   156  			vars := make([]*types.Var, params.Len()-1)
   157  			for i := 1; i < params.Len(); i++ {
   158  				vars[i-1] = params.At(i)
   159  			}
   160  			params = types.NewTuple(vars...)
   161  		}
   162  
   163  		if err = b.bindArgs(f, params); err != nil {
   164  			return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line)
   165  		}
   166  
   167  		result := sig.Results().At(0)
   168  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   169  		if err != nil {
   170  			return err
   171  		}
   172  
   173  		// success, args and return type match. Bind to method
   174  		f.GoFieldType = GoFieldMethod
   175  		f.GoReceiverName = "obj"
   176  		f.GoFieldName = target.Name()
   177  		f.TypeReference = tr
   178  
   179  		return nil
   180  
   181  	case *types.Var:
   182  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   183  		if err != nil {
   184  			return err
   185  		}
   186  
   187  		// success, bind to var
   188  		f.GoFieldType = GoFieldVariable
   189  		f.GoReceiverName = "obj"
   190  		f.GoFieldName = target.Name()
   191  		f.TypeReference = tr
   192  
   193  		return nil
   194  	default:
   195  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   196  	}
   197  }
   198  
   199  // findBindTarget attempts to match the name to a field or method on a Type
   200  // with the following priorites:
   201  // 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
   202  // 2. Any method or field with a matching name. Errors if more than one match is found
   203  // 3. Same logic again for embedded fields
   204  func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
   205  	// NOTE: a struct tag will override both methods and fields
   206  	// Bind to struct tag
   207  	found, err := b.findBindStructTagTarget(t, name)
   208  	if found != nil || err != nil {
   209  		return found, err
   210  	}
   211  
   212  	// Search for a method to bind to
   213  	foundMethod, err := b.findBindMethodTarget(t, name)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	// Search for a field to bind to
   219  	foundField, err := b.findBindFieldTarget(t, name)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	switch {
   225  	case foundField == nil && foundMethod != nil:
   226  		// Bind to method
   227  		return foundMethod, nil
   228  	case foundField != nil && foundMethod == nil:
   229  		// Bind to field
   230  		return foundField, nil
   231  	case foundField != nil && foundMethod != nil:
   232  		// Error
   233  		return nil, errors.Errorf("found more than one way to bind for %s", name)
   234  	}
   235  
   236  	// Search embeds
   237  	return b.findBindEmbedsTarget(t, name)
   238  }
   239  
   240  func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
   241  	if b.Config.StructTag == "" {
   242  		return nil, nil
   243  	}
   244  
   245  	switch t := in.(type) {
   246  	case *types.Named:
   247  		return b.findBindStructTagTarget(t.Underlying(), name)
   248  	case *types.Struct:
   249  		var found types.Object
   250  		for i := 0; i < t.NumFields(); i++ {
   251  			field := t.Field(i)
   252  			if !field.Exported() || field.Embedded() {
   253  				continue
   254  			}
   255  			tags := reflect.StructTag(t.Tag(i))
   256  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   257  				if found != nil {
   258  					return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   259  				}
   260  
   261  				found = field
   262  			}
   263  		}
   264  
   265  		return found, nil
   266  	}
   267  
   268  	return nil, nil
   269  }
   270  
   271  func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
   272  	switch t := in.(type) {
   273  	case *types.Named:
   274  		if _, ok := t.Underlying().(*types.Interface); ok {
   275  			return b.findBindMethodTarget(t.Underlying(), name)
   276  		}
   277  
   278  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   279  	case *types.Interface:
   280  		// FIX-ME: Should use ExplicitMethod here? What's the difference?
   281  		return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
   282  	}
   283  
   284  	return nil, nil
   285  }
   286  
   287  func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
   288  	var found types.Object
   289  	for i := 0; i < methodCount; i++ {
   290  		method := methodFunc(i)
   291  		if !method.Exported() || !strings.EqualFold(method.Name(), name) {
   292  			continue
   293  		}
   294  
   295  		if found != nil {
   296  			return nil, errors.Errorf("found more than one matching method to bind for %s", name)
   297  		}
   298  
   299  		found = method
   300  	}
   301  
   302  	return found, nil
   303  }
   304  
   305  func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
   306  	switch t := in.(type) {
   307  	case *types.Named:
   308  		return b.findBindFieldTarget(t.Underlying(), name)
   309  	case *types.Struct:
   310  		var found types.Object
   311  		for i := 0; i < t.NumFields(); i++ {
   312  			field := t.Field(i)
   313  			if !field.Exported() || !equalFieldName(field.Name(), name) {
   314  				continue
   315  			}
   316  
   317  			if found != nil {
   318  				return nil, errors.Errorf("found more than one matching field to bind for %s", name)
   319  			}
   320  
   321  			found = field
   322  		}
   323  
   324  		return found, nil
   325  	}
   326  
   327  	return nil, nil
   328  }
   329  
   330  func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
   331  	switch t := in.(type) {
   332  	case *types.Named:
   333  		return b.findBindEmbedsTarget(t.Underlying(), name)
   334  	case *types.Struct:
   335  		return b.findBindStructEmbedsTarget(t, name)
   336  	case *types.Interface:
   337  		return b.findBindInterfaceEmbedsTarget(t, name)
   338  	}
   339  
   340  	return nil, nil
   341  }
   342  
   343  func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
   344  	var found types.Object
   345  	for i := 0; i < strukt.NumFields(); i++ {
   346  		field := strukt.Field(i)
   347  		if !field.Embedded() {
   348  			continue
   349  		}
   350  
   351  		fieldType := field.Type()
   352  		if ptr, ok := fieldType.(*types.Pointer); ok {
   353  			fieldType = ptr.Elem()
   354  		}
   355  
   356  		f, err := b.findBindTarget(fieldType, name)
   357  		if err != nil {
   358  			return nil, err
   359  		}
   360  
   361  		if f != nil && found != nil {
   362  			return nil, errors.Errorf("found more than one way to bind for %s", name)
   363  		}
   364  
   365  		if f != nil {
   366  			found = f
   367  		}
   368  	}
   369  
   370  	return found, nil
   371  }
   372  
   373  func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
   374  	var found types.Object
   375  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   376  		embeddedType := iface.EmbeddedType(i)
   377  
   378  		f, err := b.findBindTarget(embeddedType, name)
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  
   383  		if f != nil && found != nil {
   384  			return nil, errors.Errorf("found more than one way to bind for %s", name)
   385  		}
   386  
   387  		if f != nil {
   388  			found = f
   389  		}
   390  	}
   391  
   392  	return found, nil
   393  }
   394  
   395  func (f *Field) HasDirectives() bool {
   396  	return len(f.ImplDirectives()) > 0
   397  }
   398  
   399  func (f *Field) DirectiveObjName() string {
   400  	if f.Object.Root {
   401  		return "nil"
   402  	}
   403  	return f.GoReceiverName
   404  }
   405  
   406  func (f *Field) ImplDirectives() []*Directive {
   407  	var d []*Directive
   408  	loc := ast.LocationFieldDefinition
   409  	if f.Object.IsInputType() {
   410  		loc = ast.LocationInputFieldDefinition
   411  	}
   412  	for i := range f.Directives {
   413  		if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) {
   414  			d = append(d, f.Directives[i])
   415  		}
   416  	}
   417  	return d
   418  }
   419  
   420  func (f *Field) IsReserved() bool {
   421  	return strings.HasPrefix(f.Name, "__")
   422  }
   423  
   424  func (f *Field) IsMethod() bool {
   425  	return f.GoFieldType == GoFieldMethod
   426  }
   427  
   428  func (f *Field) IsVariable() bool {
   429  	return f.GoFieldType == GoFieldVariable
   430  }
   431  
   432  func (f *Field) IsMap() bool {
   433  	return f.GoFieldType == GoFieldMap
   434  }
   435  
   436  func (f *Field) IsConcurrent() bool {
   437  	if f.Object.DisableConcurrency {
   438  		return false
   439  	}
   440  	return f.MethodHasContext || f.IsResolver
   441  }
   442  
   443  func (f *Field) GoNameUnexported() string {
   444  	return templates.ToGoPrivate(f.Name)
   445  }
   446  
   447  func (f *Field) ShortInvocation() string {
   448  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   449  }
   450  
   451  func (f *Field) ArgsFunc() string {
   452  	if len(f.Args) == 0 {
   453  		return ""
   454  	}
   455  
   456  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   457  }
   458  
   459  func (f *Field) ResolverType() string {
   460  	if !f.IsResolver {
   461  		return ""
   462  	}
   463  
   464  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   465  }
   466  
   467  func (f *Field) ShortResolverDeclaration() string {
   468  	res := "(ctx context.Context"
   469  
   470  	if !f.Object.Root {
   471  		res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
   472  	}
   473  	for _, arg := range f.Args {
   474  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   475  	}
   476  
   477  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   478  	if f.Object.Stream {
   479  		result = "<-chan " + result
   480  	}
   481  
   482  	res += fmt.Sprintf(") (%s, error)", result)
   483  	return res
   484  }
   485  
   486  func (f *Field) ComplexitySignature() string {
   487  	res := fmt.Sprintf("func(childComplexity int")
   488  	for _, arg := range f.Args {
   489  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   490  	}
   491  	res += ") int"
   492  	return res
   493  }
   494  
   495  func (f *Field) ComplexityArgs() string {
   496  	args := make([]string, len(f.Args))
   497  	for i, arg := range f.Args {
   498  		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
   499  	}
   500  
   501  	return strings.Join(args, ", ")
   502  }
   503  
   504  func (f *Field) CallArgs() string {
   505  	args := make([]string, 0, len(f.Args)+2)
   506  
   507  	if f.IsResolver {
   508  		args = append(args, "rctx")
   509  
   510  		if !f.Object.Root {
   511  			args = append(args, "obj")
   512  		}
   513  	} else if f.MethodHasContext {
   514  		args = append(args, "ctx")
   515  	}
   516  
   517  	for _, arg := range f.Args {
   518  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   519  	}
   520  
   521  	return strings.Join(args, ", ")
   522  }