github.com/mkusaka/gqlgen@v0.7.2/codegen/util.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"reflect"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/pkg/errors"
    11  	"golang.org/x/tools/go/loader"
    12  )
    13  
    14  func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
    15  	if pkgName == "" {
    16  		return nil, nil
    17  	}
    18  	fullName := typeName
    19  	if pkgName != "" {
    20  		fullName = pkgName + "." + typeName
    21  	}
    22  
    23  	pkgName, err := resolvePkg(pkgName)
    24  	if err != nil {
    25  		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
    26  	}
    27  
    28  	pkg := prog.Imported[pkgName]
    29  	if pkg == nil {
    30  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
    31  	}
    32  
    33  	for astNode, def := range pkg.Defs {
    34  		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
    35  			continue
    36  		}
    37  
    38  		return def, nil
    39  	}
    40  
    41  	return nil, errors.Errorf("unable to find type %s\n", fullName)
    42  }
    43  
    44  func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
    45  	def, err := findGoType(prog, pkgName, typeName)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	if def == nil {
    50  		return nil, nil
    51  	}
    52  
    53  	namedType, ok := def.Type().(*types.Named)
    54  	if !ok {
    55  		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
    56  	}
    57  
    58  	return namedType, nil
    59  }
    60  
    61  func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
    62  	namedType, err := findGoNamedType(prog, pkgName, typeName)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	if namedType == nil {
    67  		return nil, nil
    68  	}
    69  
    70  	underlying, ok := namedType.Underlying().(*types.Interface)
    71  	if !ok {
    72  		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
    73  	}
    74  
    75  	return underlying, nil
    76  }
    77  
    78  func findMethod(typ *types.Named, name string) *types.Func {
    79  	for i := 0; i < typ.NumMethods(); i++ {
    80  		method := typ.Method(i)
    81  		if !method.Exported() {
    82  			continue
    83  		}
    84  
    85  		if strings.EqualFold(method.Name(), name) {
    86  			return method
    87  		}
    88  	}
    89  
    90  	if s, ok := typ.Underlying().(*types.Struct); ok {
    91  		for i := 0; i < s.NumFields(); i++ {
    92  			field := s.Field(i)
    93  			if !field.Anonymous() {
    94  				continue
    95  			}
    96  
    97  			if named, ok := field.Type().(*types.Named); ok {
    98  				if f := findMethod(named, name); f != nil {
    99  					return f
   100  				}
   101  			}
   102  		}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  func equalFieldName(source, target string) bool {
   109  	source = strings.Replace(source, "_", "", -1)
   110  	target = strings.Replace(target, "_", "", -1)
   111  	return strings.EqualFold(source, target)
   112  }
   113  
   114  // findField attempts to match the name to a struct field with the following
   115  // priorites:
   116  // 1. If struct tag is passed then struct tag has highest priority
   117  // 2. Field in an embedded struct
   118  // 3. Actual Field name
   119  func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
   120  	var foundField *types.Var
   121  	foundFieldWasTag := false
   122  
   123  	for i := 0; i < typ.NumFields(); i++ {
   124  		field := typ.Field(i)
   125  
   126  		if structTag != "" {
   127  			tags := reflect.StructTag(typ.Tag(i))
   128  			if val, ok := tags.Lookup(structTag); ok {
   129  				if equalFieldName(val, name) {
   130  					if foundField != nil && foundFieldWasTag {
   131  						return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
   132  					}
   133  
   134  					foundField = field
   135  					foundFieldWasTag = true
   136  				}
   137  			}
   138  		}
   139  
   140  		if field.Anonymous() {
   141  
   142  			fieldType := field.Type()
   143  
   144  			if ptr, ok := fieldType.(*types.Pointer); ok {
   145  				fieldType = ptr.Elem()
   146  			}
   147  
   148  			// Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
   149  			// It should be safe to always call.
   150  			if named, ok := fieldType.Underlying().(*types.Struct); ok {
   151  				f, err := findField(named, name, structTag)
   152  				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
   153  					return nil, err
   154  				}
   155  				if f != nil && foundField == nil {
   156  					foundField = f
   157  				}
   158  			}
   159  		}
   160  
   161  		if !field.Exported() {
   162  			continue
   163  		}
   164  
   165  		if equalFieldName(field.Name(), name) && foundField == nil { // aqui!
   166  			foundField = field
   167  		}
   168  	}
   169  
   170  	if foundField == nil {
   171  		return nil, fmt.Errorf("no field named %s", name)
   172  	}
   173  
   174  	return foundField, nil
   175  }
   176  
   177  type BindError struct {
   178  	object    *Object
   179  	field     *Field
   180  	typ       types.Type
   181  	methodErr error
   182  	varErr    error
   183  }
   184  
   185  func (b BindError) Error() string {
   186  	return fmt.Sprintf(
   187  		"Unable to bind %s.%s to %s\n  %s\n  %s",
   188  		b.object.GQLType,
   189  		b.field.GQLName,
   190  		b.typ.String(),
   191  		b.methodErr.Error(),
   192  		b.varErr.Error(),
   193  	)
   194  }
   195  
   196  type BindErrors []BindError
   197  
   198  func (b BindErrors) Error() string {
   199  	var errs []string
   200  	for _, err := range b {
   201  		errs = append(errs, err.Error())
   202  	}
   203  	return strings.Join(errs, "\n\n")
   204  }
   205  
   206  func bindObject(t types.Type, object *Object, structTag string) BindErrors {
   207  	var errs BindErrors
   208  	for i := range object.Fields {
   209  		field := &object.Fields[i]
   210  
   211  		if field.ForceResolver {
   212  			continue
   213  		}
   214  
   215  		// first try binding to a method
   216  		methodErr := bindMethod(t, field)
   217  		if methodErr == nil {
   218  			continue
   219  		}
   220  
   221  		// otherwise try binding to a var
   222  		varErr := bindVar(t, field, structTag)
   223  
   224  		if varErr != nil {
   225  			errs = append(errs, BindError{
   226  				object:    object,
   227  				typ:       t,
   228  				field:     field,
   229  				varErr:    varErr,
   230  				methodErr: methodErr,
   231  			})
   232  		}
   233  	}
   234  	return errs
   235  }
   236  
   237  func bindMethod(t types.Type, field *Field) error {
   238  	namedType, ok := t.(*types.Named)
   239  	if !ok {
   240  		return fmt.Errorf("not a named type")
   241  	}
   242  
   243  	goName := field.GQLName
   244  	if field.GoFieldName != "" {
   245  		goName = field.GoFieldName
   246  	}
   247  	method := findMethod(namedType, goName)
   248  	if method == nil {
   249  		return fmt.Errorf("no method named %s", field.GQLName)
   250  	}
   251  	sig := method.Type().(*types.Signature)
   252  
   253  	if sig.Results().Len() == 1 {
   254  		field.NoErr = true
   255  	} else if sig.Results().Len() != 2 {
   256  		return fmt.Errorf("method has wrong number of args")
   257  	}
   258  	params := sig.Params()
   259  	// If the first argument is the context, remove it from the comparison and set
   260  	// the MethodHasContext flag so that the context will be passed to this model's method
   261  	if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
   262  		field.MethodHasContext = true
   263  		vars := make([]*types.Var, params.Len()-1)
   264  		for i := 1; i < params.Len(); i++ {
   265  			vars[i-1] = params.At(i)
   266  		}
   267  		params = types.NewTuple(vars...)
   268  	}
   269  
   270  	newArgs, err := matchArgs(field, params)
   271  	if err != nil {
   272  		return err
   273  	}
   274  
   275  	result := sig.Results().At(0)
   276  	if err := validateTypeBinding(field, result.Type()); err != nil {
   277  		return errors.Wrap(err, "method has wrong return type")
   278  	}
   279  
   280  	// success, args and return type match. Bind to method
   281  	field.GoFieldType = GoFieldMethod
   282  	field.GoReceiverName = "obj"
   283  	field.GoFieldName = method.Name()
   284  	field.Args = newArgs
   285  	return nil
   286  }
   287  
   288  func bindVar(t types.Type, field *Field, structTag string) error {
   289  	underlying, ok := t.Underlying().(*types.Struct)
   290  	if !ok {
   291  		return fmt.Errorf("not a struct")
   292  	}
   293  
   294  	goName := field.GQLName
   295  	if field.GoFieldName != "" {
   296  		goName = field.GoFieldName
   297  	}
   298  	structField, err := findField(underlying, goName, structTag)
   299  	if err != nil {
   300  		return err
   301  	}
   302  
   303  	if err := validateTypeBinding(field, structField.Type()); err != nil {
   304  		return errors.Wrap(err, "field has wrong type")
   305  	}
   306  
   307  	// success, bind to var
   308  	field.GoFieldType = GoFieldVariable
   309  	field.GoReceiverName = "obj"
   310  	field.GoFieldName = structField.Name()
   311  	return nil
   312  }
   313  
   314  func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
   315  	var newArgs []FieldArgument
   316  
   317  nextArg:
   318  	for j := 0; j < params.Len(); j++ {
   319  		param := params.At(j)
   320  		for _, oldArg := range field.Args {
   321  			if strings.EqualFold(oldArg.GQLName, param.Name()) {
   322  				if !field.ForceResolver {
   323  					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
   324  				}
   325  				newArgs = append(newArgs, oldArg)
   326  				continue nextArg
   327  			}
   328  		}
   329  
   330  		// no matching arg found, abort
   331  		return nil, fmt.Errorf("arg %s not found on method", param.Name())
   332  	}
   333  	return newArgs, nil
   334  }
   335  
   336  func validateTypeBinding(field *Field, goType types.Type) error {
   337  	gqlType := normalizeVendor(field.Type.FullSignature())
   338  	goTypeStr := normalizeVendor(goType.String())
   339  
   340  	if equalTypes(goTypeStr, gqlType) {
   341  		field.Type.Modifiers = modifiersFromGoType(goType)
   342  		return nil
   343  	}
   344  
   345  	// deal with type aliases
   346  	underlyingStr := normalizeVendor(goType.Underlying().String())
   347  	if equalTypes(underlyingStr, gqlType) {
   348  		field.Type.Modifiers = modifiersFromGoType(goType)
   349  		pkg, typ := pkgAndType(goType.String())
   350  		field.AliasedType = &Ref{GoType: typ, Package: pkg}
   351  		return nil
   352  	}
   353  
   354  	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
   355  }
   356  
   357  func modifiersFromGoType(t types.Type) []string {
   358  	var modifiers []string
   359  	for {
   360  		switch val := t.(type) {
   361  		case *types.Pointer:
   362  			modifiers = append(modifiers, modPtr)
   363  			t = val.Elem()
   364  		case *types.Array:
   365  			modifiers = append(modifiers, modList)
   366  			t = val.Elem()
   367  		case *types.Slice:
   368  			modifiers = append(modifiers, modList)
   369  			t = val.Elem()
   370  		default:
   371  			return modifiers
   372  		}
   373  	}
   374  }
   375  
   376  var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
   377  
   378  func normalizeVendor(pkg string) string {
   379  	modifiers := modsRegex.FindAllString(pkg, 1)[0]
   380  	pkg = strings.TrimPrefix(pkg, modifiers)
   381  	parts := strings.Split(pkg, "/vendor/")
   382  	return modifiers + parts[len(parts)-1]
   383  }
   384  
   385  func equalTypes(goType string, gqlType string) bool {
   386  	return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType
   387  }