github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/codegen/config/config.go (about) 1 package config 2 3 import ( 4 "bytes" 5 "fmt" 6 "os" 7 "path/filepath" 8 "regexp" 9 "sort" 10 "strings" 11 12 "github.com/mstephano/gqlgen-schemagen/internal/code" 13 "github.com/vektah/gqlparser/v2" 14 "github.com/vektah/gqlparser/v2/ast" 15 "gopkg.in/yaml.v3" 16 ) 17 18 type Config struct { 19 SchemaFilename StringList `yaml:"schema,omitempty"` 20 Exec ExecConfig `yaml:"exec"` 21 Model PackageConfig `yaml:"model,omitempty"` 22 Federation PackageConfig `yaml:"federation,omitempty"` 23 Resolver ResolverConfig `yaml:"resolver,omitempty"` 24 AutoBind []string `yaml:"autobind"` 25 Models TypeMap `yaml:"models,omitempty"` 26 StructTag string `yaml:"struct_tag,omitempty"` 27 Directives map[string]DirectiveConfig `yaml:"directives,omitempty"` 28 OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"` 29 OmitGetters bool `yaml:"omit_getters,omitempty"` 30 StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"` 31 ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"` 32 ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"` 33 SkipValidation bool `yaml:"skip_validation,omitempty"` 34 SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"` 35 Sources []*ast.Source `yaml:"-"` 36 Packages *code.Packages `yaml:"-"` 37 Schema *ast.Schema `yaml:"-"` 38 39 // Deprecated: use Federation instead. Will be removed next release 40 Federated bool `yaml:"federated,omitempty"` 41 } 42 43 var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"} 44 45 // DefaultConfig creates a copy of the default config 46 func DefaultConfig() *Config { 47 return &Config{ 48 SchemaFilename: StringList{"schema.graphql"}, 49 Model: PackageConfig{Filename: "models_gen.go"}, 50 Exec: ExecConfig{Filename: "generated.go"}, 51 Directives: map[string]DirectiveConfig{}, 52 Models: TypeMap{}, 53 StructFieldsAlwaysPointers: true, 54 ReturnPointersInUmarshalInput: false, 55 ResolversAlwaysReturnPointers: true, 56 } 57 } 58 59 // LoadDefaultConfig loads the default config so that it is ready to be used 60 func LoadDefaultConfig() (*Config, error) { 61 config := DefaultConfig() 62 63 for _, filename := range config.SchemaFilename { 64 filename = filepath.ToSlash(filename) 65 var err error 66 var schemaRaw []byte 67 schemaRaw, err = os.ReadFile(filename) 68 if err != nil { 69 return nil, fmt.Errorf("unable to open schema: %w", err) 70 } 71 72 config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)}) 73 } 74 75 return config, nil 76 } 77 78 // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories 79 // walking up the tree. The closest config file will be returned. 80 func LoadConfigFromDefaultLocations() (*Config, error) { 81 cfgFile, err := findCfg() 82 if err != nil { 83 return nil, err 84 } 85 86 err = os.Chdir(filepath.Dir(cfgFile)) 87 if err != nil { 88 return nil, fmt.Errorf("unable to enter config dir: %w", err) 89 } 90 return LoadConfig(cfgFile) 91 } 92 93 var path2regex = strings.NewReplacer( 94 `.`, `\.`, 95 `*`, `.+`, 96 `\`, `[\\/]`, 97 `/`, `[\\/]`, 98 ) 99 100 // LoadConfig reads the gqlgen.yml config file 101 func LoadConfig(filename string) (*Config, error) { 102 config := DefaultConfig() 103 104 b, err := os.ReadFile(filename) 105 if err != nil { 106 return nil, fmt.Errorf("unable to read config: %w", err) 107 } 108 109 dec := yaml.NewDecoder(bytes.NewReader(b)) 110 dec.KnownFields(true) 111 112 if err := dec.Decode(config); err != nil { 113 return nil, fmt.Errorf("unable to parse config: %w", err) 114 } 115 116 if err := CompleteConfig(config); err != nil { 117 return nil, err 118 } 119 120 return config, nil 121 } 122 123 // CompleteConfig fills in the schema and other values to a config loaded from 124 // YAML. 125 func CompleteConfig(config *Config) error { 126 defaultDirectives := map[string]DirectiveConfig{ 127 "skip": {SkipRuntime: true}, 128 "include": {SkipRuntime: true}, 129 "deprecated": {SkipRuntime: true}, 130 "specifiedBy": {SkipRuntime: true}, 131 } 132 133 for key, value := range defaultDirectives { 134 if _, defined := config.Directives[key]; !defined { 135 config.Directives[key] = value 136 } 137 } 138 139 preGlobbing := config.SchemaFilename 140 config.SchemaFilename = StringList{} 141 for _, f := range preGlobbing { 142 var matches []string 143 144 // for ** we want to override default globbing patterns and walk all 145 // subdirectories to match schema files. 146 if strings.Contains(f, "**") { 147 pathParts := strings.SplitN(f, "**", 2) 148 rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`) 149 // turn the rest of the glob into a regex, anchored only at the end because ** allows 150 // for any number of dirs in between and walk will let us match against the full path name 151 globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`) 152 153 if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error { 154 if err != nil { 155 return err 156 } 157 158 if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) { 159 matches = append(matches, path) 160 } 161 162 return nil 163 }); err != nil { 164 return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err) 165 } 166 } else { 167 var err error 168 matches, err = filepath.Glob(f) 169 if err != nil { 170 return fmt.Errorf("failed to glob schema filename %s: %w", f, err) 171 } 172 } 173 174 for _, m := range matches { 175 if config.SchemaFilename.Has(m) { 176 continue 177 } 178 config.SchemaFilename = append(config.SchemaFilename, m) 179 } 180 } 181 182 for _, filename := range config.SchemaFilename { 183 filename = filepath.ToSlash(filename) 184 var err error 185 var schemaRaw []byte 186 schemaRaw, err = os.ReadFile(filename) 187 if err != nil { 188 return fmt.Errorf("unable to open schema: %w", err) 189 } 190 191 config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)}) 192 } 193 return nil 194 } 195 196 func (c *Config) Init() error { 197 if c.Packages == nil { 198 c.Packages = &code.Packages{} 199 } 200 201 if c.Schema == nil { 202 if err := c.LoadSchema(); err != nil { 203 return err 204 } 205 } 206 207 err := c.injectTypesFromSchema() 208 if err != nil { 209 return err 210 } 211 212 err = c.autobind() 213 if err != nil { 214 return err 215 } 216 217 c.injectBuiltins() 218 // prefetch all packages in one big packages.Load call 219 c.Packages.LoadAll(c.packageList()...) 220 221 // check everything is valid on the way out 222 err = c.check() 223 if err != nil { 224 return err 225 } 226 227 return nil 228 } 229 230 func (c *Config) packageList() []string { 231 pkgs := []string{ 232 "github.com/mstephano/gqlgen-schemagen/graphql", 233 "github.com/mstephano/gqlgen-schemagen/graphql/introspection", 234 } 235 pkgs = append(pkgs, c.Models.ReferencedPackages()...) 236 pkgs = append(pkgs, c.AutoBind...) 237 return pkgs 238 } 239 240 func (c *Config) ReloadAllPackages() { 241 c.Packages.ReloadAll(c.packageList()...) 242 } 243 244 func (c *Config) injectTypesFromSchema() error { 245 c.Directives["goModel"] = DirectiveConfig{ 246 SkipRuntime: true, 247 } 248 249 c.Directives["goField"] = DirectiveConfig{ 250 SkipRuntime: true, 251 } 252 253 c.Directives["goTag"] = DirectiveConfig{ 254 SkipRuntime: true, 255 } 256 257 for _, schemaType := range c.Schema.Types { 258 if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription { 259 continue 260 } 261 262 if bd := schemaType.Directives.ForName("goModel"); bd != nil { 263 if ma := bd.Arguments.ForName("model"); ma != nil { 264 if mv, err := ma.Value.Value(nil); err == nil { 265 c.Models.Add(schemaType.Name, mv.(string)) 266 } 267 } 268 if ma := bd.Arguments.ForName("models"); ma != nil { 269 if mvs, err := ma.Value.Value(nil); err == nil { 270 for _, mv := range mvs.([]interface{}) { 271 c.Models.Add(schemaType.Name, mv.(string)) 272 } 273 } 274 } 275 } 276 277 if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject { 278 for _, field := range schemaType.Fields { 279 if fd := field.Directives.ForName("goField"); fd != nil { 280 forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver 281 fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName 282 283 if ra := fd.Arguments.ForName("forceResolver"); ra != nil { 284 if fr, err := ra.Value.Value(nil); err == nil { 285 forceResolver = fr.(bool) 286 } 287 } 288 289 if na := fd.Arguments.ForName("name"); na != nil { 290 if fr, err := na.Value.Value(nil); err == nil { 291 fieldName = fr.(string) 292 } 293 } 294 295 if c.Models[schemaType.Name].Fields == nil { 296 c.Models[schemaType.Name] = TypeMapEntry{ 297 Model: c.Models[schemaType.Name].Model, 298 Fields: map[string]TypeMapField{}, 299 } 300 } 301 302 c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{ 303 FieldName: fieldName, 304 Resolver: forceResolver, 305 } 306 } 307 } 308 } 309 } 310 311 return nil 312 } 313 314 type TypeMapEntry struct { 315 Model StringList `yaml:"model"` 316 Fields map[string]TypeMapField `yaml:"fields,omitempty"` 317 } 318 319 type TypeMapField struct { 320 Resolver bool `yaml:"resolver"` 321 FieldName string `yaml:"fieldName"` 322 GeneratedMethod string `yaml:"-"` 323 } 324 325 type StringList []string 326 327 func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error { 328 var single string 329 err := unmarshal(&single) 330 if err == nil { 331 *a = []string{single} 332 return nil 333 } 334 335 var multi []string 336 err = unmarshal(&multi) 337 if err != nil { 338 return err 339 } 340 341 *a = multi 342 return nil 343 } 344 345 func (a StringList) Has(file string) bool { 346 for _, existing := range a { 347 if existing == file { 348 return true 349 } 350 } 351 return false 352 } 353 354 func (c *Config) check() error { 355 if c.Models == nil { 356 c.Models = TypeMap{} 357 } 358 359 type FilenamePackage struct { 360 Filename string 361 Package string 362 Declaree string 363 } 364 365 fileList := map[string][]FilenamePackage{} 366 367 if err := c.Models.Check(); err != nil { 368 return fmt.Errorf("config.models: %w", err) 369 } 370 if err := c.Exec.Check(); err != nil { 371 return fmt.Errorf("config.exec: %w", err) 372 } 373 fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{ 374 Filename: c.Exec.Filename, 375 Package: c.Exec.Package, 376 Declaree: "exec", 377 }) 378 379 if c.Model.IsDefined() { 380 if err := c.Model.Check(); err != nil { 381 return fmt.Errorf("config.model: %w", err) 382 } 383 fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{ 384 Filename: c.Model.Filename, 385 Package: c.Model.Package, 386 Declaree: "model", 387 }) 388 } 389 if c.Resolver.IsDefined() { 390 if err := c.Resolver.Check(); err != nil { 391 return fmt.Errorf("config.resolver: %w", err) 392 } 393 fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{ 394 Filename: c.Resolver.Filename, 395 Package: c.Resolver.Package, 396 Declaree: "resolver", 397 }) 398 } 399 if c.Federation.IsDefined() { 400 if err := c.Federation.Check(); err != nil { 401 return fmt.Errorf("config.federation: %w", err) 402 } 403 fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{ 404 Filename: c.Federation.Filename, 405 Package: c.Federation.Package, 406 Declaree: "federation", 407 }) 408 if c.Federation.ImportPath() != c.Exec.ImportPath() { 409 return fmt.Errorf("federation and exec must be in the same package") 410 } 411 } 412 if c.Federated { 413 return fmt.Errorf("federated has been removed, instead use\nfederation:\n filename: path/to/federated.go") 414 } 415 416 for importPath, pkg := range fileList { 417 for _, file1 := range pkg { 418 for _, file2 := range pkg { 419 if file1.Package != file2.Package { 420 return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)", 421 file1.Declaree, 422 file2.Declaree, 423 importPath, 424 file1.Package, 425 file2.Package, 426 ) 427 } 428 } 429 } 430 } 431 432 return nil 433 } 434 435 type TypeMap map[string]TypeMapEntry 436 437 func (tm TypeMap) Exists(typeName string) bool { 438 _, ok := tm[typeName] 439 return ok 440 } 441 442 func (tm TypeMap) UserDefined(typeName string) bool { 443 m, ok := tm[typeName] 444 return ok && len(m.Model) > 0 445 } 446 447 func (tm TypeMap) Check() error { 448 for typeName, entry := range tm { 449 for _, model := range entry.Model { 450 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") { 451 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model) 452 } 453 } 454 } 455 return nil 456 } 457 458 func (tm TypeMap) ReferencedPackages() []string { 459 var pkgs []string 460 461 for _, typ := range tm { 462 for _, model := range typ.Model { 463 if model == "map[string]interface{}" || model == "interface{}" { 464 continue 465 } 466 pkg, _ := code.PkgAndType(model) 467 if pkg == "" || inStrSlice(pkgs, pkg) { 468 continue 469 } 470 pkgs = append(pkgs, code.QualifyPackagePath(pkg)) 471 } 472 } 473 474 sort.Slice(pkgs, func(i, j int) bool { 475 return pkgs[i] > pkgs[j] 476 }) 477 return pkgs 478 } 479 480 func (tm TypeMap) Add(name string, goType string) { 481 modelCfg := tm[name] 482 modelCfg.Model = append(modelCfg.Model, goType) 483 tm[name] = modelCfg 484 } 485 486 type DirectiveConfig struct { 487 SkipRuntime bool `yaml:"skip_runtime"` 488 } 489 490 func inStrSlice(haystack []string, needle string) bool { 491 for _, v := range haystack { 492 if needle == v { 493 return true 494 } 495 } 496 497 return false 498 } 499 500 // findCfg searches for the config file in this directory and all parents up the tree 501 // looking for the closest match 502 func findCfg() (string, error) { 503 dir, err := os.Getwd() 504 if err != nil { 505 return "", fmt.Errorf("unable to get working dir to findCfg: %w", err) 506 } 507 508 cfg := findCfgInDir(dir) 509 510 for cfg == "" && dir != filepath.Dir(dir) { 511 dir = filepath.Dir(dir) 512 cfg = findCfgInDir(dir) 513 } 514 515 if cfg == "" { 516 return "", os.ErrNotExist 517 } 518 519 return cfg, nil 520 } 521 522 func findCfgInDir(dir string) string { 523 for _, cfgName := range cfgFilenames { 524 path := filepath.Join(dir, cfgName) 525 if _, err := os.Stat(path); err == nil { 526 return path 527 } 528 } 529 return "" 530 } 531 532 func (c *Config) autobind() error { 533 if len(c.AutoBind) == 0 { 534 return nil 535 } 536 537 ps := c.Packages.LoadAll(c.AutoBind...) 538 539 for _, t := range c.Schema.Types { 540 if c.Models.UserDefined(t.Name) { 541 continue 542 } 543 544 for i, p := range ps { 545 if p == nil || p.Module == nil { 546 return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i]) 547 } 548 if t := p.Types.Scope().Lookup(t.Name); t != nil { 549 c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name()) 550 break 551 } 552 } 553 } 554 555 for i, t := range c.Models { 556 for j, m := range t.Model { 557 pkg, typename := code.PkgAndType(m) 558 559 // skip anything that looks like an import path 560 if strings.Contains(pkg, "/") { 561 continue 562 } 563 564 for _, p := range ps { 565 if p.Name != pkg { 566 continue 567 } 568 if t := p.Types.Scope().Lookup(typename); t != nil { 569 c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name() 570 break 571 } 572 } 573 } 574 } 575 576 return nil 577 } 578 579 func (c *Config) injectBuiltins() { 580 builtins := TypeMap{ 581 "__Directive": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.Directive"}}, 582 "__DirectiveLocation": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.String"}}, 583 "__Type": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.Type"}}, 584 "__TypeKind": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.String"}}, 585 "__Field": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.Field"}}, 586 "__EnumValue": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.EnumValue"}}, 587 "__InputValue": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.InputValue"}}, 588 "__Schema": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql/introspection.Schema"}}, 589 "Float": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.FloatContext"}}, 590 "String": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.String"}}, 591 "Boolean": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Boolean"}}, 592 "Int": {Model: StringList{ 593 "github.com/mstephano/gqlgen-schemagen/graphql.Int", 594 "github.com/mstephano/gqlgen-schemagen/graphql.Int32", 595 "github.com/mstephano/gqlgen-schemagen/graphql.Int64", 596 }}, 597 "ID": { 598 Model: StringList{ 599 "github.com/mstephano/gqlgen-schemagen/graphql.ID", 600 "github.com/mstephano/gqlgen-schemagen/graphql.IntID", 601 }, 602 }, 603 } 604 605 for typeName, entry := range builtins { 606 if !c.Models.Exists(typeName) { 607 c.Models[typeName] = entry 608 } 609 } 610 611 // These are additional types that are injected if defined in the schema as scalars. 612 extraBuiltins := TypeMap{ 613 "Time": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Time"}}, 614 "Map": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Map"}}, 615 "Upload": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Upload"}}, 616 "Any": {Model: StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Any"}}, 617 } 618 619 for typeName, entry := range extraBuiltins { 620 if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar { 621 c.Models[typeName] = entry 622 } 623 } 624 } 625 626 func (c *Config) LoadSchema() error { 627 if c.Packages != nil { 628 c.Packages = &code.Packages{} 629 } 630 631 if err := c.check(); err != nil { 632 return err 633 } 634 635 schema, err := gqlparser.LoadSchema(c.Sources...) 636 if err != nil { 637 return err 638 } 639 640 if schema.Query == nil { 641 schema.Query = &ast.Definition{ 642 Kind: ast.Object, 643 Name: "Query", 644 } 645 schema.Types["Query"] = schema.Query 646 } 647 648 c.Schema = schema 649 return nil 650 } 651 652 func abs(path string) string { 653 absPath, err := filepath.Abs(path) 654 if err != nil { 655 panic(err) 656 } 657 return filepath.ToSlash(absPath) 658 }