github.com/plages-ai/gqlgen@v0.7.2/codegen/util.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "go/types" 6 "reflect" 7 "regexp" 8 "strings" 9 10 "github.com/pkg/errors" 11 "golang.org/x/tools/go/loader" 12 ) 13 14 func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) { 15 if pkgName == "" { 16 return nil, nil 17 } 18 fullName := typeName 19 if pkgName != "" { 20 fullName = pkgName + "." + typeName 21 } 22 23 pkgName, err := resolvePkg(pkgName) 24 if err != nil { 25 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) 26 } 27 28 pkg := prog.Imported[pkgName] 29 if pkg == nil { 30 return nil, errors.Errorf("required package was not loaded: %s", fullName) 31 } 32 33 for astNode, def := range pkg.Defs { 34 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() { 35 continue 36 } 37 38 return def, nil 39 } 40 41 return nil, errors.Errorf("unable to find type %s\n", fullName) 42 } 43 44 func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) { 45 def, err := findGoType(prog, pkgName, typeName) 46 if err != nil { 47 return nil, err 48 } 49 if def == nil { 50 return nil, nil 51 } 52 53 namedType, ok := def.Type().(*types.Named) 54 if !ok { 55 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type()) 56 } 57 58 return namedType, nil 59 } 60 61 func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) { 62 namedType, err := findGoNamedType(prog, pkgName, typeName) 63 if err != nil { 64 return nil, err 65 } 66 if namedType == nil { 67 return nil, nil 68 } 69 70 underlying, ok := namedType.Underlying().(*types.Interface) 71 if !ok { 72 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String()) 73 } 74 75 return underlying, nil 76 } 77 78 func findMethod(typ *types.Named, name string) *types.Func { 79 for i := 0; i < typ.NumMethods(); i++ { 80 method := typ.Method(i) 81 if !method.Exported() { 82 continue 83 } 84 85 if strings.EqualFold(method.Name(), name) { 86 return method 87 } 88 } 89 90 if s, ok := typ.Underlying().(*types.Struct); ok { 91 for i := 0; i < s.NumFields(); i++ { 92 field := s.Field(i) 93 if !field.Anonymous() { 94 continue 95 } 96 97 if named, ok := field.Type().(*types.Named); ok { 98 if f := findMethod(named, name); f != nil { 99 return f 100 } 101 } 102 } 103 } 104 105 return nil 106 } 107 108 func equalFieldName(source, target string) bool { 109 source = strings.Replace(source, "_", "", -1) 110 target = strings.Replace(target, "_", "", -1) 111 return strings.EqualFold(source, target) 112 } 113 114 // findField attempts to match the name to a struct field with the following 115 // priorites: 116 // 1. If struct tag is passed then struct tag has highest priority 117 // 2. Field in an embedded struct 118 // 3. Actual Field name 119 func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { 120 var foundField *types.Var 121 foundFieldWasTag := false 122 123 for i := 0; i < typ.NumFields(); i++ { 124 field := typ.Field(i) 125 126 if structTag != "" { 127 tags := reflect.StructTag(typ.Tag(i)) 128 if val, ok := tags.Lookup(structTag); ok { 129 if equalFieldName(val, name) { 130 if foundField != nil && foundFieldWasTag { 131 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) 132 } 133 134 foundField = field 135 foundFieldWasTag = true 136 } 137 } 138 } 139 140 if field.Anonymous() { 141 142 fieldType := field.Type() 143 144 if ptr, ok := fieldType.(*types.Pointer); ok { 145 fieldType = ptr.Elem() 146 } 147 148 // Type.Underlying() returns itself for all types except types.Named, where it returns a struct type. 149 // It should be safe to always call. 150 if named, ok := fieldType.Underlying().(*types.Struct); ok { 151 f, err := findField(named, name, structTag) 152 if err != nil && !strings.HasPrefix(err.Error(), "no field named") { 153 return nil, err 154 } 155 if f != nil && foundField == nil { 156 foundField = f 157 } 158 } 159 } 160 161 if !field.Exported() { 162 continue 163 } 164 165 if equalFieldName(field.Name(), name) && foundField == nil { // aqui! 166 foundField = field 167 } 168 } 169 170 if foundField == nil { 171 return nil, fmt.Errorf("no field named %s", name) 172 } 173 174 return foundField, nil 175 } 176 177 type BindError struct { 178 object *Object 179 field *Field 180 typ types.Type 181 methodErr error 182 varErr error 183 } 184 185 func (b BindError) Error() string { 186 return fmt.Sprintf( 187 "Unable to bind %s.%s to %s\n %s\n %s", 188 b.object.GQLType, 189 b.field.GQLName, 190 b.typ.String(), 191 b.methodErr.Error(), 192 b.varErr.Error(), 193 ) 194 } 195 196 type BindErrors []BindError 197 198 func (b BindErrors) Error() string { 199 var errs []string 200 for _, err := range b { 201 errs = append(errs, err.Error()) 202 } 203 return strings.Join(errs, "\n\n") 204 } 205 206 func bindObject(t types.Type, object *Object, structTag string) BindErrors { 207 var errs BindErrors 208 for i := range object.Fields { 209 field := &object.Fields[i] 210 211 if field.ForceResolver { 212 continue 213 } 214 215 // first try binding to a method 216 methodErr := bindMethod(t, field) 217 if methodErr == nil { 218 continue 219 } 220 221 // otherwise try binding to a var 222 varErr := bindVar(t, field, structTag) 223 224 if varErr != nil { 225 errs = append(errs, BindError{ 226 object: object, 227 typ: t, 228 field: field, 229 varErr: varErr, 230 methodErr: methodErr, 231 }) 232 } 233 } 234 return errs 235 } 236 237 func bindMethod(t types.Type, field *Field) error { 238 namedType, ok := t.(*types.Named) 239 if !ok { 240 return fmt.Errorf("not a named type") 241 } 242 243 goName := field.GQLName 244 if field.GoFieldName != "" { 245 goName = field.GoFieldName 246 } 247 method := findMethod(namedType, goName) 248 if method == nil { 249 return fmt.Errorf("no method named %s", field.GQLName) 250 } 251 sig := method.Type().(*types.Signature) 252 253 if sig.Results().Len() == 1 { 254 field.NoErr = true 255 } else if sig.Results().Len() != 2 { 256 return fmt.Errorf("method has wrong number of args") 257 } 258 params := sig.Params() 259 // If the first argument is the context, remove it from the comparison and set 260 // the MethodHasContext flag so that the context will be passed to this model's method 261 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { 262 field.MethodHasContext = true 263 vars := make([]*types.Var, params.Len()-1) 264 for i := 1; i < params.Len(); i++ { 265 vars[i-1] = params.At(i) 266 } 267 params = types.NewTuple(vars...) 268 } 269 270 newArgs, err := matchArgs(field, params) 271 if err != nil { 272 return err 273 } 274 275 result := sig.Results().At(0) 276 if err := validateTypeBinding(field, result.Type()); err != nil { 277 return errors.Wrap(err, "method has wrong return type") 278 } 279 280 // success, args and return type match. Bind to method 281 field.GoFieldType = GoFieldMethod 282 field.GoReceiverName = "obj" 283 field.GoFieldName = method.Name() 284 field.Args = newArgs 285 return nil 286 } 287 288 func bindVar(t types.Type, field *Field, structTag string) error { 289 underlying, ok := t.Underlying().(*types.Struct) 290 if !ok { 291 return fmt.Errorf("not a struct") 292 } 293 294 goName := field.GQLName 295 if field.GoFieldName != "" { 296 goName = field.GoFieldName 297 } 298 structField, err := findField(underlying, goName, structTag) 299 if err != nil { 300 return err 301 } 302 303 if err := validateTypeBinding(field, structField.Type()); err != nil { 304 return errors.Wrap(err, "field has wrong type") 305 } 306 307 // success, bind to var 308 field.GoFieldType = GoFieldVariable 309 field.GoReceiverName = "obj" 310 field.GoFieldName = structField.Name() 311 return nil 312 } 313 314 func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) { 315 var newArgs []FieldArgument 316 317 nextArg: 318 for j := 0; j < params.Len(); j++ { 319 param := params.At(j) 320 for _, oldArg := range field.Args { 321 if strings.EqualFold(oldArg.GQLName, param.Name()) { 322 if !field.ForceResolver { 323 oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) 324 } 325 newArgs = append(newArgs, oldArg) 326 continue nextArg 327 } 328 } 329 330 // no matching arg found, abort 331 return nil, fmt.Errorf("arg %s not found on method", param.Name()) 332 } 333 return newArgs, nil 334 } 335 336 func validateTypeBinding(field *Field, goType types.Type) error { 337 gqlType := normalizeVendor(field.Type.FullSignature()) 338 goTypeStr := normalizeVendor(goType.String()) 339 340 if equalTypes(goTypeStr, gqlType) { 341 field.Type.Modifiers = modifiersFromGoType(goType) 342 return nil 343 } 344 345 // deal with type aliases 346 underlyingStr := normalizeVendor(goType.Underlying().String()) 347 if equalTypes(underlyingStr, gqlType) { 348 field.Type.Modifiers = modifiersFromGoType(goType) 349 pkg, typ := pkgAndType(goType.String()) 350 field.AliasedType = &Ref{GoType: typ, Package: pkg} 351 return nil 352 } 353 354 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr) 355 } 356 357 func modifiersFromGoType(t types.Type) []string { 358 var modifiers []string 359 for { 360 switch val := t.(type) { 361 case *types.Pointer: 362 modifiers = append(modifiers, modPtr) 363 t = val.Elem() 364 case *types.Array: 365 modifiers = append(modifiers, modList) 366 t = val.Elem() 367 case *types.Slice: 368 modifiers = append(modifiers, modList) 369 t = val.Elem() 370 default: 371 return modifiers 372 } 373 } 374 } 375 376 var modsRegex = regexp.MustCompile(`^(\*|\[\])*`) 377 378 func normalizeVendor(pkg string) string { 379 modifiers := modsRegex.FindAllString(pkg, 1)[0] 380 pkg = strings.TrimPrefix(pkg, modifiers) 381 parts := strings.Split(pkg, "/vendor/") 382 return modifiers + parts[len(parts)-1] 383 } 384 385 func equalTypes(goType string, gqlType string) bool { 386 return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType 387 }