github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/99designs/gqlgen/internal/code"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser/v2"
    15  	"github.com/vektah/gqlparser/v2/ast"
    16  	"gopkg.in/yaml.v2"
    17  )
    18  
    19  type Config struct {
    20  	SchemaFilename           StringList                 `yaml:"schema,omitempty"`
    21  	Exec                     PackageConfig              `yaml:"exec"`
    22  	Model                    PackageConfig              `yaml:"model,omitempty"`
    23  	Federation               PackageConfig              `yaml:"federation,omitempty"`
    24  	Resolver                 ResolverConfig             `yaml:"resolver,omitempty"`
    25  	AutoBind                 []string                   `yaml:"autobind"`
    26  	Models                   TypeMap                    `yaml:"models,omitempty"`
    27  	StructTag                string                     `yaml:"struct_tag,omitempty"`
    28  	Directives               map[string]DirectiveConfig `yaml:"directives,omitempty"`
    29  	OmitSliceElementPointers bool                       `yaml:"omit_slice_element_pointers,omitempty"`
    30  	SkipValidation           bool                       `yaml:"skip_validation,omitempty"`
    31  	Sources                  []*ast.Source              `yaml:"-"`
    32  	Packages                 *code.Packages             `yaml:"-"`
    33  	Schema                   *ast.Schema                `yaml:"-"`
    34  
    35  	// Deprecated use Federation instead. Will be removed next release
    36  	Federated bool `yaml:"federated,omitempty"`
    37  }
    38  
    39  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    40  
    41  // DefaultConfig creates a copy of the default config
    42  func DefaultConfig() *Config {
    43  	return &Config{
    44  		SchemaFilename: StringList{"schema.graphql"},
    45  		Model:          PackageConfig{Filename: "models_gen.go"},
    46  		Exec:           PackageConfig{Filename: "generated.go"},
    47  		Directives:     map[string]DirectiveConfig{},
    48  		Models:         TypeMap{},
    49  	}
    50  }
    51  
    52  // LoadDefaultConfig loads the default config so that it is ready to be used
    53  func LoadDefaultConfig() (*Config, error) {
    54  	config := DefaultConfig()
    55  
    56  	for _, filename := range config.SchemaFilename {
    57  		filename = filepath.ToSlash(filename)
    58  		var err error
    59  		var schemaRaw []byte
    60  		schemaRaw, err = ioutil.ReadFile(filename)
    61  		if err != nil {
    62  			return nil, errors.Wrap(err, "unable to open schema")
    63  		}
    64  
    65  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
    66  	}
    67  
    68  	return config, nil
    69  }
    70  
    71  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    72  // walking up the tree. The closest config file will be returned.
    73  func LoadConfigFromDefaultLocations() (*Config, error) {
    74  	cfgFile, err := findCfg()
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	err = os.Chdir(filepath.Dir(cfgFile))
    80  	if err != nil {
    81  		return nil, errors.Wrap(err, "unable to enter config dir")
    82  	}
    83  	return LoadConfig(cfgFile)
    84  }
    85  
    86  var path2regex = strings.NewReplacer(
    87  	`.`, `\.`,
    88  	`*`, `.+`,
    89  	`\`, `[\\/]`,
    90  	`/`, `[\\/]`,
    91  )
    92  
    93  // LoadConfig reads the gqlgen.yml config file
    94  func LoadConfig(filename string) (*Config, error) {
    95  	config := DefaultConfig()
    96  
    97  	b, err := ioutil.ReadFile(filename)
    98  	if err != nil {
    99  		return nil, errors.Wrap(err, "unable to read config")
   100  	}
   101  
   102  	if err := yaml.UnmarshalStrict(b, config); err != nil {
   103  		return nil, errors.Wrap(err, "unable to parse config")
   104  	}
   105  
   106  	defaultDirectives := map[string]DirectiveConfig{
   107  		"skip":       {SkipRuntime: true},
   108  		"include":    {SkipRuntime: true},
   109  		"deprecated": {SkipRuntime: true},
   110  	}
   111  
   112  	for key, value := range defaultDirectives {
   113  		if _, defined := config.Directives[key]; !defined {
   114  			config.Directives[key] = value
   115  		}
   116  	}
   117  
   118  	preGlobbing := config.SchemaFilename
   119  	config.SchemaFilename = StringList{}
   120  	for _, f := range preGlobbing {
   121  		var matches []string
   122  
   123  		// for ** we want to override default globbing patterns and walk all
   124  		// subdirectories to match schema files.
   125  		if strings.Contains(f, "**") {
   126  			pathParts := strings.SplitN(f, "**", 2)
   127  			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
   128  			// turn the rest of the glob into a regex, anchored only at the end because ** allows
   129  			// for any number of dirs in between and walk will let us match against the full path name
   130  			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
   131  
   132  			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
   133  				if err != nil {
   134  					return err
   135  				}
   136  
   137  				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
   138  					matches = append(matches, path)
   139  				}
   140  
   141  				return nil
   142  			}); err != nil {
   143  				return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
   144  			}
   145  		} else {
   146  			matches, err = filepath.Glob(f)
   147  			if err != nil {
   148  				return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
   149  			}
   150  		}
   151  
   152  		for _, m := range matches {
   153  			if config.SchemaFilename.Has(m) {
   154  				continue
   155  			}
   156  			config.SchemaFilename = append(config.SchemaFilename, m)
   157  		}
   158  	}
   159  
   160  	for _, filename := range config.SchemaFilename {
   161  		filename = filepath.ToSlash(filename)
   162  		var err error
   163  		var schemaRaw []byte
   164  		schemaRaw, err = ioutil.ReadFile(filename)
   165  		if err != nil {
   166  			return nil, errors.Wrap(err, "unable to open schema")
   167  		}
   168  
   169  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
   170  	}
   171  
   172  	return config, nil
   173  }
   174  
   175  func (c *Config) Init() error {
   176  	if c.Packages == nil {
   177  		c.Packages = &code.Packages{}
   178  	}
   179  
   180  	if c.Schema == nil {
   181  		if err := c.LoadSchema(); err != nil {
   182  			return err
   183  		}
   184  	}
   185  
   186  	err := c.injectTypesFromSchema()
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	err = c.autobind()
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	c.injectBuiltins()
   197  
   198  	// prefetch all packages in one big packages.Load call
   199  	pkgs := []string{
   200  		"github.com/99designs/gqlgen/graphql",
   201  		"github.com/99designs/gqlgen/graphql/introspection",
   202  	}
   203  	pkgs = append(pkgs, c.Models.ReferencedPackages()...)
   204  	pkgs = append(pkgs, c.AutoBind...)
   205  	c.Packages.LoadAll(pkgs...)
   206  
   207  	//  check everything is valid on the way out
   208  	err = c.check()
   209  	if err != nil {
   210  		return err
   211  	}
   212  
   213  	return nil
   214  }
   215  
   216  func (c *Config) injectTypesFromSchema() error {
   217  	c.Directives["goModel"] = DirectiveConfig{
   218  		SkipRuntime: true,
   219  	}
   220  
   221  	c.Directives["goField"] = DirectiveConfig{
   222  		SkipRuntime: true,
   223  	}
   224  
   225  	for _, schemaType := range c.Schema.Types {
   226  		if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
   227  			continue
   228  		}
   229  
   230  		if bd := schemaType.Directives.ForName("goModel"); bd != nil {
   231  			if ma := bd.Arguments.ForName("model"); ma != nil {
   232  				if mv, err := ma.Value.Value(nil); err == nil {
   233  					c.Models.Add(schemaType.Name, mv.(string))
   234  				}
   235  			}
   236  			if ma := bd.Arguments.ForName("models"); ma != nil {
   237  				if mvs, err := ma.Value.Value(nil); err == nil {
   238  					for _, mv := range mvs.([]interface{}) {
   239  						c.Models.Add(schemaType.Name, mv.(string))
   240  					}
   241  				}
   242  			}
   243  		}
   244  
   245  		if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
   246  			for _, field := range schemaType.Fields {
   247  				if fd := field.Directives.ForName("goField"); fd != nil {
   248  					forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
   249  					fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
   250  
   251  					if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
   252  						if fr, err := ra.Value.Value(nil); err == nil {
   253  							forceResolver = fr.(bool)
   254  						}
   255  					}
   256  
   257  					if na := fd.Arguments.ForName("name"); na != nil {
   258  						if fr, err := na.Value.Value(nil); err == nil {
   259  							fieldName = fr.(string)
   260  						}
   261  					}
   262  
   263  					if c.Models[schemaType.Name].Fields == nil {
   264  						c.Models[schemaType.Name] = TypeMapEntry{
   265  							Model:  c.Models[schemaType.Name].Model,
   266  							Fields: map[string]TypeMapField{},
   267  						}
   268  					}
   269  
   270  					c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
   271  						FieldName: fieldName,
   272  						Resolver:  forceResolver,
   273  					}
   274  				}
   275  			}
   276  		}
   277  	}
   278  
   279  	return nil
   280  }
   281  
   282  type TypeMapEntry struct {
   283  	Model  StringList              `yaml:"model"`
   284  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
   285  }
   286  
   287  type TypeMapField struct {
   288  	Resolver        bool   `yaml:"resolver"`
   289  	FieldName       string `yaml:"fieldName"`
   290  	GeneratedMethod string `yaml:"-"`
   291  }
   292  
   293  type StringList []string
   294  
   295  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   296  	var single string
   297  	err := unmarshal(&single)
   298  	if err == nil {
   299  		*a = []string{single}
   300  		return nil
   301  	}
   302  
   303  	var multi []string
   304  	err = unmarshal(&multi)
   305  	if err != nil {
   306  		return err
   307  	}
   308  
   309  	*a = multi
   310  	return nil
   311  }
   312  
   313  func (a StringList) Has(file string) bool {
   314  	for _, existing := range a {
   315  		if existing == file {
   316  			return true
   317  		}
   318  	}
   319  	return false
   320  }
   321  
   322  func (c *Config) check() error {
   323  	if c.Models == nil {
   324  		c.Models = TypeMap{}
   325  	}
   326  
   327  	type FilenamePackage struct {
   328  		Filename string
   329  		Package  string
   330  		Declaree string
   331  	}
   332  
   333  	fileList := map[string][]FilenamePackage{}
   334  
   335  	if err := c.Models.Check(); err != nil {
   336  		return errors.Wrap(err, "config.models")
   337  	}
   338  	if err := c.Exec.Check(); err != nil {
   339  		return errors.Wrap(err, "config.exec")
   340  	}
   341  	fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
   342  		Filename: c.Exec.Filename,
   343  		Package:  c.Exec.Package,
   344  		Declaree: "exec",
   345  	})
   346  
   347  	if c.Model.IsDefined() {
   348  		if err := c.Model.Check(); err != nil {
   349  			return errors.Wrap(err, "config.model")
   350  		}
   351  		fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
   352  			Filename: c.Model.Filename,
   353  			Package:  c.Model.Package,
   354  			Declaree: "model",
   355  		})
   356  	}
   357  	if c.Resolver.IsDefined() {
   358  		if err := c.Resolver.Check(); err != nil {
   359  			return errors.Wrap(err, "config.resolver")
   360  		}
   361  		fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
   362  			Filename: c.Resolver.Filename,
   363  			Package:  c.Resolver.Package,
   364  			Declaree: "resolver",
   365  		})
   366  	}
   367  	if c.Federation.IsDefined() {
   368  		if err := c.Federation.Check(); err != nil {
   369  			return errors.Wrap(err, "config.federation")
   370  		}
   371  		fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
   372  			Filename: c.Federation.Filename,
   373  			Package:  c.Federation.Package,
   374  			Declaree: "federation",
   375  		})
   376  		if c.Federation.ImportPath() != c.Exec.ImportPath() {
   377  			return fmt.Errorf("federation and exec must be in the same package")
   378  		}
   379  	}
   380  	if c.Federated {
   381  		return fmt.Errorf("federated has been removed, instead use\nfederation:\n    filename: path/to/federated.go")
   382  	}
   383  
   384  	for importPath, pkg := range fileList {
   385  		for _, file1 := range pkg {
   386  			for _, file2 := range pkg {
   387  				if file1.Package != file2.Package {
   388  					return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
   389  						file1.Declaree,
   390  						file2.Declaree,
   391  						importPath,
   392  						file1.Package,
   393  						file2.Package,
   394  					)
   395  				}
   396  			}
   397  		}
   398  	}
   399  
   400  	return nil
   401  }
   402  
   403  type TypeMap map[string]TypeMapEntry
   404  
   405  func (tm TypeMap) Exists(typeName string) bool {
   406  	_, ok := tm[typeName]
   407  	return ok
   408  }
   409  
   410  func (tm TypeMap) UserDefined(typeName string) bool {
   411  	m, ok := tm[typeName]
   412  	return ok && len(m.Model) > 0
   413  }
   414  
   415  func (tm TypeMap) Check() error {
   416  	for typeName, entry := range tm {
   417  		for _, model := range entry.Model {
   418  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   419  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   420  			}
   421  		}
   422  	}
   423  	return nil
   424  }
   425  
   426  func (tm TypeMap) ReferencedPackages() []string {
   427  	var pkgs []string
   428  
   429  	for _, typ := range tm {
   430  		for _, model := range typ.Model {
   431  			if model == "map[string]interface{}" || model == "interface{}" {
   432  				continue
   433  			}
   434  			pkg, _ := code.PkgAndType(model)
   435  			if pkg == "" || inStrSlice(pkgs, pkg) {
   436  				continue
   437  			}
   438  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   439  		}
   440  	}
   441  
   442  	sort.Slice(pkgs, func(i, j int) bool {
   443  		return pkgs[i] > pkgs[j]
   444  	})
   445  	return pkgs
   446  }
   447  
   448  func (tm TypeMap) Add(name string, goType string) {
   449  	modelCfg := tm[name]
   450  	modelCfg.Model = append(modelCfg.Model, goType)
   451  	tm[name] = modelCfg
   452  }
   453  
   454  type DirectiveConfig struct {
   455  	SkipRuntime bool `yaml:"skip_runtime"`
   456  }
   457  
   458  func inStrSlice(haystack []string, needle string) bool {
   459  	for _, v := range haystack {
   460  		if needle == v {
   461  			return true
   462  		}
   463  	}
   464  
   465  	return false
   466  }
   467  
   468  // findCfg searches for the config file in this directory and all parents up the tree
   469  // looking for the closest match
   470  func findCfg() (string, error) {
   471  	dir, err := os.Getwd()
   472  	if err != nil {
   473  		return "", errors.Wrap(err, "unable to get working dir to findCfg")
   474  	}
   475  
   476  	cfg := findCfgInDir(dir)
   477  
   478  	for cfg == "" && dir != filepath.Dir(dir) {
   479  		dir = filepath.Dir(dir)
   480  		cfg = findCfgInDir(dir)
   481  	}
   482  
   483  	if cfg == "" {
   484  		return "", os.ErrNotExist
   485  	}
   486  
   487  	return cfg, nil
   488  }
   489  
   490  func findCfgInDir(dir string) string {
   491  	for _, cfgName := range cfgFilenames {
   492  		path := filepath.Join(dir, cfgName)
   493  		if _, err := os.Stat(path); err == nil {
   494  			return path
   495  		}
   496  	}
   497  	return ""
   498  }
   499  
   500  func (c *Config) autobind() error {
   501  	if len(c.AutoBind) == 0 {
   502  		return nil
   503  	}
   504  
   505  	ps := c.Packages.LoadAll(c.AutoBind...)
   506  
   507  	for _, t := range c.Schema.Types {
   508  		if c.Models.UserDefined(t.Name) {
   509  			continue
   510  		}
   511  
   512  		for i, p := range ps {
   513  			if p == nil {
   514  				return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
   515  			}
   516  			if t := p.Types.Scope().Lookup(t.Name); t != nil {
   517  				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
   518  				break
   519  			}
   520  		}
   521  	}
   522  
   523  	for i, t := range c.Models {
   524  		for j, m := range t.Model {
   525  			pkg, typename := code.PkgAndType(m)
   526  
   527  			// skip anything that looks like an import path
   528  			if strings.Contains(pkg, "/") {
   529  				continue
   530  			}
   531  
   532  			for _, p := range ps {
   533  				if p.Name != pkg {
   534  					continue
   535  				}
   536  				if t := p.Types.Scope().Lookup(typename); t != nil {
   537  					c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
   538  					break
   539  				}
   540  			}
   541  		}
   542  	}
   543  
   544  	return nil
   545  }
   546  
   547  func (c *Config) injectBuiltins() {
   548  	builtins := TypeMap{
   549  		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
   550  		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   551  		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
   552  		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   553  		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
   554  		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
   555  		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
   556  		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
   557  		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
   558  		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   559  		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
   560  		"Int": {Model: StringList{
   561  			"github.com/99designs/gqlgen/graphql.Int",
   562  			"github.com/99designs/gqlgen/graphql.Int32",
   563  			"github.com/99designs/gqlgen/graphql.Int64",
   564  		}},
   565  		"ID": {
   566  			Model: StringList{
   567  				"github.com/99designs/gqlgen/graphql.ID",
   568  				"github.com/99designs/gqlgen/graphql.IntID",
   569  			},
   570  		},
   571  	}
   572  
   573  	for typeName, entry := range builtins {
   574  		if !c.Models.Exists(typeName) {
   575  			c.Models[typeName] = entry
   576  		}
   577  	}
   578  
   579  	// These are additional types that are injected if defined in the schema as scalars.
   580  	extraBuiltins := TypeMap{
   581  		"Time":   {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
   582  		"Map":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
   583  		"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
   584  		"Any":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
   585  	}
   586  
   587  	for typeName, entry := range extraBuiltins {
   588  		if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
   589  			c.Models[typeName] = entry
   590  		}
   591  	}
   592  }
   593  
   594  func (c *Config) LoadSchema() error {
   595  	if c.Packages != nil {
   596  		c.Packages = &code.Packages{}
   597  	}
   598  
   599  	if err := c.check(); err != nil {
   600  		return err
   601  	}
   602  
   603  	schema, err := gqlparser.LoadSchema(c.Sources...)
   604  	if err != nil {
   605  		return err
   606  	}
   607  
   608  	if schema.Query == nil {
   609  		schema.Query = &ast.Definition{
   610  			Kind: ast.Object,
   611  			Name: "Query",
   612  		}
   613  		schema.Types["Query"] = schema.Query
   614  	}
   615  
   616  	c.Schema = schema
   617  	return nil
   618  }
   619  
   620  func abs(path string) string {
   621  	absPath, err := filepath.Abs(path)
   622  	if err != nil {
   623  		panic(err)
   624  	}
   625  	return filepath.ToSlash(absPath)
   626  }