github.com/codyleyhan/gqlgen@v0.4.4/codegen/util.go (about) 1 package codegen 2 3 import ( 4 "fmt" 5 "go/types" 6 "regexp" 7 "strings" 8 9 "github.com/pkg/errors" 10 "golang.org/x/tools/go/loader" 11 ) 12 13 func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) { 14 if pkgName == "" { 15 return nil, nil 16 } 17 fullName := typeName 18 if pkgName != "" { 19 fullName = pkgName + "." + typeName 20 } 21 22 pkgName, err := resolvePkg(pkgName) 23 if err != nil { 24 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) 25 } 26 27 pkg := prog.Imported[pkgName] 28 if pkg == nil { 29 return nil, errors.Errorf("required package was not loaded: %s", fullName) 30 } 31 32 for astNode, def := range pkg.Defs { 33 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() { 34 continue 35 } 36 37 return def, nil 38 } 39 40 return nil, errors.Errorf("unable to find type %s\n", fullName) 41 } 42 43 func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) { 44 def, err := findGoType(prog, pkgName, typeName) 45 if err != nil { 46 return nil, err 47 } 48 if def == nil { 49 return nil, nil 50 } 51 52 namedType, ok := def.Type().(*types.Named) 53 if !ok { 54 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type()) 55 } 56 57 return namedType, nil 58 } 59 60 func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) { 61 namedType, err := findGoNamedType(prog, pkgName, typeName) 62 if err != nil { 63 return nil, err 64 } 65 if namedType == nil { 66 return nil, nil 67 } 68 69 underlying, ok := namedType.Underlying().(*types.Interface) 70 if !ok { 71 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String()) 72 } 73 74 return underlying, nil 75 } 76 77 func findMethod(typ *types.Named, name string) *types.Func { 78 for i := 0; i < typ.NumMethods(); i++ { 79 method := typ.Method(i) 80 if !method.Exported() { 81 continue 82 } 83 84 if strings.EqualFold(method.Name(), name) { 85 return method 86 } 87 } 88 89 if s, ok := typ.Underlying().(*types.Struct); ok { 90 for i := 0; i < s.NumFields(); i++ { 91 field := s.Field(i) 92 if !field.Anonymous() { 93 continue 94 } 95 96 if named, ok := field.Type().(*types.Named); ok { 97 if f := findMethod(named, name); f != nil { 98 return f 99 } 100 } 101 } 102 } 103 104 return nil 105 } 106 107 func findField(typ *types.Struct, name string) *types.Var { 108 for i := 0; i < typ.NumFields(); i++ { 109 field := typ.Field(i) 110 if field.Anonymous() { 111 if named, ok := field.Type().(*types.Struct); ok { 112 if f := findField(named, name); f != nil { 113 return f 114 } 115 } 116 117 if named, ok := field.Type().Underlying().(*types.Struct); ok { 118 if f := findField(named, name); f != nil { 119 return f 120 } 121 } 122 } 123 124 if !field.Exported() { 125 continue 126 } 127 128 if strings.EqualFold(field.Name(), name) { 129 return field 130 } 131 } 132 return nil 133 } 134 135 type BindError struct { 136 object *Object 137 field *Field 138 typ types.Type 139 methodErr error 140 varErr error 141 } 142 143 func (b BindError) Error() string { 144 return fmt.Sprintf( 145 "Unable to bind %s.%s to %s\n %s\n %s", 146 b.object.GQLType, 147 b.field.GQLName, 148 b.typ.String(), 149 b.methodErr.Error(), 150 b.varErr.Error(), 151 ) 152 } 153 154 type BindErrors []BindError 155 156 func (b BindErrors) Error() string { 157 var errs []string 158 for _, err := range b { 159 errs = append(errs, err.Error()) 160 } 161 return strings.Join(errs, "\n\n") 162 } 163 164 func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { 165 var errs BindErrors 166 for i := range object.Fields { 167 field := &object.Fields[i] 168 169 if field.ForceResolver { 170 continue 171 } 172 173 // first try binding to a method 174 methodErr := bindMethod(imports, t, field) 175 if methodErr == nil { 176 continue 177 } 178 179 // otherwise try binding to a var 180 varErr := bindVar(imports, t, field) 181 182 if varErr != nil { 183 errs = append(errs, BindError{ 184 object: object, 185 typ: t, 186 field: field, 187 varErr: varErr, 188 methodErr: methodErr, 189 }) 190 } 191 } 192 return errs 193 } 194 195 func bindMethod(imports *Imports, t types.Type, field *Field) error { 196 namedType, ok := t.(*types.Named) 197 if !ok { 198 return fmt.Errorf("not a named type") 199 } 200 201 goName := field.GQLName 202 if field.GoFieldName != "" { 203 goName = field.GoFieldName 204 } 205 method := findMethod(namedType, goName) 206 if method == nil { 207 return fmt.Errorf("no method named %s", field.GQLName) 208 } 209 sig := method.Type().(*types.Signature) 210 211 if sig.Results().Len() == 1 { 212 field.NoErr = true 213 } else if sig.Results().Len() != 2 { 214 return fmt.Errorf("method has wrong number of args") 215 } 216 newArgs, err := matchArgs(field, sig.Params()) 217 if err != nil { 218 return err 219 } 220 221 result := sig.Results().At(0) 222 if err := validateTypeBinding(imports, field, result.Type()); err != nil { 223 return errors.Wrap(err, "method has wrong return type") 224 } 225 226 // success, args and return type match. Bind to method 227 field.GoFieldType = GoFieldMethod 228 field.GoReceiverName = "obj" 229 field.GoFieldName = method.Name() 230 field.Args = newArgs 231 return nil 232 } 233 234 func bindVar(imports *Imports, t types.Type, field *Field) error { 235 underlying, ok := t.Underlying().(*types.Struct) 236 if !ok { 237 return fmt.Errorf("not a struct") 238 } 239 240 goName := field.GQLName 241 if field.GoFieldName != "" { 242 goName = field.GoFieldName 243 } 244 structField := findField(underlying, goName) 245 if structField == nil { 246 return fmt.Errorf("no field named %s", field.GQLName) 247 } 248 249 if err := validateTypeBinding(imports, field, structField.Type()); err != nil { 250 return errors.Wrap(err, "field has wrong type") 251 } 252 253 // success, bind to var 254 field.GoFieldType = GoFieldVariable 255 field.GoReceiverName = "obj" 256 field.GoFieldName = structField.Name() 257 return nil 258 } 259 260 func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) { 261 var newArgs []FieldArgument 262 263 nextArg: 264 for j := 0; j < params.Len(); j++ { 265 param := params.At(j) 266 for _, oldArg := range field.Args { 267 if strings.EqualFold(oldArg.GQLName, param.Name()) { 268 if !field.ForceResolver { 269 oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) 270 } 271 newArgs = append(newArgs, oldArg) 272 continue nextArg 273 } 274 } 275 276 // no matching arg found, abort 277 return nil, fmt.Errorf("arg %s not found on method", param.Name()) 278 } 279 return newArgs, nil 280 } 281 282 func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error { 283 gqlType := normalizeVendor(field.Type.FullSignature()) 284 goTypeStr := normalizeVendor(goType.String()) 285 286 if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType { 287 field.Type.Modifiers = modifiersFromGoType(goType) 288 return nil 289 } 290 291 // deal with type aliases 292 underlyingStr := normalizeVendor(goType.Underlying().String()) 293 if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType { 294 field.Type.Modifiers = modifiersFromGoType(goType) 295 pkg, typ := pkgAndType(goType.String()) 296 imp := imports.findByPath(pkg) 297 field.AliasedType = &Ref{GoType: typ, Import: imp} 298 return nil 299 } 300 301 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr) 302 } 303 304 func modifiersFromGoType(t types.Type) []string { 305 var modifiers []string 306 for { 307 switch val := t.(type) { 308 case *types.Pointer: 309 modifiers = append(modifiers, modPtr) 310 t = val.Elem() 311 case *types.Array: 312 modifiers = append(modifiers, modList) 313 t = val.Elem() 314 case *types.Slice: 315 modifiers = append(modifiers, modList) 316 t = val.Elem() 317 default: 318 return modifiers 319 } 320 } 321 } 322 323 var modsRegex = regexp.MustCompile(`^(\*|\[\])*`) 324 325 func normalizeVendor(pkg string) string { 326 modifiers := modsRegex.FindAllString(pkg, 1)[0] 327 pkg = strings.TrimPrefix(pkg, modifiers) 328 parts := strings.Split(pkg, "/vendor/") 329 return modifiers + parts[len(parts)-1] 330 }