github.com/matiasanaya/gqlgen@v0.6.0/codegen/util.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"reflect"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/pkg/errors"
    11  	"golang.org/x/tools/go/loader"
    12  )
    13  
    14  func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
    15  	if pkgName == "" {
    16  		return nil, nil
    17  	}
    18  	fullName := typeName
    19  	if pkgName != "" {
    20  		fullName = pkgName + "." + typeName
    21  	}
    22  
    23  	pkgName, err := resolvePkg(pkgName)
    24  	if err != nil {
    25  		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
    26  	}
    27  
    28  	pkg := prog.Imported[pkgName]
    29  	if pkg == nil {
    30  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
    31  	}
    32  
    33  	for astNode, def := range pkg.Defs {
    34  		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
    35  			continue
    36  		}
    37  
    38  		return def, nil
    39  	}
    40  
    41  	return nil, errors.Errorf("unable to find type %s\n", fullName)
    42  }
    43  
    44  func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
    45  	def, err := findGoType(prog, pkgName, typeName)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	if def == nil {
    50  		return nil, nil
    51  	}
    52  
    53  	namedType, ok := def.Type().(*types.Named)
    54  	if !ok {
    55  		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
    56  	}
    57  
    58  	return namedType, nil
    59  }
    60  
    61  func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
    62  	namedType, err := findGoNamedType(prog, pkgName, typeName)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	if namedType == nil {
    67  		return nil, nil
    68  	}
    69  
    70  	underlying, ok := namedType.Underlying().(*types.Interface)
    71  	if !ok {
    72  		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
    73  	}
    74  
    75  	return underlying, nil
    76  }
    77  
    78  func findMethod(typ *types.Named, name string) *types.Func {
    79  	for i := 0; i < typ.NumMethods(); i++ {
    80  		method := typ.Method(i)
    81  		if !method.Exported() {
    82  			continue
    83  		}
    84  
    85  		if strings.EqualFold(method.Name(), name) {
    86  			return method
    87  		}
    88  	}
    89  
    90  	if s, ok := typ.Underlying().(*types.Struct); ok {
    91  		for i := 0; i < s.NumFields(); i++ {
    92  			field := s.Field(i)
    93  			if !field.Anonymous() {
    94  				continue
    95  			}
    96  
    97  			if named, ok := field.Type().(*types.Named); ok {
    98  				if f := findMethod(named, name); f != nil {
    99  					return f
   100  				}
   101  			}
   102  		}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // findField attempts to match the name to a struct field with the following
   109  // priorites:
   110  // 1. If struct tag is passed then struct tag has highest priority
   111  // 2. Field in an embedded struct
   112  // 3. Actual Field name
   113  func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
   114  	var foundField *types.Var
   115  	foundFieldWasTag := false
   116  
   117  	for i := 0; i < typ.NumFields(); i++ {
   118  		field := typ.Field(i)
   119  
   120  		if structTag != "" {
   121  			tags := reflect.StructTag(typ.Tag(i))
   122  			if val, ok := tags.Lookup(structTag); ok {
   123  				if strings.EqualFold(val, name) {
   124  					if foundField != nil && foundFieldWasTag {
   125  						return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
   126  					}
   127  
   128  					foundField = field
   129  					foundFieldWasTag = true
   130  				}
   131  			}
   132  		}
   133  
   134  		if field.Anonymous() {
   135  
   136  			fieldType := field.Type()
   137  
   138  			if ptr, ok := fieldType.(*types.Pointer); ok {
   139  				fieldType = ptr.Elem()
   140  			}
   141  
   142  			// Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
   143  			// It should be safe to always call.
   144  			if named, ok := fieldType.Underlying().(*types.Struct); ok {
   145  				f, err := findField(named, name, structTag)
   146  				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
   147  					return nil, err
   148  				}
   149  				if f != nil && foundField == nil {
   150  					foundField = f
   151  				}
   152  			}
   153  		}
   154  
   155  		if !field.Exported() {
   156  			continue
   157  		}
   158  
   159  		if strings.EqualFold(field.Name(), name) && foundField == nil {
   160  			foundField = field
   161  		}
   162  	}
   163  
   164  	if foundField == nil {
   165  		return nil, fmt.Errorf("no field named %s", name)
   166  	}
   167  
   168  	return foundField, nil
   169  }
   170  
   171  type BindError struct {
   172  	object    *Object
   173  	field     *Field
   174  	typ       types.Type
   175  	methodErr error
   176  	varErr    error
   177  }
   178  
   179  func (b BindError) Error() string {
   180  	return fmt.Sprintf(
   181  		"Unable to bind %s.%s to %s\n  %s\n  %s",
   182  		b.object.GQLType,
   183  		b.field.GQLName,
   184  		b.typ.String(),
   185  		b.methodErr.Error(),
   186  		b.varErr.Error(),
   187  	)
   188  }
   189  
   190  type BindErrors []BindError
   191  
   192  func (b BindErrors) Error() string {
   193  	var errs []string
   194  	for _, err := range b {
   195  		errs = append(errs, err.Error())
   196  	}
   197  	return strings.Join(errs, "\n\n")
   198  }
   199  
   200  func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
   201  	var errs BindErrors
   202  	for i := range object.Fields {
   203  		field := &object.Fields[i]
   204  
   205  		if field.ForceResolver {
   206  			continue
   207  		}
   208  
   209  		// first try binding to a method
   210  		methodErr := bindMethod(imports, t, field)
   211  		if methodErr == nil {
   212  			continue
   213  		}
   214  
   215  		// otherwise try binding to a var
   216  		varErr := bindVar(imports, t, field, structTag)
   217  
   218  		if varErr != nil {
   219  			errs = append(errs, BindError{
   220  				object:    object,
   221  				typ:       t,
   222  				field:     field,
   223  				varErr:    varErr,
   224  				methodErr: methodErr,
   225  			})
   226  		}
   227  	}
   228  	return errs
   229  }
   230  
   231  func bindMethod(imports *Imports, t types.Type, field *Field) error {
   232  	namedType, ok := t.(*types.Named)
   233  	if !ok {
   234  		return fmt.Errorf("not a named type")
   235  	}
   236  
   237  	goName := field.GQLName
   238  	if field.GoFieldName != "" {
   239  		goName = field.GoFieldName
   240  	}
   241  	method := findMethod(namedType, goName)
   242  	if method == nil {
   243  		return fmt.Errorf("no method named %s", field.GQLName)
   244  	}
   245  	sig := method.Type().(*types.Signature)
   246  
   247  	if sig.Results().Len() == 1 {
   248  		field.NoErr = true
   249  	} else if sig.Results().Len() != 2 {
   250  		return fmt.Errorf("method has wrong number of args")
   251  	}
   252  	newArgs, err := matchArgs(field, sig.Params())
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	result := sig.Results().At(0)
   258  	if err := validateTypeBinding(imports, field, result.Type()); err != nil {
   259  		return errors.Wrap(err, "method has wrong return type")
   260  	}
   261  
   262  	// success, args and return type match. Bind to method
   263  	field.GoFieldType = GoFieldMethod
   264  	field.GoReceiverName = "obj"
   265  	field.GoFieldName = method.Name()
   266  	field.Args = newArgs
   267  	return nil
   268  }
   269  
   270  func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
   271  	underlying, ok := t.Underlying().(*types.Struct)
   272  	if !ok {
   273  		return fmt.Errorf("not a struct")
   274  	}
   275  
   276  	goName := field.GQLName
   277  	if field.GoFieldName != "" {
   278  		goName = field.GoFieldName
   279  	}
   280  	structField, err := findField(underlying, goName, structTag)
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
   286  		return errors.Wrap(err, "field has wrong type")
   287  	}
   288  
   289  	// success, bind to var
   290  	field.GoFieldType = GoFieldVariable
   291  	field.GoReceiverName = "obj"
   292  	field.GoFieldName = structField.Name()
   293  	return nil
   294  }
   295  
   296  func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
   297  	var newArgs []FieldArgument
   298  
   299  nextArg:
   300  	for j := 0; j < params.Len(); j++ {
   301  		param := params.At(j)
   302  		for _, oldArg := range field.Args {
   303  			if strings.EqualFold(oldArg.GQLName, param.Name()) {
   304  				if !field.ForceResolver {
   305  					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
   306  				}
   307  				newArgs = append(newArgs, oldArg)
   308  				continue nextArg
   309  			}
   310  		}
   311  
   312  		// no matching arg found, abort
   313  		return nil, fmt.Errorf("arg %s not found on method", param.Name())
   314  	}
   315  	return newArgs, nil
   316  }
   317  
   318  func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
   319  	gqlType := normalizeVendor(field.Type.FullSignature())
   320  	goTypeStr := normalizeVendor(goType.String())
   321  
   322  	if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
   323  		field.Type.Modifiers = modifiersFromGoType(goType)
   324  		return nil
   325  	}
   326  
   327  	// deal with type aliases
   328  	underlyingStr := normalizeVendor(goType.Underlying().String())
   329  	if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
   330  		field.Type.Modifiers = modifiersFromGoType(goType)
   331  		pkg, typ := pkgAndType(goType.String())
   332  		imp := imports.findByPath(pkg)
   333  		field.AliasedType = &Ref{GoType: typ, Import: imp}
   334  		return nil
   335  	}
   336  
   337  	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
   338  }
   339  
   340  func modifiersFromGoType(t types.Type) []string {
   341  	var modifiers []string
   342  	for {
   343  		switch val := t.(type) {
   344  		case *types.Pointer:
   345  			modifiers = append(modifiers, modPtr)
   346  			t = val.Elem()
   347  		case *types.Array:
   348  			modifiers = append(modifiers, modList)
   349  			t = val.Elem()
   350  		case *types.Slice:
   351  			modifiers = append(modifiers, modList)
   352  			t = val.Elem()
   353  		default:
   354  			return modifiers
   355  		}
   356  	}
   357  }
   358  
   359  var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
   360  
   361  func normalizeVendor(pkg string) string {
   362  	modifiers := modsRegex.FindAllString(pkg, 1)[0]
   363  	pkg = strings.TrimPrefix(pkg, modifiers)
   364  	parts := strings.Split(pkg, "/vendor/")
   365  	return modifiers + parts[len(parts)-1]
   366  }