github.com/matiasanaya/gqlgen@v0.6.0/codegen/util.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "go/types" 6 "reflect" 7 "regexp" 8 "strings" 9 10 "github.com/pkg/errors" 11 "golang.org/x/tools/go/loader" 12 ) 13 14 func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) { 15 if pkgName == "" { 16 return nil, nil 17 } 18 fullName := typeName 19 if pkgName != "" { 20 fullName = pkgName + "." + typeName 21 } 22 23 pkgName, err := resolvePkg(pkgName) 24 if err != nil { 25 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) 26 } 27 28 pkg := prog.Imported[pkgName] 29 if pkg == nil { 30 return nil, errors.Errorf("required package was not loaded: %s", fullName) 31 } 32 33 for astNode, def := range pkg.Defs { 34 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() { 35 continue 36 } 37 38 return def, nil 39 } 40 41 return nil, errors.Errorf("unable to find type %s\n", fullName) 42 } 43 44 func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) { 45 def, err := findGoType(prog, pkgName, typeName) 46 if err != nil { 47 return nil, err 48 } 49 if def == nil { 50 return nil, nil 51 } 52 53 namedType, ok := def.Type().(*types.Named) 54 if !ok { 55 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type()) 56 } 57 58 return namedType, nil 59 } 60 61 func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) { 62 namedType, err := findGoNamedType(prog, pkgName, typeName) 63 if err != nil { 64 return nil, err 65 } 66 if namedType == nil { 67 return nil, nil 68 } 69 70 underlying, ok := namedType.Underlying().(*types.Interface) 71 if !ok { 72 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String()) 73 } 74 75 return underlying, nil 76 } 77 78 func findMethod(typ *types.Named, name string) *types.Func { 79 for i := 0; i < typ.NumMethods(); i++ { 80 method := typ.Method(i) 81 if !method.Exported() { 82 continue 83 } 84 85 if strings.EqualFold(method.Name(), name) { 86 return method 87 } 88 } 89 90 if s, ok := typ.Underlying().(*types.Struct); ok { 91 for i := 0; i < s.NumFields(); i++ { 92 field := s.Field(i) 93 if !field.Anonymous() { 94 continue 95 } 96 97 if named, ok := field.Type().(*types.Named); ok { 98 if f := findMethod(named, name); f != nil { 99 return f 100 } 101 } 102 } 103 } 104 105 return nil 106 } 107 108 // findField attempts to match the name to a struct field with the following 109 // priorites: 110 // 1. If struct tag is passed then struct tag has highest priority 111 // 2. Field in an embedded struct 112 // 3. Actual Field name 113 func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { 114 var foundField *types.Var 115 foundFieldWasTag := false 116 117 for i := 0; i < typ.NumFields(); i++ { 118 field := typ.Field(i) 119 120 if structTag != "" { 121 tags := reflect.StructTag(typ.Tag(i)) 122 if val, ok := tags.Lookup(structTag); ok { 123 if strings.EqualFold(val, name) { 124 if foundField != nil && foundFieldWasTag { 125 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) 126 } 127 128 foundField = field 129 foundFieldWasTag = true 130 } 131 } 132 } 133 134 if field.Anonymous() { 135 136 fieldType := field.Type() 137 138 if ptr, ok := fieldType.(*types.Pointer); ok { 139 fieldType = ptr.Elem() 140 } 141 142 // Type.Underlying() returns itself for all types except types.Named, where it returns a struct type. 143 // It should be safe to always call. 144 if named, ok := fieldType.Underlying().(*types.Struct); ok { 145 f, err := findField(named, name, structTag) 146 if err != nil && !strings.HasPrefix(err.Error(), "no field named") { 147 return nil, err 148 } 149 if f != nil && foundField == nil { 150 foundField = f 151 } 152 } 153 } 154 155 if !field.Exported() { 156 continue 157 } 158 159 if strings.EqualFold(field.Name(), name) && foundField == nil { 160 foundField = field 161 } 162 } 163 164 if foundField == nil { 165 return nil, fmt.Errorf("no field named %s", name) 166 } 167 168 return foundField, nil 169 } 170 171 type BindError struct { 172 object *Object 173 field *Field 174 typ types.Type 175 methodErr error 176 varErr error 177 } 178 179 func (b BindError) Error() string { 180 return fmt.Sprintf( 181 "Unable to bind %s.%s to %s\n %s\n %s", 182 b.object.GQLType, 183 b.field.GQLName, 184 b.typ.String(), 185 b.methodErr.Error(), 186 b.varErr.Error(), 187 ) 188 } 189 190 type BindErrors []BindError 191 192 func (b BindErrors) Error() string { 193 var errs []string 194 for _, err := range b { 195 errs = append(errs, err.Error()) 196 } 197 return strings.Join(errs, "\n\n") 198 } 199 200 func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors { 201 var errs BindErrors 202 for i := range object.Fields { 203 field := &object.Fields[i] 204 205 if field.ForceResolver { 206 continue 207 } 208 209 // first try binding to a method 210 methodErr := bindMethod(imports, t, field) 211 if methodErr == nil { 212 continue 213 } 214 215 // otherwise try binding to a var 216 varErr := bindVar(imports, t, field, structTag) 217 218 if varErr != nil { 219 errs = append(errs, BindError{ 220 object: object, 221 typ: t, 222 field: field, 223 varErr: varErr, 224 methodErr: methodErr, 225 }) 226 } 227 } 228 return errs 229 } 230 231 func bindMethod(imports *Imports, t types.Type, field *Field) error { 232 namedType, ok := t.(*types.Named) 233 if !ok { 234 return fmt.Errorf("not a named type") 235 } 236 237 goName := field.GQLName 238 if field.GoFieldName != "" { 239 goName = field.GoFieldName 240 } 241 method := findMethod(namedType, goName) 242 if method == nil { 243 return fmt.Errorf("no method named %s", field.GQLName) 244 } 245 sig := method.Type().(*types.Signature) 246 247 if sig.Results().Len() == 1 { 248 field.NoErr = true 249 } else if sig.Results().Len() != 2 { 250 return fmt.Errorf("method has wrong number of args") 251 } 252 newArgs, err := matchArgs(field, sig.Params()) 253 if err != nil { 254 return err 255 } 256 257 result := sig.Results().At(0) 258 if err := validateTypeBinding(imports, field, result.Type()); err != nil { 259 return errors.Wrap(err, "method has wrong return type") 260 } 261 262 // success, args and return type match. Bind to method 263 field.GoFieldType = GoFieldMethod 264 field.GoReceiverName = "obj" 265 field.GoFieldName = method.Name() 266 field.Args = newArgs 267 return nil 268 } 269 270 func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error { 271 underlying, ok := t.Underlying().(*types.Struct) 272 if !ok { 273 return fmt.Errorf("not a struct") 274 } 275 276 goName := field.GQLName 277 if field.GoFieldName != "" { 278 goName = field.GoFieldName 279 } 280 structField, err := findField(underlying, goName, structTag) 281 if err != nil { 282 return err 283 } 284 285 if err := validateTypeBinding(imports, field, structField.Type()); err != nil { 286 return errors.Wrap(err, "field has wrong type") 287 } 288 289 // success, bind to var 290 field.GoFieldType = GoFieldVariable 291 field.GoReceiverName = "obj" 292 field.GoFieldName = structField.Name() 293 return nil 294 } 295 296 func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) { 297 var newArgs []FieldArgument 298 299 nextArg: 300 for j := 0; j < params.Len(); j++ { 301 param := params.At(j) 302 for _, oldArg := range field.Args { 303 if strings.EqualFold(oldArg.GQLName, param.Name()) { 304 if !field.ForceResolver { 305 oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) 306 } 307 newArgs = append(newArgs, oldArg) 308 continue nextArg 309 } 310 } 311 312 // no matching arg found, abort 313 return nil, fmt.Errorf("arg %s not found on method", param.Name()) 314 } 315 return newArgs, nil 316 } 317 318 func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error { 319 gqlType := normalizeVendor(field.Type.FullSignature()) 320 goTypeStr := normalizeVendor(goType.String()) 321 322 if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType { 323 field.Type.Modifiers = modifiersFromGoType(goType) 324 return nil 325 } 326 327 // deal with type aliases 328 underlyingStr := normalizeVendor(goType.Underlying().String()) 329 if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType { 330 field.Type.Modifiers = modifiersFromGoType(goType) 331 pkg, typ := pkgAndType(goType.String()) 332 imp := imports.findByPath(pkg) 333 field.AliasedType = &Ref{GoType: typ, Import: imp} 334 return nil 335 } 336 337 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr) 338 } 339 340 func modifiersFromGoType(t types.Type) []string { 341 var modifiers []string 342 for { 343 switch val := t.(type) { 344 case *types.Pointer: 345 modifiers = append(modifiers, modPtr) 346 t = val.Elem() 347 case *types.Array: 348 modifiers = append(modifiers, modList) 349 t = val.Elem() 350 case *types.Slice: 351 modifiers = append(modifiers, modList) 352 t = val.Elem() 353 default: 354 return modifiers 355 } 356 } 357 } 358 359 var modsRegex = regexp.MustCompile(`^(\*|\[\])*`) 360 361 func normalizeVendor(pkg string) string { 362 modifiers := modsRegex.FindAllString(pkg, 1)[0] 363 pkg = strings.TrimPrefix(pkg, modifiers) 364 parts := strings.Split(pkg, "/vendor/") 365 return modifiers + parts[len(parts)-1] 366 }