github.com/apipluspower/gqlgen@v0.15.2/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/apipluspower/gqlgen/internal/code"
    13  	"github.com/vektah/gqlparser/v2"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  	"gopkg.in/yaml.v2"
    16  )
    17  
    18  type Config struct {
    19  	SchemaFilename           StringList                 `yaml:"schema,omitempty"`
    20  	Exec                     ExecConfig                 `yaml:"exec"`
    21  	Model                    PackageConfig              `yaml:"model,omitempty"`
    22  	Federation               PackageConfig              `yaml:"federation,omitempty"`
    23  	Resolver                 ResolverConfig             `yaml:"resolver,omitempty"`
    24  	AutoBind                 []string                   `yaml:"autobind"`
    25  	Models                   TypeMap                    `yaml:"models,omitempty"`
    26  	StructTag                string                     `yaml:"struct_tag,omitempty"`
    27  	Directives               map[string]DirectiveConfig `yaml:"directives,omitempty"`
    28  	OmitSliceElementPointers bool                       `yaml:"omit_slice_element_pointers,omitempty"`
    29  	SkipValidation           bool                       `yaml:"skip_validation,omitempty"`
    30  	SkipModTidy              bool                       `yaml:"skip_mod_tidy,omitempty"`
    31  	Sources                  []*ast.Source              `yaml:"-"`
    32  	Packages                 *code.Packages             `yaml:"-"`
    33  	Schema                   *ast.Schema                `yaml:"-"`
    34  
    35  	// Deprecated: use Federation instead. Will be removed next release
    36  	Federated bool `yaml:"federated,omitempty"`
    37  }
    38  
    39  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    40  
    41  // DefaultConfig creates a copy of the default config
    42  func DefaultConfig() *Config {
    43  	return &Config{
    44  		SchemaFilename: StringList{"schema.graphql"},
    45  		Model:          PackageConfig{Filename: "models_gen.go"},
    46  		Exec:           ExecConfig{Filename: "generated.go"},
    47  		Directives:     map[string]DirectiveConfig{},
    48  		Models:         TypeMap{},
    49  	}
    50  }
    51  
    52  // LoadDefaultConfig loads the default config so that it is ready to be used
    53  func LoadDefaultConfig() (*Config, error) {
    54  	config := DefaultConfig()
    55  
    56  	for _, filename := range config.SchemaFilename {
    57  		filename = filepath.ToSlash(filename)
    58  		var err error
    59  		var schemaRaw []byte
    60  		schemaRaw, err = ioutil.ReadFile(filename)
    61  		if err != nil {
    62  			return nil, fmt.Errorf("unable to open schema: %w", err)
    63  		}
    64  
    65  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
    66  	}
    67  
    68  	return config, nil
    69  }
    70  
    71  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    72  // walking up the tree. The closest config file will be returned.
    73  func LoadConfigFromDefaultLocations() (*Config, error) {
    74  	cfgFile, err := findCfg()
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	err = os.Chdir(filepath.Dir(cfgFile))
    80  	if err != nil {
    81  		return nil, fmt.Errorf("unable to enter config dir: %w", err)
    82  	}
    83  	return LoadConfig(cfgFile)
    84  }
    85  
    86  var path2regex = strings.NewReplacer(
    87  	`.`, `\.`,
    88  	`*`, `.+`,
    89  	`\`, `[\\/]`,
    90  	`/`, `[\\/]`,
    91  )
    92  
    93  // LoadConfig reads the gqlgen.yml config file
    94  func LoadConfig(filename string) (*Config, error) {
    95  	config := DefaultConfig()
    96  
    97  	b, err := ioutil.ReadFile(filename)
    98  	if err != nil {
    99  		return nil, fmt.Errorf("unable to read config: %w", err)
   100  	}
   101  
   102  	if err := yaml.UnmarshalStrict(b, config); err != nil {
   103  		return nil, fmt.Errorf("unable to parse config: %w", err)
   104  	}
   105  
   106  	if err := CompleteConfig(config); err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return config, nil
   111  }
   112  
   113  // CompleteConfig fills in the schema and other values to a config loaded from
   114  // YAML.
   115  func CompleteConfig(config *Config) error {
   116  	defaultDirectives := map[string]DirectiveConfig{
   117  		"skip":       {SkipRuntime: true},
   118  		"include":    {SkipRuntime: true},
   119  		"deprecated": {SkipRuntime: true},
   120  	}
   121  
   122  	for key, value := range defaultDirectives {
   123  		if _, defined := config.Directives[key]; !defined {
   124  			config.Directives[key] = value
   125  		}
   126  	}
   127  
   128  	preGlobbing := config.SchemaFilename
   129  	config.SchemaFilename = StringList{}
   130  	for _, f := range preGlobbing {
   131  		var matches []string
   132  
   133  		// for ** we want to override default globbing patterns and walk all
   134  		// subdirectories to match schema files.
   135  		if strings.Contains(f, "**") {
   136  			pathParts := strings.SplitN(f, "**", 2)
   137  			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
   138  			// turn the rest of the glob into a regex, anchored only at the end because ** allows
   139  			// for any number of dirs in between and walk will let us match against the full path name
   140  			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
   141  
   142  			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
   143  				if err != nil {
   144  					return err
   145  				}
   146  
   147  				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
   148  					matches = append(matches, path)
   149  				}
   150  
   151  				return nil
   152  			}); err != nil {
   153  				return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err)
   154  			}
   155  		} else {
   156  			var err error
   157  			matches, err = filepath.Glob(f)
   158  			if err != nil {
   159  				return fmt.Errorf("failed to glob schema filename %s: %w", f, err)
   160  			}
   161  		}
   162  
   163  		for _, m := range matches {
   164  			if config.SchemaFilename.Has(m) {
   165  				continue
   166  			}
   167  			config.SchemaFilename = append(config.SchemaFilename, m)
   168  		}
   169  	}
   170  
   171  	for _, filename := range config.SchemaFilename {
   172  		filename = filepath.ToSlash(filename)
   173  		var err error
   174  		var schemaRaw []byte
   175  		schemaRaw, err = ioutil.ReadFile(filename)
   176  		if err != nil {
   177  			return fmt.Errorf("unable to open schema: %w", err)
   178  		}
   179  
   180  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
   181  	}
   182  	return nil
   183  }
   184  
   185  func (c *Config) Init() error {
   186  	if c.Packages == nil {
   187  		c.Packages = &code.Packages{}
   188  	}
   189  
   190  	if c.Schema == nil {
   191  		if err := c.LoadSchema(); err != nil {
   192  			return err
   193  		}
   194  	}
   195  
   196  	err := c.injectTypesFromSchema()
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	err = c.autobind()
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	c.injectBuiltins()
   207  	// prefetch all packages in one big packages.Load call
   208  	c.Packages.LoadAll(c.packageList()...)
   209  
   210  	//  check everything is valid on the way out
   211  	err = c.check()
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  func (c *Config) packageList() []string {
   220  	pkgs := []string{
   221  		"github.com/apipluspower/gqlgen/graphql",
   222  		"github.com/apipluspower/gqlgen/graphql/introspection",
   223  	}
   224  	pkgs = append(pkgs, c.Models.ReferencedPackages()...)
   225  	pkgs = append(pkgs, c.AutoBind...)
   226  	return pkgs
   227  }
   228  
   229  func (c *Config) ReloadAllPackages() {
   230  	c.Packages.ReloadAll(c.packageList()...)
   231  }
   232  
   233  func (c *Config) injectTypesFromSchema() error {
   234  	c.Directives["goModel"] = DirectiveConfig{
   235  		SkipRuntime: true,
   236  	}
   237  
   238  	c.Directives["goField"] = DirectiveConfig{
   239  		SkipRuntime: true,
   240  	}
   241  
   242  	c.Directives["goTag"] = DirectiveConfig{
   243  		SkipRuntime: true,
   244  	}
   245  
   246  	for _, schemaType := range c.Schema.Types {
   247  		if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
   248  			continue
   249  		}
   250  
   251  		if bd := schemaType.Directives.ForName("goModel"); bd != nil {
   252  			if ma := bd.Arguments.ForName("model"); ma != nil {
   253  				if mv, err := ma.Value.Value(nil); err == nil {
   254  					c.Models.Add(schemaType.Name, mv.(string))
   255  				}
   256  			}
   257  			if ma := bd.Arguments.ForName("models"); ma != nil {
   258  				if mvs, err := ma.Value.Value(nil); err == nil {
   259  					for _, mv := range mvs.([]interface{}) {
   260  						c.Models.Add(schemaType.Name, mv.(string))
   261  					}
   262  				}
   263  			}
   264  		}
   265  
   266  		if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
   267  			for _, field := range schemaType.Fields {
   268  				if fd := field.Directives.ForName("goField"); fd != nil {
   269  					forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
   270  					fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
   271  
   272  					if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
   273  						if fr, err := ra.Value.Value(nil); err == nil {
   274  							forceResolver = fr.(bool)
   275  						}
   276  					}
   277  
   278  					if na := fd.Arguments.ForName("name"); na != nil {
   279  						if fr, err := na.Value.Value(nil); err == nil {
   280  							fieldName = fr.(string)
   281  						}
   282  					}
   283  
   284  					if c.Models[schemaType.Name].Fields == nil {
   285  						c.Models[schemaType.Name] = TypeMapEntry{
   286  							Model:  c.Models[schemaType.Name].Model,
   287  							Fields: map[string]TypeMapField{},
   288  						}
   289  					}
   290  
   291  					c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
   292  						FieldName: fieldName,
   293  						Resolver:  forceResolver,
   294  					}
   295  				}
   296  			}
   297  		}
   298  	}
   299  
   300  	return nil
   301  }
   302  
   303  type TypeMapEntry struct {
   304  	Model  StringList              `yaml:"model"`
   305  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
   306  }
   307  
   308  type TypeMapField struct {
   309  	Resolver        bool   `yaml:"resolver"`
   310  	FieldName       string `yaml:"fieldName"`
   311  	GeneratedMethod string `yaml:"-"`
   312  }
   313  
   314  type StringList []string
   315  
   316  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   317  	var single string
   318  	err := unmarshal(&single)
   319  	if err == nil {
   320  		*a = []string{single}
   321  		return nil
   322  	}
   323  
   324  	var multi []string
   325  	err = unmarshal(&multi)
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	*a = multi
   331  	return nil
   332  }
   333  
   334  func (a StringList) Has(file string) bool {
   335  	for _, existing := range a {
   336  		if existing == file {
   337  			return true
   338  		}
   339  	}
   340  	return false
   341  }
   342  
   343  func (c *Config) check() error {
   344  	if c.Models == nil {
   345  		c.Models = TypeMap{}
   346  	}
   347  
   348  	type FilenamePackage struct {
   349  		Filename string
   350  		Package  string
   351  		Declaree string
   352  	}
   353  
   354  	fileList := map[string][]FilenamePackage{}
   355  
   356  	if err := c.Models.Check(); err != nil {
   357  		return fmt.Errorf("config.models: %w", err)
   358  	}
   359  	if err := c.Exec.Check(); err != nil {
   360  		return fmt.Errorf("config.exec: %w", err)
   361  	}
   362  	fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
   363  		Filename: c.Exec.Filename,
   364  		Package:  c.Exec.Package,
   365  		Declaree: "exec",
   366  	})
   367  
   368  	if c.Model.IsDefined() {
   369  		if err := c.Model.Check(); err != nil {
   370  			return fmt.Errorf("config.model: %w", err)
   371  		}
   372  		fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
   373  			Filename: c.Model.Filename,
   374  			Package:  c.Model.Package,
   375  			Declaree: "model",
   376  		})
   377  	}
   378  	if c.Resolver.IsDefined() {
   379  		if err := c.Resolver.Check(); err != nil {
   380  			return fmt.Errorf("config.resolver: %w", err)
   381  		}
   382  		fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
   383  			Filename: c.Resolver.Filename,
   384  			Package:  c.Resolver.Package,
   385  			Declaree: "resolver",
   386  		})
   387  	}
   388  	if c.Federation.IsDefined() {
   389  		if err := c.Federation.Check(); err != nil {
   390  			return fmt.Errorf("config.federation: %w", err)
   391  		}
   392  		fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
   393  			Filename: c.Federation.Filename,
   394  			Package:  c.Federation.Package,
   395  			Declaree: "federation",
   396  		})
   397  		if c.Federation.ImportPath() != c.Exec.ImportPath() {
   398  			return fmt.Errorf("federation and exec must be in the same package")
   399  		}
   400  	}
   401  	if c.Federated {
   402  		return fmt.Errorf("federated has been removed, instead use\nfederation:\n    filename: path/to/federated.go")
   403  	}
   404  
   405  	for importPath, pkg := range fileList {
   406  		for _, file1 := range pkg {
   407  			for _, file2 := range pkg {
   408  				if file1.Package != file2.Package {
   409  					return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
   410  						file1.Declaree,
   411  						file2.Declaree,
   412  						importPath,
   413  						file1.Package,
   414  						file2.Package,
   415  					)
   416  				}
   417  			}
   418  		}
   419  	}
   420  
   421  	return nil
   422  }
   423  
   424  type TypeMap map[string]TypeMapEntry
   425  
   426  func (tm TypeMap) Exists(typeName string) bool {
   427  	_, ok := tm[typeName]
   428  	return ok
   429  }
   430  
   431  func (tm TypeMap) UserDefined(typeName string) bool {
   432  	m, ok := tm[typeName]
   433  	return ok && len(m.Model) > 0
   434  }
   435  
   436  func (tm TypeMap) Check() error {
   437  	for typeName, entry := range tm {
   438  		for _, model := range entry.Model {
   439  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   440  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   441  			}
   442  		}
   443  	}
   444  	return nil
   445  }
   446  
   447  func (tm TypeMap) ReferencedPackages() []string {
   448  	var pkgs []string
   449  
   450  	for _, typ := range tm {
   451  		for _, model := range typ.Model {
   452  			if model == "map[string]interface{}" || model == "interface{}" {
   453  				continue
   454  			}
   455  			pkg, _ := code.PkgAndType(model)
   456  			if pkg == "" || inStrSlice(pkgs, pkg) {
   457  				continue
   458  			}
   459  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   460  		}
   461  	}
   462  
   463  	sort.Slice(pkgs, func(i, j int) bool {
   464  		return pkgs[i] > pkgs[j]
   465  	})
   466  	return pkgs
   467  }
   468  
   469  func (tm TypeMap) Add(name string, goType string) {
   470  	modelCfg := tm[name]
   471  	modelCfg.Model = append(modelCfg.Model, goType)
   472  	tm[name] = modelCfg
   473  }
   474  
   475  type DirectiveConfig struct {
   476  	SkipRuntime bool `yaml:"skip_runtime"`
   477  }
   478  
   479  func inStrSlice(haystack []string, needle string) bool {
   480  	for _, v := range haystack {
   481  		if needle == v {
   482  			return true
   483  		}
   484  	}
   485  
   486  	return false
   487  }
   488  
   489  // findCfg searches for the config file in this directory and all parents up the tree
   490  // looking for the closest match
   491  func findCfg() (string, error) {
   492  	dir, err := os.Getwd()
   493  	if err != nil {
   494  		return "", fmt.Errorf("unable to get working dir to findCfg: %w", err)
   495  	}
   496  
   497  	cfg := findCfgInDir(dir)
   498  
   499  	for cfg == "" && dir != filepath.Dir(dir) {
   500  		dir = filepath.Dir(dir)
   501  		cfg = findCfgInDir(dir)
   502  	}
   503  
   504  	if cfg == "" {
   505  		return "", os.ErrNotExist
   506  	}
   507  
   508  	return cfg, nil
   509  }
   510  
   511  func findCfgInDir(dir string) string {
   512  	for _, cfgName := range cfgFilenames {
   513  		path := filepath.Join(dir, cfgName)
   514  		if _, err := os.Stat(path); err == nil {
   515  			return path
   516  		}
   517  	}
   518  	return ""
   519  }
   520  
   521  func (c *Config) autobind() error {
   522  	if len(c.AutoBind) == 0 {
   523  		return nil
   524  	}
   525  
   526  	ps := c.Packages.LoadAll(c.AutoBind...)
   527  
   528  	for _, t := range c.Schema.Types {
   529  		if c.Models.UserDefined(t.Name) {
   530  			continue
   531  		}
   532  
   533  		for i, p := range ps {
   534  			if p == nil {
   535  				return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
   536  			}
   537  			if t := p.Types.Scope().Lookup(t.Name); t != nil {
   538  				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
   539  				break
   540  			}
   541  		}
   542  	}
   543  
   544  	for i, t := range c.Models {
   545  		for j, m := range t.Model {
   546  			pkg, typename := code.PkgAndType(m)
   547  
   548  			// skip anything that looks like an import path
   549  			if strings.Contains(pkg, "/") {
   550  				continue
   551  			}
   552  
   553  			for _, p := range ps {
   554  				if p.Name != pkg {
   555  					continue
   556  				}
   557  				if t := p.Types.Scope().Lookup(typename); t != nil {
   558  					c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
   559  					break
   560  				}
   561  			}
   562  		}
   563  	}
   564  
   565  	return nil
   566  }
   567  
   568  func (c *Config) injectBuiltins() {
   569  	builtins := TypeMap{
   570  		"__Directive":         {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.Directive"}},
   571  		"__DirectiveLocation": {Model: StringList{"github.com/apipluspower/gqlgen/graphql.String"}},
   572  		"__Type":              {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.Type"}},
   573  		"__TypeKind":          {Model: StringList{"github.com/apipluspower/gqlgen/graphql.String"}},
   574  		"__Field":             {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.Field"}},
   575  		"__EnumValue":         {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.EnumValue"}},
   576  		"__InputValue":        {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.InputValue"}},
   577  		"__Schema":            {Model: StringList{"github.com/apipluspower/gqlgen/graphql/introspection.Schema"}},
   578  		"Float":               {Model: StringList{"github.com/apipluspower/gqlgen/graphql.FloatContext"}},
   579  		"String":              {Model: StringList{"github.com/apipluspower/gqlgen/graphql.String"}},
   580  		"Boolean":             {Model: StringList{"github.com/apipluspower/gqlgen/graphql.Boolean"}},
   581  		"Int": {Model: StringList{
   582  			"github.com/apipluspower/gqlgen/graphql.Int",
   583  			"github.com/apipluspower/gqlgen/graphql.Int32",
   584  			"github.com/apipluspower/gqlgen/graphql.Int64",
   585  		}},
   586  		"ID": {
   587  			Model: StringList{
   588  				"github.com/apipluspower/gqlgen/graphql.ID",
   589  				"github.com/apipluspower/gqlgen/graphql.IntID",
   590  			},
   591  		},
   592  	}
   593  
   594  	for typeName, entry := range builtins {
   595  		if !c.Models.Exists(typeName) {
   596  			c.Models[typeName] = entry
   597  		}
   598  	}
   599  
   600  	// These are additional types that are injected if defined in the schema as scalars.
   601  	extraBuiltins := TypeMap{
   602  		"Time":   {Model: StringList{"github.com/apipluspower/gqlgen/graphql.Time"}},
   603  		"Map":    {Model: StringList{"github.com/apipluspower/gqlgen/graphql.Map"}},
   604  		"Upload": {Model: StringList{"github.com/apipluspower/gqlgen/graphql.Upload"}},
   605  		"Any":    {Model: StringList{"github.com/apipluspower/gqlgen/graphql.Any"}},
   606  	}
   607  
   608  	for typeName, entry := range extraBuiltins {
   609  		if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
   610  			c.Models[typeName] = entry
   611  		}
   612  	}
   613  }
   614  
   615  func (c *Config) LoadSchema() error {
   616  	if c.Packages != nil {
   617  		c.Packages = &code.Packages{}
   618  	}
   619  
   620  	if err := c.check(); err != nil {
   621  		return err
   622  	}
   623  
   624  	schema, err := gqlparser.LoadSchema(c.Sources...)
   625  	if err != nil {
   626  		return err
   627  	}
   628  
   629  	if schema.Query == nil {
   630  		schema.Query = &ast.Definition{
   631  			Kind: ast.Object,
   632  			Name: "Query",
   633  		}
   634  		schema.Types["Query"] = schema.Query
   635  	}
   636  
   637  	c.Schema = schema
   638  	return nil
   639  }
   640  
   641  func abs(path string) string {
   642  	absPath, err := filepath.Abs(path)
   643  	if err != nil {
   644  		panic(err)
   645  	}
   646  	return filepath.ToSlash(absPath)
   647  }