github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/codegen/config/binder.go (about) 1 package config 2 3 import ( 4 "errors" 5 "fmt" 6 "go/token" 7 "go/types" 8 9 "golang.org/x/tools/go/packages" 10 11 "github.com/mstephano/gqlgen-schemagen/codegen/templates" 12 "github.com/mstephano/gqlgen-schemagen/internal/code" 13 "github.com/vektah/gqlparser/v2/ast" 14 ) 15 16 var ErrTypeNotFound = errors.New("unable to find type") 17 18 // Binder connects graphql types to golang types using static analysis 19 type Binder struct { 20 pkgs *code.Packages 21 schema *ast.Schema 22 cfg *Config 23 References []*TypeReference 24 SawInvalid bool 25 objectCache map[string]map[string]types.Object 26 } 27 28 func (c *Config) NewBinder() *Binder { 29 return &Binder{ 30 pkgs: c.Packages, 31 schema: c.Schema, 32 cfg: c, 33 } 34 } 35 36 func (b *Binder) TypePosition(typ types.Type) token.Position { 37 named, isNamed := typ.(*types.Named) 38 if !isNamed { 39 return token.Position{ 40 Filename: "unknown", 41 } 42 } 43 44 return b.ObjectPosition(named.Obj()) 45 } 46 47 func (b *Binder) ObjectPosition(typ types.Object) token.Position { 48 if typ == nil { 49 return token.Position{ 50 Filename: "unknown", 51 } 52 } 53 pkg := b.pkgs.Load(typ.Pkg().Path()) 54 return pkg.Fset.Position(typ.Pos()) 55 } 56 57 func (b *Binder) FindTypeFromName(name string) (types.Type, error) { 58 pkgName, typeName := code.PkgAndType(name) 59 return b.FindType(pkgName, typeName) 60 } 61 62 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { 63 if pkgName == "" { 64 if typeName == "map[string]interface{}" { 65 return MapType, nil 66 } 67 68 if typeName == "interface{}" { 69 return InterfaceType, nil 70 } 71 } 72 73 obj, err := b.FindObject(pkgName, typeName) 74 if err != nil { 75 return nil, err 76 } 77 78 if fun, isFunc := obj.(*types.Func); isFunc { 79 return fun.Type().(*types.Signature).Params().At(0).Type(), nil 80 } 81 return obj.Type(), nil 82 } 83 84 var ( 85 MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) 86 InterfaceType = types.NewInterfaceType(nil, nil) 87 ) 88 89 func (b *Binder) DefaultUserObject(name string) (types.Type, error) { 90 models := b.cfg.Models[name].Model 91 if len(models) == 0 { 92 return nil, fmt.Errorf(name + " not found in typemap") 93 } 94 95 if models[0] == "map[string]interface{}" { 96 return MapType, nil 97 } 98 99 if models[0] == "interface{}" { 100 return InterfaceType, nil 101 } 102 103 pkgName, typeName := code.PkgAndType(models[0]) 104 if pkgName == "" { 105 return nil, fmt.Errorf("missing package name for %s", name) 106 } 107 108 obj, err := b.FindObject(pkgName, typeName) 109 if err != nil { 110 return nil, err 111 } 112 113 return obj.Type(), nil 114 } 115 116 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) { 117 if pkgName == "" { 118 return nil, fmt.Errorf("package cannot be nil") 119 } 120 121 pkg := b.pkgs.LoadWithTypes(pkgName) 122 if pkg == nil { 123 err := b.pkgs.Errors() 124 if err != nil { 125 return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err) 126 } 127 return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName) 128 } 129 130 if b.objectCache == nil { 131 b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count()) 132 } 133 134 defsIndex, ok := b.objectCache[pkgName] 135 if !ok { 136 defsIndex = indexDefs(pkg) 137 b.objectCache[pkgName] = defsIndex 138 } 139 140 // function based marshalers take precedence 141 if val, ok := defsIndex["Marshal"+typeName]; ok { 142 return val, nil 143 } 144 145 if val, ok := defsIndex[typeName]; ok { 146 return val, nil 147 } 148 149 return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName) 150 } 151 152 func indexDefs(pkg *packages.Package) map[string]types.Object { 153 res := make(map[string]types.Object) 154 155 scope := pkg.Types.Scope() 156 for astNode, def := range pkg.TypesInfo.Defs { 157 // only look at defs in the top scope 158 if def == nil { 159 continue 160 } 161 parent := def.Parent() 162 if parent == nil || parent != scope { 163 continue 164 } 165 166 if _, ok := res[astNode.Name]; !ok { 167 // The above check may not be really needed, it is only here to have a consistent behavior with 168 // previous implementation of FindObject() function which only honored the first inclusion of a def. 169 // If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups. 170 res[astNode.Name] = def 171 } 172 } 173 174 return res 175 } 176 177 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { 178 newRef := *ref 179 newRef.GO = types.NewPointer(ref.GO) 180 b.References = append(b.References, &newRef) 181 return &newRef 182 } 183 184 // TypeReference is used by args and field types. The Definition can refer to both input and output types. 185 type TypeReference struct { 186 Definition *ast.Definition 187 GQL *ast.Type 188 GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. 189 Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. 190 CastType types.Type // Before calling marshalling functions cast from/to this base type 191 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function 192 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function 193 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler 194 IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. 195 PointersInUmarshalInput bool // Inverse values and pointers in return. 196 } 197 198 func (ref *TypeReference) Elem() *TypeReference { 199 if p, isPtr := ref.GO.(*types.Pointer); isPtr { 200 newRef := *ref 201 newRef.GO = p.Elem() 202 return &newRef 203 } 204 205 if ref.IsSlice() { 206 newRef := *ref 207 newRef.GO = ref.GO.(*types.Slice).Elem() 208 newRef.GQL = ref.GQL.Elem 209 return &newRef 210 } 211 return nil 212 } 213 214 func (t *TypeReference) IsPtr() bool { 215 _, isPtr := t.GO.(*types.Pointer) 216 return isPtr 217 } 218 219 // fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful) 220 func (t *TypeReference) IsPtrToPtr() bool { 221 if p, isPtr := t.GO.(*types.Pointer); isPtr { 222 _, isPtr := p.Elem().(*types.Pointer) 223 return isPtr 224 } 225 return false 226 } 227 228 func (t *TypeReference) IsNilable() bool { 229 return IsNilable(t.GO) 230 } 231 232 func (t *TypeReference) IsSlice() bool { 233 _, isSlice := t.GO.(*types.Slice) 234 return t.GQL.Elem != nil && isSlice 235 } 236 237 func (t *TypeReference) IsPtrToSlice() bool { 238 if t.IsPtr() { 239 _, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice) 240 return isPointerToSlice 241 } 242 return false 243 } 244 245 func (t *TypeReference) IsNamed() bool { 246 _, isSlice := t.GO.(*types.Named) 247 return isSlice 248 } 249 250 func (t *TypeReference) IsStruct() bool { 251 _, isStruct := t.GO.Underlying().(*types.Struct) 252 return isStruct 253 } 254 255 func (t *TypeReference) IsUnderlyingBasic() bool { 256 _, isUnderlyingBasic := t.GO.Underlying().(*types.Basic) 257 return isUnderlyingBasic 258 } 259 260 func (t *TypeReference) IsScalarID() bool { 261 return t.Definition.Kind == ast.Scalar && t.Marshaler.Name() == "MarshalID" 262 } 263 264 func (t *TypeReference) IsScalar() bool { 265 return t.Definition.Kind == ast.Scalar 266 } 267 268 func (t *TypeReference) UniquenessKey() string { 269 nullability := "O" 270 if t.GQL.NonNull { 271 nullability = "N" 272 } 273 274 elemNullability := "" 275 if t.GQL.Elem != nil && t.GQL.Elem.NonNull { 276 // Fix for #896 277 elemNullability = "áš„" 278 } 279 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability 280 } 281 282 func (t *TypeReference) MarshalFunc() string { 283 if t.Definition == nil { 284 panic(errors.New("Definition missing for " + t.GQL.Name())) 285 } 286 287 if t.Definition.Kind == ast.InputObject { 288 return "" 289 } 290 291 return "marshal" + t.UniquenessKey() 292 } 293 294 func (t *TypeReference) UnmarshalFunc() string { 295 if t.Definition == nil { 296 panic(errors.New("Definition missing for " + t.GQL.Name())) 297 } 298 299 if !t.Definition.IsInputType() { 300 return "" 301 } 302 303 return "unmarshal" + t.UniquenessKey() 304 } 305 306 func (t *TypeReference) IsTargetNilable() bool { 307 return IsNilable(t.Target) 308 } 309 310 func (b *Binder) PushRef(ret *TypeReference) { 311 b.References = append(b.References, ret) 312 } 313 314 func isMap(t types.Type) bool { 315 if t == nil { 316 return true 317 } 318 _, ok := t.(*types.Map) 319 return ok 320 } 321 322 func isIntf(t types.Type) bool { 323 if t == nil { 324 return true 325 } 326 _, ok := t.(*types.Interface) 327 return ok 328 } 329 330 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) { 331 if !isValid(bindTarget) { 332 b.SawInvalid = true 333 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name()) 334 } 335 336 var pkgName, typeName string 337 def := b.schema.Types[schemaType.Name()] 338 defer func() { 339 if err == nil && ret != nil { 340 b.PushRef(ret) 341 } 342 }() 343 344 if len(b.cfg.Models[schemaType.Name()].Model) == 0 { 345 return nil, fmt.Errorf("%s was not found", schemaType.Name()) 346 } 347 348 for _, model := range b.cfg.Models[schemaType.Name()].Model { 349 if model == "map[string]interface{}" { 350 if !isMap(bindTarget) { 351 continue 352 } 353 return &TypeReference{ 354 Definition: def, 355 GQL: schemaType, 356 GO: MapType, 357 }, nil 358 } 359 360 if model == "interface{}" { 361 if !isIntf(bindTarget) { 362 continue 363 } 364 return &TypeReference{ 365 Definition: def, 366 GQL: schemaType, 367 GO: InterfaceType, 368 }, nil 369 } 370 371 pkgName, typeName = code.PkgAndType(model) 372 if pkgName == "" { 373 return nil, fmt.Errorf("missing package name for %s", schemaType.Name()) 374 } 375 376 ref := &TypeReference{ 377 Definition: def, 378 GQL: schemaType, 379 } 380 381 obj, err := b.FindObject(pkgName, typeName) 382 if err != nil { 383 return nil, err 384 } 385 386 if fun, isFunc := obj.(*types.Func); isFunc { 387 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() 388 ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/mstephano/gqlgen-schemagen/graphql.ContextMarshaler" 389 ref.Marshaler = fun 390 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) 391 } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") { 392 ref.GO = obj.Type() 393 ref.IsContext = true 394 ref.IsMarshaler = true 395 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { 396 ref.GO = obj.Type() 397 ref.IsMarshaler = true 398 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { 399 // TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595) 400 401 ref.GO = obj.Type() 402 ref.CastType = underlying 403 404 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) 405 if err != nil { 406 return nil, err 407 } 408 409 ref.Marshaler = underlyingRef.Marshaler 410 ref.Unmarshaler = underlyingRef.Unmarshaler 411 } else { 412 ref.GO = obj.Type() 413 } 414 415 ref.Target = ref.GO 416 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO) 417 418 if bindTarget != nil { 419 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil { 420 continue 421 } 422 ref.GO = bindTarget 423 } 424 425 ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput 426 427 return ref, nil 428 } 429 430 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String()) 431 } 432 433 func isValid(t types.Type) bool { 434 basic, isBasic := t.(*types.Basic) 435 if !isBasic { 436 return true 437 } 438 return basic.Kind() != types.Invalid 439 } 440 441 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { 442 if t.Elem != nil { 443 child := b.CopyModifiersFromAst(t.Elem, base) 444 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers { 445 child = types.NewPointer(child) 446 } 447 return types.NewSlice(child) 448 } 449 450 var isInterface bool 451 if named, ok := base.(*types.Named); ok { 452 _, isInterface = named.Underlying().(*types.Interface) 453 } 454 455 if !isInterface && !IsNilable(base) && !t.NonNull { 456 return types.NewPointer(base) 457 } 458 459 return base 460 } 461 462 func IsNilable(t types.Type) bool { 463 if namedType, isNamed := t.(*types.Named); isNamed { 464 return IsNilable(namedType.Underlying()) 465 } 466 _, isPtr := t.(*types.Pointer) 467 _, isMap := t.(*types.Map) 468 _, isInterface := t.(*types.Interface) 469 _, isSlice := t.(*types.Slice) 470 _, isChan := t.(*types.Chan) 471 return isPtr || isMap || isInterface || isSlice || isChan 472 } 473 474 func hasMethod(it types.Type, name string) bool { 475 if ptr, isPtr := it.(*types.Pointer); isPtr { 476 it = ptr.Elem() 477 } 478 namedType, ok := it.(*types.Named) 479 if !ok { 480 return false 481 } 482 483 for i := 0; i < namedType.NumMethods(); i++ { 484 if namedType.Method(i).Name() == name { 485 return true 486 } 487 } 488 return false 489 } 490 491 func basicUnderlying(it types.Type) *types.Basic { 492 if ptr, isPtr := it.(*types.Pointer); isPtr { 493 it = ptr.Elem() 494 } 495 namedType, ok := it.(*types.Named) 496 if !ok { 497 return nil 498 } 499 500 if basic, ok := namedType.Underlying().(*types.Basic); ok { 501 return basic 502 } 503 504 return nil 505 }