git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/config/binder.go (about) 1 package config 2 3 import ( 4 "fmt" 5 "go/token" 6 "go/types" 7 8 "git.sr.ht/~sircmpwn/gqlgen/codegen/templates" 9 "git.sr.ht/~sircmpwn/gqlgen/internal/code" 10 "github.com/pkg/errors" 11 "github.com/vektah/gqlparser/v2/ast" 12 ) 13 14 // Binder connects graphql types to golang types using static analysis 15 type Binder struct { 16 pkgs *code.Packages 17 schema *ast.Schema 18 cfg *Config 19 References []*TypeReference 20 SawInvalid bool 21 } 22 23 func (c *Config) NewBinder() *Binder { 24 return &Binder{ 25 pkgs: c.Packages, 26 schema: c.Schema, 27 cfg: c, 28 } 29 } 30 31 func (b *Binder) TypePosition(typ types.Type) token.Position { 32 named, isNamed := typ.(*types.Named) 33 if !isNamed { 34 return token.Position{ 35 Filename: "unknown", 36 } 37 } 38 39 return b.ObjectPosition(named.Obj()) 40 } 41 42 func (b *Binder) ObjectPosition(typ types.Object) token.Position { 43 if typ == nil { 44 return token.Position{ 45 Filename: "unknown", 46 } 47 } 48 pkg := b.pkgs.Load(typ.Pkg().Path()) 49 return pkg.Fset.Position(typ.Pos()) 50 } 51 52 func (b *Binder) FindTypeFromName(name string) (types.Type, error) { 53 pkgName, typeName := code.PkgAndType(name) 54 return b.FindType(pkgName, typeName) 55 } 56 57 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { 58 if pkgName == "" { 59 if typeName == "map[string]interface{}" { 60 return MapType, nil 61 } 62 63 if typeName == "interface{}" { 64 return InterfaceType, nil 65 } 66 } 67 68 obj, err := b.FindObject(pkgName, typeName) 69 if err != nil { 70 return nil, err 71 } 72 73 if fun, isFunc := obj.(*types.Func); isFunc { 74 return fun.Type().(*types.Signature).Params().At(0).Type(), nil 75 } 76 return obj.Type(), nil 77 } 78 79 var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) 80 var InterfaceType = types.NewInterfaceType(nil, nil) 81 82 func (b *Binder) DefaultUserObject(name string) (types.Type, error) { 83 models := b.cfg.Models[name].Model 84 if len(models) == 0 { 85 return nil, fmt.Errorf(name + " not found in typemap") 86 } 87 88 if models[0] == "map[string]interface{}" { 89 return MapType, nil 90 } 91 92 if models[0] == "interface{}" { 93 return InterfaceType, nil 94 } 95 96 pkgName, typeName := code.PkgAndType(models[0]) 97 if pkgName == "" { 98 return nil, fmt.Errorf("missing package name for %s", name) 99 } 100 101 obj, err := b.FindObject(pkgName, typeName) 102 if err != nil { 103 return nil, err 104 } 105 106 return obj.Type(), nil 107 } 108 109 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) { 110 if pkgName == "" { 111 return nil, fmt.Errorf("package cannot be nil") 112 } 113 fullName := typeName 114 if pkgName != "" { 115 fullName = pkgName + "." + typeName 116 } 117 118 pkg := b.pkgs.LoadWithTypes(pkgName) 119 if pkg == nil { 120 return nil, errors.Errorf("required package was not loaded: %s", fullName) 121 } 122 123 // function based marshalers take precedence 124 for astNode, def := range pkg.TypesInfo.Defs { 125 // only look at defs in the top scope 126 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() { 127 continue 128 } 129 130 if astNode.Name == "Marshal"+typeName { 131 return def, nil 132 } 133 } 134 135 // then look for types directly 136 for astNode, def := range pkg.TypesInfo.Defs { 137 // only look at defs in the top scope 138 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() { 139 continue 140 } 141 142 if astNode.Name == typeName { 143 return def, nil 144 } 145 } 146 147 return nil, errors.Errorf("unable to find type %s\n", fullName) 148 } 149 150 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { 151 newRef := &TypeReference{ 152 GO: types.NewPointer(ref.GO), 153 GQL: ref.GQL, 154 CastType: ref.CastType, 155 Definition: ref.Definition, 156 Unmarshaler: ref.Unmarshaler, 157 Marshaler: ref.Marshaler, 158 IsMarshaler: ref.IsMarshaler, 159 } 160 161 b.References = append(b.References, newRef) 162 return newRef 163 } 164 165 // TypeReference is used by args and field types. The Definition can refer to both input and output types. 166 type TypeReference struct { 167 Definition *ast.Definition 168 GQL *ast.Type 169 GO types.Type 170 CastType types.Type // Before calling marshalling functions cast from/to this base type 171 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function 172 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function 173 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler 174 } 175 176 func (ref *TypeReference) Elem() *TypeReference { 177 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 178 return &TypeReference{ 179 GO: p.Elem(), 180 GQL: ref.GQL, 181 CastType: ref.CastType, 182 Definition: ref.Definition, 183 Unmarshaler: ref.Unmarshaler, 184 Marshaler: ref.Marshaler, 185 IsMarshaler: ref.IsMarshaler, 186 } 187 } 188 189 if ref.IsSlice() { 190 return &TypeReference{ 191 GO: ref.GO.(*types.Slice).Elem(), 192 GQL: ref.GQL.Elem, 193 CastType: ref.CastType, 194 Definition: ref.Definition, 195 Unmarshaler: ref.Unmarshaler, 196 Marshaler: ref.Marshaler, 197 IsMarshaler: ref.IsMarshaler, 198 } 199 } 200 return nil 201 } 202 203 func (t *TypeReference) IsPtr() bool { 204 _, isPtr := t.GO.(*types.Pointer) 205 return isPtr 206 } 207 208 func (t *TypeReference) IsNilable() bool { 209 return IsNilable(t.GO) 210 } 211 212 func (t *TypeReference) IsSlice() bool { 213 _, isSlice := t.GO.(*types.Slice) 214 return t.GQL.Elem != nil && isSlice 215 } 216 217 func (t *TypeReference) IsNamed() bool { 218 _, isSlice := t.GO.(*types.Named) 219 return isSlice 220 } 221 222 func (t *TypeReference) IsStruct() bool { 223 _, isStruct := t.GO.Underlying().(*types.Struct) 224 return isStruct 225 } 226 227 func (t *TypeReference) IsScalar() bool { 228 return t.Definition.Kind == ast.Scalar 229 } 230 231 func (t *TypeReference) UniquenessKey() string { 232 var nullability = "O" 233 if t.GQL.NonNull { 234 nullability = "N" 235 } 236 237 var elemNullability = "" 238 if t.GQL.Elem != nil && t.GQL.Elem.NonNull { 239 // Fix for #896 240 elemNullability = "áš„" 241 } 242 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability 243 } 244 245 func (t *TypeReference) MarshalFunc() string { 246 if t.Definition == nil { 247 panic(errors.New("Definition missing for " + t.GQL.Name())) 248 } 249 250 if t.Definition.Kind == ast.InputObject { 251 return "" 252 } 253 254 return "marshal" + t.UniquenessKey() 255 } 256 257 func (t *TypeReference) UnmarshalFunc() string { 258 if t.Definition == nil { 259 panic(errors.New("Definition missing for " + t.GQL.Name())) 260 } 261 262 if !t.Definition.IsInputType() { 263 return "" 264 } 265 266 return "unmarshal" + t.UniquenessKey() 267 } 268 269 func (b *Binder) PushRef(ret *TypeReference) { 270 b.References = append(b.References, ret) 271 } 272 273 func isMap(t types.Type) bool { 274 if t == nil { 275 return true 276 } 277 _, ok := t.(*types.Map) 278 return ok 279 } 280 281 func isIntf(t types.Type) bool { 282 if t == nil { 283 return true 284 } 285 _, ok := t.(*types.Interface) 286 return ok 287 } 288 289 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) { 290 if !isValid(bindTarget) { 291 b.SawInvalid = true 292 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name()) 293 } 294 295 var pkgName, typeName string 296 def := b.schema.Types[schemaType.Name()] 297 defer func() { 298 if err == nil && ret != nil { 299 b.PushRef(ret) 300 } 301 }() 302 303 if len(b.cfg.Models[schemaType.Name()].Model) == 0 { 304 return nil, fmt.Errorf("%s was not found", schemaType.Name()) 305 } 306 307 for _, model := range b.cfg.Models[schemaType.Name()].Model { 308 if model == "map[string]interface{}" { 309 if !isMap(bindTarget) { 310 continue 311 } 312 return &TypeReference{ 313 Definition: def, 314 GQL: schemaType, 315 GO: MapType, 316 }, nil 317 } 318 319 if model == "interface{}" { 320 if !isIntf(bindTarget) { 321 continue 322 } 323 return &TypeReference{ 324 Definition: def, 325 GQL: schemaType, 326 GO: InterfaceType, 327 }, nil 328 } 329 330 pkgName, typeName = code.PkgAndType(model) 331 if pkgName == "" { 332 return nil, fmt.Errorf("missing package name for %s", schemaType.Name()) 333 } 334 335 ref := &TypeReference{ 336 Definition: def, 337 GQL: schemaType, 338 } 339 340 obj, err := b.FindObject(pkgName, typeName) 341 if err != nil { 342 return nil, err 343 } 344 345 if fun, isFunc := obj.(*types.Func); isFunc { 346 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() 347 ref.Marshaler = fun 348 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) 349 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { 350 ref.GO = obj.Type() 351 ref.IsMarshaler = true 352 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { 353 // Special case for named types wrapping strings. Used by default enum implementations. 354 355 ref.GO = obj.Type() 356 ref.CastType = underlying 357 358 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) 359 if err != nil { 360 return nil, err 361 } 362 363 ref.Marshaler = underlyingRef.Marshaler 364 ref.Unmarshaler = underlyingRef.Unmarshaler 365 } else { 366 ref.GO = obj.Type() 367 } 368 369 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO) 370 371 if bindTarget != nil { 372 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil { 373 continue 374 } 375 ref.GO = bindTarget 376 } 377 378 return ref, nil 379 } 380 381 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String()) 382 } 383 384 func isValid(t types.Type) bool { 385 basic, isBasic := t.(*types.Basic) 386 if !isBasic { 387 return true 388 } 389 return basic.Kind() != types.Invalid 390 } 391 392 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { 393 if t.Elem != nil { 394 child := b.CopyModifiersFromAst(t.Elem, base) 395 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers { 396 child = types.NewPointer(child) 397 } 398 return types.NewSlice(child) 399 } 400 401 var isInterface bool 402 if named, ok := base.(*types.Named); ok { 403 _, isInterface = named.Underlying().(*types.Interface) 404 } 405 406 if !isInterface && !IsNilable(base) && !t.NonNull { 407 return types.NewPointer(base) 408 } 409 410 return base 411 } 412 413 func IsNilable(t types.Type) bool { 414 if namedType, isNamed := t.(*types.Named); isNamed { 415 t = namedType.Underlying() 416 } 417 _, isPtr := t.(*types.Pointer) 418 _, isMap := t.(*types.Map) 419 _, isInterface := t.(*types.Interface) 420 return isPtr || isMap || isInterface 421 } 422 423 func hasMethod(it types.Type, name string) bool { 424 if ptr, isPtr := it.(*types.Pointer); isPtr { 425 it = ptr.Elem() 426 } 427 namedType, ok := it.(*types.Named) 428 if !ok { 429 return false 430 } 431 432 for i := 0; i < namedType.NumMethods(); i++ { 433 if namedType.Method(i).Name() == name { 434 return true 435 } 436 } 437 return false 438 } 439 440 func basicUnderlying(it types.Type) *types.Basic { 441 if ptr, isPtr := it.(*types.Pointer); isPtr { 442 it = ptr.Elem() 443 } 444 namedType, ok := it.(*types.Named) 445 if !ok { 446 return nil 447 } 448 449 if basic, ok := namedType.Underlying().(*types.Basic); ok { 450 return basic 451 } 452 453 return nil 454 }