github.com/tobiash/gqlgen@v0.5.1/codegen/util.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"reflect"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/pkg/errors"
    11  	"golang.org/x/tools/go/loader"
    12  )
    13  
    14  func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
    15  	if pkgName == "" {
    16  		return nil, nil
    17  	}
    18  	fullName := typeName
    19  	if pkgName != "" {
    20  		fullName = pkgName + "." + typeName
    21  	}
    22  
    23  	pkgName, err := resolvePkg(pkgName)
    24  	if err != nil {
    25  		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
    26  	}
    27  
    28  	pkg := prog.Imported[pkgName]
    29  	if pkg == nil {
    30  		return nil, errors.Errorf("required package was not loaded: %s", fullName)
    31  	}
    32  
    33  	for astNode, def := range pkg.Defs {
    34  		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
    35  			continue
    36  		}
    37  
    38  		return def, nil
    39  	}
    40  
    41  	return nil, errors.Errorf("unable to find type %s\n", fullName)
    42  }
    43  
    44  func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
    45  	def, err := findGoType(prog, pkgName, typeName)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	if def == nil {
    50  		return nil, nil
    51  	}
    52  
    53  	namedType, ok := def.Type().(*types.Named)
    54  	if !ok {
    55  		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
    56  	}
    57  
    58  	return namedType, nil
    59  }
    60  
    61  func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
    62  	namedType, err := findGoNamedType(prog, pkgName, typeName)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	if namedType == nil {
    67  		return nil, nil
    68  	}
    69  
    70  	underlying, ok := namedType.Underlying().(*types.Interface)
    71  	if !ok {
    72  		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
    73  	}
    74  
    75  	return underlying, nil
    76  }
    77  
    78  func findMethod(typ *types.Named, name string) *types.Func {
    79  	for i := 0; i < typ.NumMethods(); i++ {
    80  		method := typ.Method(i)
    81  		if !method.Exported() {
    82  			continue
    83  		}
    84  
    85  		if strings.EqualFold(method.Name(), name) {
    86  			return method
    87  		}
    88  	}
    89  
    90  	if s, ok := typ.Underlying().(*types.Struct); ok {
    91  		for i := 0; i < s.NumFields(); i++ {
    92  			field := s.Field(i)
    93  			if !field.Anonymous() {
    94  				continue
    95  			}
    96  
    97  			if named, ok := field.Type().(*types.Named); ok {
    98  				if f := findMethod(named, name); f != nil {
    99  					return f
   100  				}
   101  			}
   102  		}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // findField attempts to match the name to a struct field with the following
   109  // priorites:
   110  // 1. If struct tag is passed then struct tag has highest priority
   111  // 2. Field in an embedded struct
   112  // 3. Actual Field name
   113  func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
   114  	var foundField *types.Var
   115  	foundFieldWasTag := false
   116  
   117  	for i := 0; i < typ.NumFields(); i++ {
   118  		field := typ.Field(i)
   119  
   120  		if structTag != "" {
   121  			tags := reflect.StructTag(typ.Tag(i))
   122  			if val, ok := tags.Lookup(structTag); ok {
   123  				if strings.EqualFold(val, name) {
   124  					if foundField != nil && foundFieldWasTag {
   125  						return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
   126  					}
   127  
   128  					foundField = field
   129  					foundFieldWasTag = true
   130  				}
   131  			}
   132  		}
   133  
   134  		if field.Anonymous() {
   135  			if named, ok := field.Type().(*types.Struct); ok {
   136  				f, err := findField(named, name, structTag)
   137  				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
   138  					return nil, err
   139  				}
   140  				if f != nil && foundField == nil {
   141  					foundField = f
   142  				}
   143  			}
   144  
   145  			if named, ok := field.Type().Underlying().(*types.Struct); ok {
   146  				f, err := findField(named, name, structTag)
   147  				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
   148  					return nil, err
   149  				}
   150  				if f != nil && foundField == nil {
   151  					foundField = f
   152  				}
   153  			}
   154  		}
   155  
   156  		if !field.Exported() {
   157  			continue
   158  		}
   159  
   160  		if strings.EqualFold(field.Name(), name) && foundField == nil {
   161  			foundField = field
   162  		}
   163  	}
   164  
   165  	if foundField == nil {
   166  		return nil, fmt.Errorf("no field named %s", name)
   167  	}
   168  
   169  	return foundField, nil
   170  }
   171  
   172  type BindError struct {
   173  	object    *Object
   174  	field     *Field
   175  	typ       types.Type
   176  	methodErr error
   177  	varErr    error
   178  }
   179  
   180  func (b BindError) Error() string {
   181  	return fmt.Sprintf(
   182  		"Unable to bind %s.%s to %s\n  %s\n  %s",
   183  		b.object.GQLType,
   184  		b.field.GQLName,
   185  		b.typ.String(),
   186  		b.methodErr.Error(),
   187  		b.varErr.Error(),
   188  	)
   189  }
   190  
   191  type BindErrors []BindError
   192  
   193  func (b BindErrors) Error() string {
   194  	var errs []string
   195  	for _, err := range b {
   196  		errs = append(errs, err.Error())
   197  	}
   198  	return strings.Join(errs, "\n\n")
   199  }
   200  
   201  func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
   202  	var errs BindErrors
   203  	for i := range object.Fields {
   204  		field := &object.Fields[i]
   205  
   206  		if field.ForceResolver {
   207  			continue
   208  		}
   209  
   210  		// first try binding to a method
   211  		methodErr := bindMethod(imports, t, field)
   212  		if methodErr == nil {
   213  			continue
   214  		}
   215  
   216  		// otherwise try binding to a var
   217  		varErr := bindVar(imports, t, field, structTag)
   218  
   219  		if varErr != nil {
   220  			errs = append(errs, BindError{
   221  				object:    object,
   222  				typ:       t,
   223  				field:     field,
   224  				varErr:    varErr,
   225  				methodErr: methodErr,
   226  			})
   227  		}
   228  	}
   229  	return errs
   230  }
   231  
   232  func bindMethod(imports *Imports, t types.Type, field *Field) error {
   233  	namedType, ok := t.(*types.Named)
   234  	if !ok {
   235  		return fmt.Errorf("not a named type")
   236  	}
   237  
   238  	goName := field.GQLName
   239  	if field.GoFieldName != "" {
   240  		goName = field.GoFieldName
   241  	}
   242  	method := findMethod(namedType, goName)
   243  	if method == nil {
   244  		return fmt.Errorf("no method named %s", field.GQLName)
   245  	}
   246  	sig := method.Type().(*types.Signature)
   247  
   248  	if sig.Results().Len() == 1 {
   249  		field.NoErr = true
   250  	} else if sig.Results().Len() != 2 {
   251  		return fmt.Errorf("method has wrong number of args")
   252  	}
   253  	newArgs, err := matchArgs(field, sig.Params())
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	result := sig.Results().At(0)
   259  	if err := validateTypeBinding(imports, field, result.Type()); err != nil {
   260  		return errors.Wrap(err, "method has wrong return type")
   261  	}
   262  
   263  	// success, args and return type match. Bind to method
   264  	field.GoFieldType = GoFieldMethod
   265  	field.GoReceiverName = "obj"
   266  	field.GoFieldName = method.Name()
   267  	field.Args = newArgs
   268  	return nil
   269  }
   270  
   271  func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
   272  	underlying, ok := t.Underlying().(*types.Struct)
   273  	if !ok {
   274  		return fmt.Errorf("not a struct")
   275  	}
   276  
   277  	goName := field.GQLName
   278  	if field.GoFieldName != "" {
   279  		goName = field.GoFieldName
   280  	}
   281  	structField, err := findField(underlying, goName, structTag)
   282  	if err != nil {
   283  		return err
   284  	}
   285  
   286  	if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
   287  		return errors.Wrap(err, "field has wrong type")
   288  	}
   289  
   290  	// success, bind to var
   291  	field.GoFieldType = GoFieldVariable
   292  	field.GoReceiverName = "obj"
   293  	field.GoFieldName = structField.Name()
   294  	return nil
   295  }
   296  
   297  func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
   298  	var newArgs []FieldArgument
   299  
   300  nextArg:
   301  	for j := 0; j < params.Len(); j++ {
   302  		param := params.At(j)
   303  		for _, oldArg := range field.Args {
   304  			if strings.EqualFold(oldArg.GQLName, param.Name()) {
   305  				if !field.ForceResolver {
   306  					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
   307  				}
   308  				newArgs = append(newArgs, oldArg)
   309  				continue nextArg
   310  			}
   311  		}
   312  
   313  		// no matching arg found, abort
   314  		return nil, fmt.Errorf("arg %s not found on method", param.Name())
   315  	}
   316  	return newArgs, nil
   317  }
   318  
   319  func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
   320  	gqlType := normalizeVendor(field.Type.FullSignature())
   321  	goTypeStr := normalizeVendor(goType.String())
   322  
   323  	if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
   324  		field.Type.Modifiers = modifiersFromGoType(goType)
   325  		return nil
   326  	}
   327  
   328  	// deal with type aliases
   329  	underlyingStr := normalizeVendor(goType.Underlying().String())
   330  	if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
   331  		field.Type.Modifiers = modifiersFromGoType(goType)
   332  		pkg, typ := pkgAndType(goType.String())
   333  		imp := imports.findByPath(pkg)
   334  		field.AliasedType = &Ref{GoType: typ, Import: imp}
   335  		return nil
   336  	}
   337  
   338  	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
   339  }
   340  
   341  func modifiersFromGoType(t types.Type) []string {
   342  	var modifiers []string
   343  	for {
   344  		switch val := t.(type) {
   345  		case *types.Pointer:
   346  			modifiers = append(modifiers, modPtr)
   347  			t = val.Elem()
   348  		case *types.Array:
   349  			modifiers = append(modifiers, modList)
   350  			t = val.Elem()
   351  		case *types.Slice:
   352  			modifiers = append(modifiers, modList)
   353  			t = val.Elem()
   354  		default:
   355  			return modifiers
   356  		}
   357  	}
   358  }
   359  
   360  var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
   361  
   362  func normalizeVendor(pkg string) string {
   363  	modifiers := modsRegex.FindAllString(pkg, 1)[0]
   364  	pkg = strings.TrimPrefix(pkg, modifiers)
   365  	parts := strings.Split(pkg, "/vendor/")
   366  	return modifiers + parts[len(parts)-1]
   367  }