github.com/codyleyhan/gqlgen@v0.4.4/codegen/util.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/pkg/errors"
    10  	"golang.org/x/tools/go/loader"
    11  )
    12  
    13  func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
    14  	if pkgName == "" {
    15  		return nil, nil
    16  	}
    17  	fullName := typeName
    18  	if pkgName != "" {
    19  		fullName = pkgName + "." + typeName
    20  	}
    21  
    22  	pkgName, err := resolvePkg(pkgName)
    23  	if err != nil {
    24  		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
    25  	}
    26  
    27  	pkg := prog.Imported[pkgName]
    28  	if pkg == nil {
    29  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
    30  	}
    31  
    32  	for astNode, def := range pkg.Defs {
    33  		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
    34  			continue
    35  		}
    36  
    37  		return def, nil
    38  	}
    39  
    40  	return nil, errors.Errorf("unable to find type %s\n", fullName)
    41  }
    42  
    43  func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
    44  	def, err := findGoType(prog, pkgName, typeName)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	if def == nil {
    49  		return nil, nil
    50  	}
    51  
    52  	namedType, ok := def.Type().(*types.Named)
    53  	if !ok {
    54  		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
    55  	}
    56  
    57  	return namedType, nil
    58  }
    59  
    60  func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
    61  	namedType, err := findGoNamedType(prog, pkgName, typeName)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	if namedType == nil {
    66  		return nil, nil
    67  	}
    68  
    69  	underlying, ok := namedType.Underlying().(*types.Interface)
    70  	if !ok {
    71  		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
    72  	}
    73  
    74  	return underlying, nil
    75  }
    76  
    77  func findMethod(typ *types.Named, name string) *types.Func {
    78  	for i := 0; i < typ.NumMethods(); i++ {
    79  		method := typ.Method(i)
    80  		if !method.Exported() {
    81  			continue
    82  		}
    83  
    84  		if strings.EqualFold(method.Name(), name) {
    85  			return method
    86  		}
    87  	}
    88  
    89  	if s, ok := typ.Underlying().(*types.Struct); ok {
    90  		for i := 0; i < s.NumFields(); i++ {
    91  			field := s.Field(i)
    92  			if !field.Anonymous() {
    93  				continue
    94  			}
    95  
    96  			if named, ok := field.Type().(*types.Named); ok {
    97  				if f := findMethod(named, name); f != nil {
    98  					return f
    99  				}
   100  			}
   101  		}
   102  	}
   103  
   104  	return nil
   105  }
   106  
   107  func findField(typ *types.Struct, name string) *types.Var {
   108  	for i := 0; i < typ.NumFields(); i++ {
   109  		field := typ.Field(i)
   110  		if field.Anonymous() {
   111  			if named, ok := field.Type().(*types.Struct); ok {
   112  				if f := findField(named, name); f != nil {
   113  					return f
   114  				}
   115  			}
   116  
   117  			if named, ok := field.Type().Underlying().(*types.Struct); ok {
   118  				if f := findField(named, name); f != nil {
   119  					return f
   120  				}
   121  			}
   122  		}
   123  
   124  		if !field.Exported() {
   125  			continue
   126  		}
   127  
   128  		if strings.EqualFold(field.Name(), name) {
   129  			return field
   130  		}
   131  	}
   132  	return nil
   133  }
   134  
   135  type BindError struct {
   136  	object    *Object
   137  	field     *Field
   138  	typ       types.Type
   139  	methodErr error
   140  	varErr    error
   141  }
   142  
   143  func (b BindError) Error() string {
   144  	return fmt.Sprintf(
   145  		"Unable to bind %s.%s to %s\n  %s\n  %s",
   146  		b.object.GQLType,
   147  		b.field.GQLName,
   148  		b.typ.String(),
   149  		b.methodErr.Error(),
   150  		b.varErr.Error(),
   151  	)
   152  }
   153  
   154  type BindErrors []BindError
   155  
   156  func (b BindErrors) Error() string {
   157  	var errs []string
   158  	for _, err := range b {
   159  		errs = append(errs, err.Error())
   160  	}
   161  	return strings.Join(errs, "\n\n")
   162  }
   163  
   164  func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
   165  	var errs BindErrors
   166  	for i := range object.Fields {
   167  		field := &object.Fields[i]
   168  
   169  		if field.ForceResolver {
   170  			continue
   171  		}
   172  
   173  		// first try binding to a method
   174  		methodErr := bindMethod(imports, t, field)
   175  		if methodErr == nil {
   176  			continue
   177  		}
   178  
   179  		// otherwise try binding to a var
   180  		varErr := bindVar(imports, t, field)
   181  
   182  		if varErr != nil {
   183  			errs = append(errs, BindError{
   184  				object:    object,
   185  				typ:       t,
   186  				field:     field,
   187  				varErr:    varErr,
   188  				methodErr: methodErr,
   189  			})
   190  		}
   191  	}
   192  	return errs
   193  }
   194  
   195  func bindMethod(imports *Imports, t types.Type, field *Field) error {
   196  	namedType, ok := t.(*types.Named)
   197  	if !ok {
   198  		return fmt.Errorf("not a named type")
   199  	}
   200  
   201  	goName := field.GQLName
   202  	if field.GoFieldName != "" {
   203  		goName = field.GoFieldName
   204  	}
   205  	method := findMethod(namedType, goName)
   206  	if method == nil {
   207  		return fmt.Errorf("no method named %s", field.GQLName)
   208  	}
   209  	sig := method.Type().(*types.Signature)
   210  
   211  	if sig.Results().Len() == 1 {
   212  		field.NoErr = true
   213  	} else if sig.Results().Len() != 2 {
   214  		return fmt.Errorf("method has wrong number of args")
   215  	}
   216  	newArgs, err := matchArgs(field, sig.Params())
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	result := sig.Results().At(0)
   222  	if err := validateTypeBinding(imports, field, result.Type()); err != nil {
   223  		return errors.Wrap(err, "method has wrong return type")
   224  	}
   225  
   226  	// success, args and return type match. Bind to method
   227  	field.GoFieldType = GoFieldMethod
   228  	field.GoReceiverName = "obj"
   229  	field.GoFieldName = method.Name()
   230  	field.Args = newArgs
   231  	return nil
   232  }
   233  
   234  func bindVar(imports *Imports, t types.Type, field *Field) error {
   235  	underlying, ok := t.Underlying().(*types.Struct)
   236  	if !ok {
   237  		return fmt.Errorf("not a struct")
   238  	}
   239  
   240  	goName := field.GQLName
   241  	if field.GoFieldName != "" {
   242  		goName = field.GoFieldName
   243  	}
   244  	structField := findField(underlying, goName)
   245  	if structField == nil {
   246  		return fmt.Errorf("no field named %s", field.GQLName)
   247  	}
   248  
   249  	if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
   250  		return errors.Wrap(err, "field has wrong type")
   251  	}
   252  
   253  	// success, bind to var
   254  	field.GoFieldType = GoFieldVariable
   255  	field.GoReceiverName = "obj"
   256  	field.GoFieldName = structField.Name()
   257  	return nil
   258  }
   259  
   260  func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
   261  	var newArgs []FieldArgument
   262  
   263  nextArg:
   264  	for j := 0; j < params.Len(); j++ {
   265  		param := params.At(j)
   266  		for _, oldArg := range field.Args {
   267  			if strings.EqualFold(oldArg.GQLName, param.Name()) {
   268  				if !field.ForceResolver {
   269  					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
   270  				}
   271  				newArgs = append(newArgs, oldArg)
   272  				continue nextArg
   273  			}
   274  		}
   275  
   276  		// no matching arg found, abort
   277  		return nil, fmt.Errorf("arg %s not found on method", param.Name())
   278  	}
   279  	return newArgs, nil
   280  }
   281  
   282  func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
   283  	gqlType := normalizeVendor(field.Type.FullSignature())
   284  	goTypeStr := normalizeVendor(goType.String())
   285  
   286  	if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
   287  		field.Type.Modifiers = modifiersFromGoType(goType)
   288  		return nil
   289  	}
   290  
   291  	// deal with type aliases
   292  	underlyingStr := normalizeVendor(goType.Underlying().String())
   293  	if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
   294  		field.Type.Modifiers = modifiersFromGoType(goType)
   295  		pkg, typ := pkgAndType(goType.String())
   296  		imp := imports.findByPath(pkg)
   297  		field.AliasedType = &Ref{GoType: typ, Import: imp}
   298  		return nil
   299  	}
   300  
   301  	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
   302  }
   303  
   304  func modifiersFromGoType(t types.Type) []string {
   305  	var modifiers []string
   306  	for {
   307  		switch val := t.(type) {
   308  		case *types.Pointer:
   309  			modifiers = append(modifiers, modPtr)
   310  			t = val.Elem()
   311  		case *types.Array:
   312  			modifiers = append(modifiers, modList)
   313  			t = val.Elem()
   314  		case *types.Slice:
   315  			modifiers = append(modifiers, modList)
   316  			t = val.Elem()
   317  		default:
   318  			return modifiers
   319  		}
   320  	}
   321  }
   322  
   323  var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
   324  
   325  func normalizeVendor(pkg string) string {
   326  	modifiers := modsRegex.FindAllString(pkg, 1)[0]
   327  	pkg = strings.TrimPrefix(pkg, modifiers)
   328  	parts := strings.Split(pkg, "/vendor/")
   329  	return modifiers + parts[len(parts)-1]
   330  }