github.com/tobiash/gqlgen@v0.5.1/codegen/util.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "go/types" 6 "reflect" 7 "regexp" 8 "strings" 9 10 "github.com/pkg/errors" 11 "golang.org/x/tools/go/loader" 12 ) 13 14 func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) { 15 if pkgName == "" { 16 return nil, nil 17 } 18 fullName := typeName 19 if pkgName != "" { 20 fullName = pkgName + "." + typeName 21 } 22 23 pkgName, err := resolvePkg(pkgName) 24 if err != nil { 25 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) 26 } 27 28 pkg := prog.Imported[pkgName] 29 if pkg == nil { 30 return nil, errors.Errorf("required package was not loaded: %s", fullName) 31 } 32 33 for astNode, def := range pkg.Defs { 34 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() { 35 continue 36 } 37 38 return def, nil 39 } 40 41 return nil, errors.Errorf("unable to find type %s\n", fullName) 42 } 43 44 func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) { 45 def, err := findGoType(prog, pkgName, typeName) 46 if err != nil { 47 return nil, err 48 } 49 if def == nil { 50 return nil, nil 51 } 52 53 namedType, ok := def.Type().(*types.Named) 54 if !ok { 55 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type()) 56 } 57 58 return namedType, nil 59 } 60 61 func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) { 62 namedType, err := findGoNamedType(prog, pkgName, typeName) 63 if err != nil { 64 return nil, err 65 } 66 if namedType == nil { 67 return nil, nil 68 } 69 70 underlying, ok := namedType.Underlying().(*types.Interface) 71 if !ok { 72 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String()) 73 } 74 75 return underlying, nil 76 } 77 78 func findMethod(typ *types.Named, name string) *types.Func { 79 for i := 0; i < typ.NumMethods(); i++ { 80 method := typ.Method(i) 81 if !method.Exported() { 82 continue 83 } 84 85 if strings.EqualFold(method.Name(), name) { 86 return method 87 } 88 } 89 90 if s, ok := typ.Underlying().(*types.Struct); ok { 91 for i := 0; i < s.NumFields(); i++ { 92 field := s.Field(i) 93 if !field.Anonymous() { 94 continue 95 } 96 97 if named, ok := field.Type().(*types.Named); ok { 98 if f := findMethod(named, name); f != nil { 99 return f 100 } 101 } 102 } 103 } 104 105 return nil 106 } 107 108 // findField attempts to match the name to a struct field with the following 109 // priorites: 110 // 1. If struct tag is passed then struct tag has highest priority 111 // 2. Field in an embedded struct 112 // 3. Actual Field name 113 func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { 114 var foundField *types.Var 115 foundFieldWasTag := false 116 117 for i := 0; i < typ.NumFields(); i++ { 118 field := typ.Field(i) 119 120 if structTag != "" { 121 tags := reflect.StructTag(typ.Tag(i)) 122 if val, ok := tags.Lookup(structTag); ok { 123 if strings.EqualFold(val, name) { 124 if foundField != nil && foundFieldWasTag { 125 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) 126 } 127 128 foundField = field 129 foundFieldWasTag = true 130 } 131 } 132 } 133 134 if field.Anonymous() { 135 if named, ok := field.Type().(*types.Struct); ok { 136 f, err := findField(named, name, structTag) 137 if err != nil && !strings.HasPrefix(err.Error(), "no field named") { 138 return nil, err 139 } 140 if f != nil && foundField == nil { 141 foundField = f 142 } 143 } 144 145 if named, ok := field.Type().Underlying().(*types.Struct); ok { 146 f, err := findField(named, name, structTag) 147 if err != nil && !strings.HasPrefix(err.Error(), "no field named") { 148 return nil, err 149 } 150 if f != nil && foundField == nil { 151 foundField = f 152 } 153 } 154 } 155 156 if !field.Exported() { 157 continue 158 } 159 160 if strings.EqualFold(field.Name(), name) && foundField == nil { 161 foundField = field 162 } 163 } 164 165 if foundField == nil { 166 return nil, fmt.Errorf("no field named %s", name) 167 } 168 169 return foundField, nil 170 } 171 172 type BindError struct { 173 object *Object 174 field *Field 175 typ types.Type 176 methodErr error 177 varErr error 178 } 179 180 func (b BindError) Error() string { 181 return fmt.Sprintf( 182 "Unable to bind %s.%s to %s\n %s\n %s", 183 b.object.GQLType, 184 b.field.GQLName, 185 b.typ.String(), 186 b.methodErr.Error(), 187 b.varErr.Error(), 188 ) 189 } 190 191 type BindErrors []BindError 192 193 func (b BindErrors) Error() string { 194 var errs []string 195 for _, err := range b { 196 errs = append(errs, err.Error()) 197 } 198 return strings.Join(errs, "\n\n") 199 } 200 201 func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors { 202 var errs BindErrors 203 for i := range object.Fields { 204 field := &object.Fields[i] 205 206 if field.ForceResolver { 207 continue 208 } 209 210 // first try binding to a method 211 methodErr := bindMethod(imports, t, field) 212 if methodErr == nil { 213 continue 214 } 215 216 // otherwise try binding to a var 217 varErr := bindVar(imports, t, field, structTag) 218 219 if varErr != nil { 220 errs = append(errs, BindError{ 221 object: object, 222 typ: t, 223 field: field, 224 varErr: varErr, 225 methodErr: methodErr, 226 }) 227 } 228 } 229 return errs 230 } 231 232 func bindMethod(imports *Imports, t types.Type, field *Field) error { 233 namedType, ok := t.(*types.Named) 234 if !ok { 235 return fmt.Errorf("not a named type") 236 } 237 238 goName := field.GQLName 239 if field.GoFieldName != "" { 240 goName = field.GoFieldName 241 } 242 method := findMethod(namedType, goName) 243 if method == nil { 244 return fmt.Errorf("no method named %s", field.GQLName) 245 } 246 sig := method.Type().(*types.Signature) 247 248 if sig.Results().Len() == 1 { 249 field.NoErr = true 250 } else if sig.Results().Len() != 2 { 251 return fmt.Errorf("method has wrong number of args") 252 } 253 newArgs, err := matchArgs(field, sig.Params()) 254 if err != nil { 255 return err 256 } 257 258 result := sig.Results().At(0) 259 if err := validateTypeBinding(imports, field, result.Type()); err != nil { 260 return errors.Wrap(err, "method has wrong return type") 261 } 262 263 // success, args and return type match. Bind to method 264 field.GoFieldType = GoFieldMethod 265 field.GoReceiverName = "obj" 266 field.GoFieldName = method.Name() 267 field.Args = newArgs 268 return nil 269 } 270 271 func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error { 272 underlying, ok := t.Underlying().(*types.Struct) 273 if !ok { 274 return fmt.Errorf("not a struct") 275 } 276 277 goName := field.GQLName 278 if field.GoFieldName != "" { 279 goName = field.GoFieldName 280 } 281 structField, err := findField(underlying, goName, structTag) 282 if err != nil { 283 return err 284 } 285 286 if err := validateTypeBinding(imports, field, structField.Type()); err != nil { 287 return errors.Wrap(err, "field has wrong type") 288 } 289 290 // success, bind to var 291 field.GoFieldType = GoFieldVariable 292 field.GoReceiverName = "obj" 293 field.GoFieldName = structField.Name() 294 return nil 295 } 296 297 func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) { 298 var newArgs []FieldArgument 299 300 nextArg: 301 for j := 0; j < params.Len(); j++ { 302 param := params.At(j) 303 for _, oldArg := range field.Args { 304 if strings.EqualFold(oldArg.GQLName, param.Name()) { 305 if !field.ForceResolver { 306 oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) 307 } 308 newArgs = append(newArgs, oldArg) 309 continue nextArg 310 } 311 } 312 313 // no matching arg found, abort 314 return nil, fmt.Errorf("arg %s not found on method", param.Name()) 315 } 316 return newArgs, nil 317 } 318 319 func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error { 320 gqlType := normalizeVendor(field.Type.FullSignature()) 321 goTypeStr := normalizeVendor(goType.String()) 322 323 if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType { 324 field.Type.Modifiers = modifiersFromGoType(goType) 325 return nil 326 } 327 328 // deal with type aliases 329 underlyingStr := normalizeVendor(goType.Underlying().String()) 330 if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType { 331 field.Type.Modifiers = modifiersFromGoType(goType) 332 pkg, typ := pkgAndType(goType.String()) 333 imp := imports.findByPath(pkg) 334 field.AliasedType = &Ref{GoType: typ, Import: imp} 335 return nil 336 } 337 338 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr) 339 } 340 341 func modifiersFromGoType(t types.Type) []string { 342 var modifiers []string 343 for { 344 switch val := t.(type) { 345 case *types.Pointer: 346 modifiers = append(modifiers, modPtr) 347 t = val.Elem() 348 case *types.Array: 349 modifiers = append(modifiers, modList) 350 t = val.Elem() 351 case *types.Slice: 352 modifiers = append(modifiers, modList) 353 t = val.Elem() 354 default: 355 return modifiers 356 } 357 } 358 } 359 360 var modsRegex = regexp.MustCompile(`^(\*|\[\])*`) 361 362 func normalizeVendor(pkg string) string { 363 modifiers := modsRegex.FindAllString(pkg, 1)[0] 364 pkg = strings.TrimPrefix(pkg, modifiers) 365 parts := strings.Split(pkg, "/vendor/") 366 return modifiers + parts[len(parts)-1] 367 }