git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/config/config.go (about) 1 package config 2 3 import ( 4 "fmt" 5 "io/ioutil" 6 "os" 7 "path/filepath" 8 "regexp" 9 "sort" 10 "strings" 11 12 "git.sr.ht/~sircmpwn/gqlgen/internal/code" 13 "github.com/pkg/errors" 14 "github.com/vektah/gqlparser/v2" 15 "github.com/vektah/gqlparser/v2/ast" 16 "gopkg.in/yaml.v2" 17 ) 18 19 type Config struct { 20 SchemaFilename StringList `yaml:"schema,omitempty"` 21 Exec PackageConfig `yaml:"exec"` 22 Model PackageConfig `yaml:"model,omitempty"` 23 Federation PackageConfig `yaml:"federation,omitempty"` 24 Resolver ResolverConfig `yaml:"resolver,omitempty"` 25 AutoBind []string `yaml:"autobind"` 26 Models TypeMap `yaml:"models,omitempty"` 27 StructTag string `yaml:"struct_tag,omitempty"` 28 Directives map[string]DirectiveConfig `yaml:"directives,omitempty"` 29 OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"` 30 SkipValidation bool `yaml:"skip_validation,omitempty"` 31 Sources []*ast.Source `yaml:"-"` 32 Packages *code.Packages `yaml:"-"` 33 Schema *ast.Schema `yaml:"-"` 34 35 // Deprecated use Federation instead. Will be removed next release 36 Federated bool `yaml:"federated,omitempty"` 37 } 38 39 var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"} 40 41 // DefaultConfig creates a copy of the default config 42 func DefaultConfig() *Config { 43 return &Config{ 44 SchemaFilename: StringList{"schema.graphql"}, 45 Model: PackageConfig{Filename: "models_gen.go"}, 46 Exec: PackageConfig{Filename: "generated.go"}, 47 Directives: map[string]DirectiveConfig{}, 48 Models: TypeMap{}, 49 } 50 } 51 52 // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories 53 // walking up the tree. The closest config file will be returned. 54 func LoadConfigFromDefaultLocations() (*Config, error) { 55 cfgFile, err := findCfg() 56 if err != nil { 57 return nil, err 58 } 59 60 err = os.Chdir(filepath.Dir(cfgFile)) 61 if err != nil { 62 return nil, errors.Wrap(err, "unable to enter config dir") 63 } 64 return LoadConfig(cfgFile) 65 } 66 67 var path2regex = strings.NewReplacer( 68 `.`, `\.`, 69 `*`, `.+`, 70 `\`, `[\\/]`, 71 `/`, `[\\/]`, 72 ) 73 74 // LoadConfig reads the gqlgen.yml config file 75 func LoadConfig(filename string) (*Config, error) { 76 config := DefaultConfig() 77 78 b, err := ioutil.ReadFile(filename) 79 if err != nil { 80 return nil, errors.Wrap(err, "unable to read config") 81 } 82 83 if err := yaml.UnmarshalStrict(b, config); err != nil { 84 return nil, errors.Wrap(err, "unable to parse config") 85 } 86 87 defaultDirectives := map[string]DirectiveConfig{ 88 "skip": {SkipRuntime: true}, 89 "include": {SkipRuntime: true}, 90 "deprecated": {SkipRuntime: true}, 91 } 92 93 for key, value := range defaultDirectives { 94 if _, defined := config.Directives[key]; !defined { 95 config.Directives[key] = value 96 } 97 } 98 99 preGlobbing := config.SchemaFilename 100 config.SchemaFilename = StringList{} 101 for _, f := range preGlobbing { 102 var matches []string 103 104 // for ** we want to override default globbing patterns and walk all 105 // subdirectories to match schema files. 106 if strings.Contains(f, "**") { 107 pathParts := strings.SplitN(f, "**", 2) 108 rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`) 109 // turn the rest of the glob into a regex, anchored only at the end because ** allows 110 // for any number of dirs in between and walk will let us match against the full path name 111 globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`) 112 113 if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error { 114 if err != nil { 115 return err 116 } 117 118 if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) { 119 matches = append(matches, path) 120 } 121 122 return nil 123 }); err != nil { 124 return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0]) 125 } 126 } else { 127 matches, err = filepath.Glob(f) 128 if err != nil { 129 return nil, errors.Wrapf(err, "failed to glob schema filename %s", f) 130 } 131 } 132 133 for _, m := range matches { 134 if config.SchemaFilename.Has(m) { 135 continue 136 } 137 config.SchemaFilename = append(config.SchemaFilename, m) 138 } 139 } 140 141 for _, filename := range config.SchemaFilename { 142 filename = filepath.ToSlash(filename) 143 var err error 144 var schemaRaw []byte 145 schemaRaw, err = ioutil.ReadFile(filename) 146 if err != nil { 147 return nil, errors.Wrap(err, "unable to open schema") 148 } 149 150 config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)}) 151 } 152 153 return config, nil 154 } 155 156 func (c *Config) Init() error { 157 if c.Packages == nil { 158 c.Packages = &code.Packages{} 159 } 160 161 if c.Schema == nil { 162 if err := c.LoadSchema(); err != nil { 163 return err 164 } 165 } 166 167 err := c.injectTypesFromSchema() 168 if err != nil { 169 return err 170 } 171 172 err = c.autobind() 173 if err != nil { 174 return err 175 } 176 177 c.injectBuiltins() 178 179 // prefetch all packages in one big packages.Load call 180 pkgs := []string{ 181 "git.sr.ht/~sircmpwn/gqlgen/graphql", 182 "git.sr.ht/~sircmpwn/gqlgen/graphql/introspection", 183 } 184 pkgs = append(pkgs, c.Models.ReferencedPackages()...) 185 pkgs = append(pkgs, c.AutoBind...) 186 c.Packages.LoadAll(pkgs...) 187 188 // check everything is valid on the way out 189 err = c.check() 190 if err != nil { 191 return err 192 } 193 194 return nil 195 } 196 197 func (c *Config) injectTypesFromSchema() error { 198 c.Directives["goModel"] = DirectiveConfig{ 199 SkipRuntime: true, 200 } 201 202 c.Directives["goField"] = DirectiveConfig{ 203 SkipRuntime: true, 204 } 205 206 for _, schemaType := range c.Schema.Types { 207 if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription { 208 continue 209 } 210 211 if bd := schemaType.Directives.ForName("goModel"); bd != nil { 212 if ma := bd.Arguments.ForName("model"); ma != nil { 213 if mv, err := ma.Value.Value(nil); err == nil { 214 c.Models.Add(schemaType.Name, mv.(string)) 215 } 216 } 217 if ma := bd.Arguments.ForName("models"); ma != nil { 218 if mvs, err := ma.Value.Value(nil); err == nil { 219 for _, mv := range mvs.([]interface{}) { 220 c.Models.Add(schemaType.Name, mv.(string)) 221 } 222 } 223 } 224 } 225 226 if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject { 227 for _, field := range schemaType.Fields { 228 if fd := field.Directives.ForName("goField"); fd != nil { 229 forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver 230 fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName 231 232 if ra := fd.Arguments.ForName("forceResolver"); ra != nil { 233 if fr, err := ra.Value.Value(nil); err == nil { 234 forceResolver = fr.(bool) 235 } 236 } 237 238 if na := fd.Arguments.ForName("name"); na != nil { 239 if fr, err := na.Value.Value(nil); err == nil { 240 fieldName = fr.(string) 241 } 242 } 243 244 if c.Models[schemaType.Name].Fields == nil { 245 c.Models[schemaType.Name] = TypeMapEntry{ 246 Model: c.Models[schemaType.Name].Model, 247 Fields: map[string]TypeMapField{}, 248 } 249 } 250 251 c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{ 252 FieldName: fieldName, 253 Resolver: forceResolver, 254 } 255 } 256 } 257 } 258 } 259 260 return nil 261 } 262 263 type TypeMapEntry struct { 264 Model StringList `yaml:"model"` 265 Fields map[string]TypeMapField `yaml:"fields,omitempty"` 266 } 267 268 type TypeMapField struct { 269 Resolver bool `yaml:"resolver"` 270 FieldName string `yaml:"fieldName"` 271 GeneratedMethod string `yaml:"-"` 272 } 273 274 type StringList []string 275 276 func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error { 277 var single string 278 err := unmarshal(&single) 279 if err == nil { 280 *a = []string{single} 281 return nil 282 } 283 284 var multi []string 285 err = unmarshal(&multi) 286 if err != nil { 287 return err 288 } 289 290 *a = multi 291 return nil 292 } 293 294 func (a StringList) Has(file string) bool { 295 for _, existing := range a { 296 if existing == file { 297 return true 298 } 299 } 300 return false 301 } 302 303 func (c *Config) check() error { 304 if c.Models == nil { 305 c.Models = TypeMap{} 306 } 307 308 type FilenamePackage struct { 309 Filename string 310 Package string 311 Declaree string 312 } 313 314 fileList := map[string][]FilenamePackage{} 315 316 if err := c.Models.Check(); err != nil { 317 return errors.Wrap(err, "config.models") 318 } 319 if err := c.Exec.Check(); err != nil { 320 return errors.Wrap(err, "config.exec") 321 } 322 fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{ 323 Filename: c.Exec.Filename, 324 Package: c.Exec.Package, 325 Declaree: "exec", 326 }) 327 328 if c.Model.IsDefined() { 329 if err := c.Model.Check(); err != nil { 330 return errors.Wrap(err, "config.model") 331 } 332 fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{ 333 Filename: c.Model.Filename, 334 Package: c.Model.Package, 335 Declaree: "model", 336 }) 337 } 338 if c.Resolver.IsDefined() { 339 if err := c.Resolver.Check(); err != nil { 340 return errors.Wrap(err, "config.resolver") 341 } 342 fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{ 343 Filename: c.Resolver.Filename, 344 Package: c.Resolver.Package, 345 Declaree: "resolver", 346 }) 347 } 348 if c.Federation.IsDefined() { 349 if err := c.Federation.Check(); err != nil { 350 return errors.Wrap(err, "config.federation") 351 } 352 fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{ 353 Filename: c.Federation.Filename, 354 Package: c.Federation.Package, 355 Declaree: "federation", 356 }) 357 if c.Federation.ImportPath() != c.Exec.ImportPath() { 358 return fmt.Errorf("federation and exec must be in the same package") 359 } 360 } 361 if c.Federated { 362 return fmt.Errorf("federated has been removed, instead use\nfederation:\n filename: path/to/federated.go") 363 } 364 365 for importPath, pkg := range fileList { 366 for _, file1 := range pkg { 367 for _, file2 := range pkg { 368 if file1.Package != file2.Package { 369 return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)", 370 file1.Declaree, 371 file2.Declaree, 372 importPath, 373 file1.Package, 374 file2.Package, 375 ) 376 } 377 } 378 } 379 } 380 381 return nil 382 } 383 384 type TypeMap map[string]TypeMapEntry 385 386 func (tm TypeMap) Exists(typeName string) bool { 387 _, ok := tm[typeName] 388 return ok 389 } 390 391 func (tm TypeMap) UserDefined(typeName string) bool { 392 m, ok := tm[typeName] 393 return ok && len(m.Model) > 0 394 } 395 396 func (tm TypeMap) Check() error { 397 for typeName, entry := range tm { 398 for _, model := range entry.Model { 399 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") { 400 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model) 401 } 402 } 403 } 404 return nil 405 } 406 407 func (tm TypeMap) ReferencedPackages() []string { 408 var pkgs []string 409 410 for _, typ := range tm { 411 for _, model := range typ.Model { 412 if model == "map[string]interface{}" || model == "interface{}" { 413 continue 414 } 415 pkg, _ := code.PkgAndType(model) 416 if pkg == "" || inStrSlice(pkgs, pkg) { 417 continue 418 } 419 pkgs = append(pkgs, code.QualifyPackagePath(pkg)) 420 } 421 } 422 423 sort.Slice(pkgs, func(i, j int) bool { 424 return pkgs[i] > pkgs[j] 425 }) 426 return pkgs 427 } 428 429 func (tm TypeMap) Add(name string, goType string) { 430 modelCfg := tm[name] 431 modelCfg.Model = append(modelCfg.Model, goType) 432 tm[name] = modelCfg 433 } 434 435 type DirectiveConfig struct { 436 SkipRuntime bool `yaml:"skip_runtime"` 437 } 438 439 func inStrSlice(haystack []string, needle string) bool { 440 for _, v := range haystack { 441 if needle == v { 442 return true 443 } 444 } 445 446 return false 447 } 448 449 // findCfg searches for the config file in this directory and all parents up the tree 450 // looking for the closest match 451 func findCfg() (string, error) { 452 dir, err := os.Getwd() 453 if err != nil { 454 return "", errors.Wrap(err, "unable to get working dir to findCfg") 455 } 456 457 cfg := findCfgInDir(dir) 458 459 for cfg == "" && dir != filepath.Dir(dir) { 460 dir = filepath.Dir(dir) 461 cfg = findCfgInDir(dir) 462 } 463 464 if cfg == "" { 465 return "", os.ErrNotExist 466 } 467 468 return cfg, nil 469 } 470 471 func findCfgInDir(dir string) string { 472 for _, cfgName := range cfgFilenames { 473 path := filepath.Join(dir, cfgName) 474 if _, err := os.Stat(path); err == nil { 475 return path 476 } 477 } 478 return "" 479 } 480 481 func (c *Config) autobind() error { 482 if len(c.AutoBind) == 0 { 483 return nil 484 } 485 486 ps := c.Packages.LoadAll(c.AutoBind...) 487 488 for _, t := range c.Schema.Types { 489 if c.Models.UserDefined(t.Name) { 490 continue 491 } 492 493 for i, p := range ps { 494 if p == nil { 495 return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i]) 496 } 497 if t := p.Types.Scope().Lookup(t.Name); t != nil { 498 c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name()) 499 break 500 } 501 } 502 } 503 504 for i, t := range c.Models { 505 for j, m := range t.Model { 506 pkg, typename := code.PkgAndType(m) 507 508 // skip anything that looks like an import path 509 if strings.Contains(pkg, "/") { 510 continue 511 } 512 513 for _, p := range ps { 514 if p.Name != pkg { 515 continue 516 } 517 if t := p.Types.Scope().Lookup(typename); t != nil { 518 c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name() 519 break 520 } 521 } 522 } 523 } 524 525 return nil 526 } 527 528 func (c *Config) injectBuiltins() { 529 builtins := TypeMap{ 530 "__Directive": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Directive"}}, 531 "__DirectiveLocation": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}}, 532 "__Type": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Type"}}, 533 "__TypeKind": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}}, 534 "__Field": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Field"}}, 535 "__EnumValue": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.EnumValue"}}, 536 "__InputValue": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.InputValue"}}, 537 "__Schema": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Schema"}}, 538 "Float": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Float"}}, 539 "String": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}}, 540 "Boolean": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Boolean"}}, 541 "Int": {Model: StringList{ 542 "git.sr.ht/~sircmpwn/gqlgen/graphql.Int", 543 "git.sr.ht/~sircmpwn/gqlgen/graphql.Int32", 544 "git.sr.ht/~sircmpwn/gqlgen/graphql.Int64", 545 }}, 546 "ID": { 547 Model: StringList{ 548 "git.sr.ht/~sircmpwn/gqlgen/graphql.ID", 549 "git.sr.ht/~sircmpwn/gqlgen/graphql.IntID", 550 }, 551 }, 552 } 553 554 for typeName, entry := range builtins { 555 if !c.Models.Exists(typeName) { 556 c.Models[typeName] = entry 557 } 558 } 559 560 // These are additional types that are injected if defined in the schema as scalars. 561 extraBuiltins := TypeMap{ 562 "Time": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Time"}}, 563 "Map": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Map"}}, 564 "Upload": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Upload"}}, 565 "Any": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Any"}}, 566 } 567 568 for typeName, entry := range extraBuiltins { 569 if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar { 570 c.Models[typeName] = entry 571 } 572 } 573 } 574 575 func (c *Config) LoadSchema() error { 576 if c.Packages != nil { 577 c.Packages = &code.Packages{} 578 } 579 580 if err := c.check(); err != nil { 581 return err 582 } 583 584 schema, err := gqlparser.LoadSchema(c.Sources...) 585 if err != nil { 586 return err 587 } 588 589 if schema.Query == nil { 590 schema.Query = &ast.Definition{ 591 Kind: ast.Object, 592 Name: "Query", 593 } 594 schema.Types["Query"] = schema.Query 595 } 596 597 c.Schema = schema 598 return nil 599 } 600 601 func abs(path string) string { 602 absPath, err := filepath.Abs(path) 603 if err != nil { 604 panic(err) 605 } 606 return filepath.ToSlash(absPath) 607 }