github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/codegen/field.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"log"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/HaswinVidanage/gqlgen/codegen/config"
    12  	"github.com/HaswinVidanage/gqlgen/codegen/templates"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser/ast"
    15  )
    16  
    17  type Field struct {
    18  	*ast.FieldDefinition
    19  
    20  	TypeReference    *config.TypeReference
    21  	GoFieldType      GoFieldType      // The field type in go, if any
    22  	GoReceiverName   string           // The name of method & var receiver in go, if any
    23  	GoFieldName      string           // The name of the method or var in go, if any
    24  	IsResolver       bool             // Does this field need a resolver
    25  	Args             []*FieldArgument // A list of arguments to be passed to this field
    26  	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
    27  	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
    28  	Object           *Object          // A link back to the parent object
    29  	Default          interface{}      // The default value
    30  	Directives       []*Directive
    31  }
    32  
    33  func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
    34  	dirs, err := b.getDirectives(field.Directives)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	f := Field{
    40  		FieldDefinition: field,
    41  		Object:          obj,
    42  		Directives:      dirs,
    43  		GoFieldName:     templates.ToGo(field.Name),
    44  		GoFieldType:     GoFieldVariable,
    45  		GoReceiverName:  "obj",
    46  	}
    47  
    48  	if field.DefaultValue != nil {
    49  		var err error
    50  		f.Default, err = field.DefaultValue.Value(nil)
    51  		if err != nil {
    52  			return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error())
    53  		}
    54  	}
    55  
    56  	for _, arg := range field.Arguments {
    57  		newArg, err := b.buildArg(obj, arg)
    58  		if err != nil {
    59  			return nil, err
    60  		}
    61  		f.Args = append(f.Args, newArg)
    62  	}
    63  
    64  	if err = b.bindField(obj, &f); err != nil {
    65  		f.IsResolver = true
    66  		log.Println(err.Error())
    67  	}
    68  
    69  	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
    70  		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
    71  	}
    72  
    73  	return &f, nil
    74  }
    75  
    76  func (b *builder) bindField(obj *Object, f *Field) error {
    77  	defer func() {
    78  		if f.TypeReference == nil {
    79  			tr, err := b.Binder.TypeReference(f.Type, nil)
    80  			if err != nil {
    81  				panic(err)
    82  			}
    83  			f.TypeReference = tr
    84  		}
    85  	}()
    86  
    87  	switch {
    88  	case f.Name == "__schema":
    89  		f.GoFieldType = GoFieldMethod
    90  		f.GoReceiverName = "ec"
    91  		f.GoFieldName = "introspectSchema"
    92  		return nil
    93  	case f.Name == "__type":
    94  		f.GoFieldType = GoFieldMethod
    95  		f.GoReceiverName = "ec"
    96  		f.GoFieldName = "introspectType"
    97  		return nil
    98  	case obj.Root:
    99  		f.IsResolver = true
   100  		return nil
   101  	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
   102  		f.IsResolver = true
   103  		return nil
   104  	case obj.Type == config.MapType:
   105  		return nil
   106  	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
   107  		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
   108  	}
   109  
   110  	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	pos := b.Binder.ObjectPosition(target)
   116  
   117  	switch target := target.(type) {
   118  	case nil:
   119  		objPos := b.Binder.TypePosition(obj.Type)
   120  		return fmt.Errorf(
   121  			"%s:%d adding resolver method for %s.%s, nothing matched",
   122  			objPos.Filename,
   123  			objPos.Line,
   124  			obj.Name,
   125  			f.Name,
   126  		)
   127  
   128  	case *types.Func:
   129  		sig := target.Type().(*types.Signature)
   130  		if sig.Results().Len() == 1 {
   131  			f.NoErr = true
   132  		} else if sig.Results().Len() != 2 {
   133  			return fmt.Errorf("method has wrong number of args")
   134  		}
   135  		params := sig.Params()
   136  		// If the first argument is the context, remove it from the comparison and set
   137  		// the MethodHasContext flag so that the context will be passed to this model's method
   138  		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   139  			f.MethodHasContext = true
   140  			vars := make([]*types.Var, params.Len()-1)
   141  			for i := 1; i < params.Len(); i++ {
   142  				vars[i-1] = params.At(i)
   143  			}
   144  			params = types.NewTuple(vars...)
   145  		}
   146  
   147  		if err = b.bindArgs(f, params); err != nil {
   148  			return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line)
   149  		}
   150  
   151  		result := sig.Results().At(0)
   152  		tr, err := b.Binder.TypeReference(f.Type, result.Type())
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		// success, args and return type match. Bind to method
   158  		f.GoFieldType = GoFieldMethod
   159  		f.GoReceiverName = "obj"
   160  		f.GoFieldName = target.Name()
   161  		f.TypeReference = tr
   162  
   163  		return nil
   164  
   165  	case *types.Var:
   166  		tr, err := b.Binder.TypeReference(f.Type, target.Type())
   167  		if err != nil {
   168  			return err
   169  		}
   170  
   171  		// success, bind to var
   172  		f.GoFieldType = GoFieldVariable
   173  		f.GoReceiverName = "obj"
   174  		f.GoFieldName = target.Name()
   175  		f.TypeReference = tr
   176  
   177  		return nil
   178  	default:
   179  		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
   180  	}
   181  }
   182  
   183  // findField attempts to match the name to a struct field with the following
   184  // priorites:
   185  // 1. Any method with a matching name
   186  // 2. Any Fields with a struct tag (see config.StructTag)
   187  // 3. Any fields with a matching name
   188  // 4. Same logic again for embedded fields
   189  func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
   190  	for i := 0; i < named.NumMethods(); i++ {
   191  		method := named.Method(i)
   192  		if !method.Exported() {
   193  			continue
   194  		}
   195  
   196  		if !strings.EqualFold(method.Name(), name) {
   197  			continue
   198  		}
   199  
   200  		return method, nil
   201  	}
   202  
   203  	strukt, ok := named.Underlying().(*types.Struct)
   204  	if !ok {
   205  		return nil, fmt.Errorf("not a struct")
   206  	}
   207  	return b.findBindStructTarget(strukt, name)
   208  }
   209  
   210  func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) {
   211  	// struct tags have the highest priority
   212  	if b.Config.StructTag != "" {
   213  		var foundField *types.Var
   214  		for i := 0; i < strukt.NumFields(); i++ {
   215  			field := strukt.Field(i)
   216  			if !field.Exported() {
   217  				continue
   218  			}
   219  			tags := reflect.StructTag(strukt.Tag(i))
   220  			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
   221  				if foundField != nil {
   222  					return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
   223  				}
   224  
   225  				foundField = field
   226  			}
   227  		}
   228  		if foundField != nil {
   229  			return foundField, nil
   230  		}
   231  	}
   232  
   233  	// Then matching field names
   234  	for i := 0; i < strukt.NumFields(); i++ {
   235  		field := strukt.Field(i)
   236  		if !field.Exported() {
   237  			continue
   238  		}
   239  		if equalFieldName(field.Name(), name) { // aqui!
   240  			return field, nil
   241  		}
   242  	}
   243  
   244  	// Then look in embedded structs
   245  	for i := 0; i < strukt.NumFields(); i++ {
   246  		field := strukt.Field(i)
   247  		if !field.Exported() {
   248  			continue
   249  		}
   250  
   251  		if !field.Anonymous() {
   252  			continue
   253  		}
   254  
   255  		fieldType := field.Type()
   256  		if ptr, ok := fieldType.(*types.Pointer); ok {
   257  			fieldType = ptr.Elem()
   258  		}
   259  
   260  		switch fieldType := fieldType.(type) {
   261  		case *types.Named:
   262  			f, err := b.findBindTarget(fieldType, name)
   263  			if err != nil {
   264  				return nil, err
   265  			}
   266  			if f != nil {
   267  				return f, nil
   268  			}
   269  		case *types.Struct:
   270  			f, err := b.findBindStructTarget(fieldType, name)
   271  			if err != nil {
   272  				return nil, err
   273  			}
   274  			if f != nil {
   275  				return f, nil
   276  			}
   277  		default:
   278  			panic(fmt.Errorf("unknown embedded field type %T", field.Type()))
   279  		}
   280  	}
   281  
   282  	return nil, nil
   283  }
   284  
   285  func (f *Field) HasDirectives() bool {
   286  	return len(f.Directives) > 0
   287  }
   288  
   289  func (f *Field) IsReserved() bool {
   290  	return strings.HasPrefix(f.Name, "__")
   291  }
   292  
   293  func (f *Field) IsMethod() bool {
   294  	return f.GoFieldType == GoFieldMethod
   295  }
   296  
   297  func (f *Field) IsVariable() bool {
   298  	return f.GoFieldType == GoFieldVariable
   299  }
   300  
   301  func (f *Field) IsConcurrent() bool {
   302  	if f.Object.DisableConcurrency {
   303  		return false
   304  	}
   305  	return f.MethodHasContext || f.IsResolver
   306  }
   307  
   308  func (f *Field) GoNameUnexported() string {
   309  	return templates.ToGoPrivate(f.Name)
   310  }
   311  
   312  func (f *Field) ShortInvocation() string {
   313  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   314  }
   315  
   316  func (f *Field) ArgsFunc() string {
   317  	if len(f.Args) == 0 {
   318  		return ""
   319  	}
   320  
   321  	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
   322  }
   323  
   324  func (f *Field) ResolverType() string {
   325  	if !f.IsResolver {
   326  		return ""
   327  	}
   328  
   329  	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
   330  }
   331  
   332  func (f *Field) ShortResolverDeclaration() string {
   333  	res := "(ctx context.Context"
   334  
   335  	if !f.Object.Root {
   336  		res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type))
   337  	}
   338  	for _, arg := range f.Args {
   339  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   340  	}
   341  
   342  	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
   343  	if f.Object.Stream {
   344  		result = "<-chan " + result
   345  	}
   346  
   347  	res += fmt.Sprintf(") (%s, error)", result)
   348  	return res
   349  }
   350  
   351  func (f *Field) ComplexitySignature() string {
   352  	res := fmt.Sprintf("func(childComplexity int")
   353  	for _, arg := range f.Args {
   354  		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
   355  	}
   356  	res += ") int"
   357  	return res
   358  }
   359  
   360  func (f *Field) ComplexityArgs() string {
   361  	var args []string
   362  	for _, arg := range f.Args {
   363  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   364  	}
   365  
   366  	return strings.Join(args, ", ")
   367  }
   368  
   369  func (f *Field) CallArgs() string {
   370  	var args []string
   371  
   372  	if f.IsResolver {
   373  		args = append(args, "rctx")
   374  
   375  		if !f.Object.Root {
   376  			args = append(args, "obj")
   377  		}
   378  	} else {
   379  		if f.MethodHasContext {
   380  			args = append(args, "ctx")
   381  		}
   382  	}
   383  
   384  	for _, arg := range f.Args {
   385  		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
   386  	}
   387  
   388  	return strings.Join(args, ", ")
   389  }