github.com/shippio/gqlgen@v0.0.0-20220912092219-633ea699ef07/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/99designs/gqlgen/internal/code"
    13  	"github.com/vektah/gqlparser/v2"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  	"gopkg.in/yaml.v3"
    16  )
    17  
    18  type Config struct {
    19  	SchemaFilename                StringList                 `yaml:"schema,omitempty"`
    20  	Exec                          ExecConfig                 `yaml:"exec"`
    21  	Model                         PackageConfig              `yaml:"model,omitempty"`
    22  	Federation                    PackageConfig              `yaml:"federation,omitempty"`
    23  	Resolver                      ResolverConfig             `yaml:"resolver,omitempty"`
    24  	AutoBind                      []string                   `yaml:"autobind"`
    25  	Models                        TypeMap                    `yaml:"models,omitempty"`
    26  	StructTag                     string                     `yaml:"struct_tag,omitempty"`
    27  	Directives                    map[string]DirectiveConfig `yaml:"directives,omitempty"`
    28  	OmitSliceElementPointers      bool                       `yaml:"omit_slice_element_pointers,omitempty"`
    29  	OmitGetters                   bool                       `yaml:"omit_getters,omitempty"`
    30  	StructFieldsAlwaysPointers    bool                       `yaml:"struct_fields_always_pointers,omitempty"`
    31  	ResolversAlwaysReturnPointers bool                       `yaml:"resolvers_always_return_pointers,omitempty"`
    32  	SkipValidation                bool                       `yaml:"skip_validation,omitempty"`
    33  	SkipModTidy                   bool                       `yaml:"skip_mod_tidy,omitempty"`
    34  	Sources                       []*ast.Source              `yaml:"-"`
    35  	Packages                      *code.Packages             `yaml:"-"`
    36  	Schema                        *ast.Schema                `yaml:"-"`
    37  
    38  	// Deprecated: use Federation instead. Will be removed next release
    39  	Federated bool `yaml:"federated,omitempty"`
    40  }
    41  
    42  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    43  
    44  // DefaultConfig creates a copy of the default config
    45  func DefaultConfig() *Config {
    46  	return &Config{
    47  		SchemaFilename:                StringList{"schema.graphql"},
    48  		Model:                         PackageConfig{Filename: "models_gen.go"},
    49  		Exec:                          ExecConfig{Filename: "generated.go"},
    50  		Directives:                    map[string]DirectiveConfig{},
    51  		Models:                        TypeMap{},
    52  		StructFieldsAlwaysPointers:    true,
    53  		ResolversAlwaysReturnPointers: true,
    54  	}
    55  }
    56  
    57  // LoadDefaultConfig loads the default config so that it is ready to be used
    58  func LoadDefaultConfig() (*Config, error) {
    59  	config := DefaultConfig()
    60  
    61  	for _, filename := range config.SchemaFilename {
    62  		filename = filepath.ToSlash(filename)
    63  		var err error
    64  		var schemaRaw []byte
    65  		schemaRaw, err = os.ReadFile(filename)
    66  		if err != nil {
    67  			return nil, fmt.Errorf("unable to open schema: %w", err)
    68  		}
    69  
    70  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
    71  	}
    72  
    73  	return config, nil
    74  }
    75  
    76  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    77  // walking up the tree. The closest config file will be returned.
    78  func LoadConfigFromDefaultLocations() (*Config, error) {
    79  	cfgFile, err := findCfg()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	err = os.Chdir(filepath.Dir(cfgFile))
    85  	if err != nil {
    86  		return nil, fmt.Errorf("unable to enter config dir: %w", err)
    87  	}
    88  	return LoadConfig(cfgFile)
    89  }
    90  
    91  var path2regex = strings.NewReplacer(
    92  	`.`, `\.`,
    93  	`*`, `.+`,
    94  	`\`, `[\\/]`,
    95  	`/`, `[\\/]`,
    96  )
    97  
    98  // LoadConfig reads the gqlgen.yml config file
    99  func LoadConfig(filename string) (*Config, error) {
   100  	config := DefaultConfig()
   101  
   102  	b, err := os.ReadFile(filename)
   103  	if err != nil {
   104  		return nil, fmt.Errorf("unable to read config: %w", err)
   105  	}
   106  
   107  	dec := yaml.NewDecoder(bytes.NewReader(b))
   108  	dec.KnownFields(true)
   109  
   110  	if err := dec.Decode(config); err != nil {
   111  		return nil, fmt.Errorf("unable to parse config: %w", err)
   112  	}
   113  
   114  	if err := CompleteConfig(config); err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return config, nil
   119  }
   120  
   121  // CompleteConfig fills in the schema and other values to a config loaded from
   122  // YAML.
   123  func CompleteConfig(config *Config) error {
   124  	defaultDirectives := map[string]DirectiveConfig{
   125  		"skip":        {SkipRuntime: true},
   126  		"include":     {SkipRuntime: true},
   127  		"deprecated":  {SkipRuntime: true},
   128  		"specifiedBy": {SkipRuntime: true},
   129  	}
   130  
   131  	for key, value := range defaultDirectives {
   132  		if _, defined := config.Directives[key]; !defined {
   133  			config.Directives[key] = value
   134  		}
   135  	}
   136  
   137  	preGlobbing := config.SchemaFilename
   138  	config.SchemaFilename = StringList{}
   139  	for _, f := range preGlobbing {
   140  		var matches []string
   141  
   142  		// for ** we want to override default globbing patterns and walk all
   143  		// subdirectories to match schema files.
   144  		if strings.Contains(f, "**") {
   145  			pathParts := strings.SplitN(f, "**", 2)
   146  			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
   147  			// turn the rest of the glob into a regex, anchored only at the end because ** allows
   148  			// for any number of dirs in between and walk will let us match against the full path name
   149  			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
   150  
   151  			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
   152  				if err != nil {
   153  					return err
   154  				}
   155  
   156  				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
   157  					matches = append(matches, path)
   158  				}
   159  
   160  				return nil
   161  			}); err != nil {
   162  				return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err)
   163  			}
   164  		} else {
   165  			var err error
   166  			matches, err = filepath.Glob(f)
   167  			if err != nil {
   168  				return fmt.Errorf("failed to glob schema filename %s: %w", f, err)
   169  			}
   170  		}
   171  
   172  		for _, m := range matches {
   173  			if config.SchemaFilename.Has(m) {
   174  				continue
   175  			}
   176  			config.SchemaFilename = append(config.SchemaFilename, m)
   177  		}
   178  	}
   179  
   180  	for _, filename := range config.SchemaFilename {
   181  		filename = filepath.ToSlash(filename)
   182  		var err error
   183  		var schemaRaw []byte
   184  		schemaRaw, err = os.ReadFile(filename)
   185  		if err != nil {
   186  			return fmt.Errorf("unable to open schema: %w", err)
   187  		}
   188  
   189  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
   190  	}
   191  	return nil
   192  }
   193  
   194  func (c *Config) Init() error {
   195  	if c.Packages == nil {
   196  		c.Packages = &code.Packages{}
   197  	}
   198  
   199  	if c.Schema == nil {
   200  		if err := c.LoadSchema(); err != nil {
   201  			return err
   202  		}
   203  	}
   204  
   205  	err := c.injectTypesFromSchema()
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	err = c.autobind()
   211  	if err != nil {
   212  		return err
   213  	}
   214  
   215  	c.injectBuiltins()
   216  	// prefetch all packages in one big packages.Load call
   217  	c.Packages.LoadAll(c.packageList()...)
   218  
   219  	//  check everything is valid on the way out
   220  	err = c.check()
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  func (c *Config) packageList() []string {
   229  	pkgs := []string{
   230  		"github.com/99designs/gqlgen/graphql",
   231  		"github.com/99designs/gqlgen/graphql/introspection",
   232  	}
   233  	pkgs = append(pkgs, c.Models.ReferencedPackages()...)
   234  	pkgs = append(pkgs, c.AutoBind...)
   235  	return pkgs
   236  }
   237  
   238  func (c *Config) ReloadAllPackages() {
   239  	c.Packages.ReloadAll(c.packageList()...)
   240  }
   241  
   242  func (c *Config) injectTypesFromSchema() error {
   243  	c.Directives["goModel"] = DirectiveConfig{
   244  		SkipRuntime: true,
   245  	}
   246  
   247  	c.Directives["goField"] = DirectiveConfig{
   248  		SkipRuntime: true,
   249  	}
   250  
   251  	c.Directives["goTag"] = DirectiveConfig{
   252  		SkipRuntime: true,
   253  	}
   254  
   255  	for _, schemaType := range c.Schema.Types {
   256  		if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
   257  			continue
   258  		}
   259  
   260  		if bd := schemaType.Directives.ForName("goModel"); bd != nil {
   261  			if ma := bd.Arguments.ForName("model"); ma != nil {
   262  				if mv, err := ma.Value.Value(nil); err == nil {
   263  					c.Models.Add(schemaType.Name, mv.(string))
   264  				}
   265  			}
   266  			if ma := bd.Arguments.ForName("models"); ma != nil {
   267  				if mvs, err := ma.Value.Value(nil); err == nil {
   268  					for _, mv := range mvs.([]interface{}) {
   269  						c.Models.Add(schemaType.Name, mv.(string))
   270  					}
   271  				}
   272  			}
   273  		}
   274  
   275  		if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
   276  			for _, field := range schemaType.Fields {
   277  				if fd := field.Directives.ForName("goField"); fd != nil {
   278  					forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
   279  					fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
   280  
   281  					if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
   282  						if fr, err := ra.Value.Value(nil); err == nil {
   283  							forceResolver = fr.(bool)
   284  						}
   285  					}
   286  
   287  					if na := fd.Arguments.ForName("name"); na != nil {
   288  						if fr, err := na.Value.Value(nil); err == nil {
   289  							fieldName = fr.(string)
   290  						}
   291  					}
   292  
   293  					if c.Models[schemaType.Name].Fields == nil {
   294  						c.Models[schemaType.Name] = TypeMapEntry{
   295  							Model:  c.Models[schemaType.Name].Model,
   296  							Fields: map[string]TypeMapField{},
   297  						}
   298  					}
   299  
   300  					c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
   301  						FieldName: fieldName,
   302  						Resolver:  forceResolver,
   303  					}
   304  				}
   305  			}
   306  		}
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  type TypeMapEntry struct {
   313  	Model  StringList              `yaml:"model"`
   314  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
   315  }
   316  
   317  type TypeMapField struct {
   318  	Resolver        bool   `yaml:"resolver"`
   319  	FieldName       string `yaml:"fieldName"`
   320  	GeneratedMethod string `yaml:"-"`
   321  }
   322  
   323  type StringList []string
   324  
   325  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   326  	var single string
   327  	err := unmarshal(&single)
   328  	if err == nil {
   329  		*a = []string{single}
   330  		return nil
   331  	}
   332  
   333  	var multi []string
   334  	err = unmarshal(&multi)
   335  	if err != nil {
   336  		return err
   337  	}
   338  
   339  	*a = multi
   340  	return nil
   341  }
   342  
   343  func (a StringList) Has(file string) bool {
   344  	for _, existing := range a {
   345  		if existing == file {
   346  			return true
   347  		}
   348  	}
   349  	return false
   350  }
   351  
   352  func (c *Config) check() error {
   353  	if c.Models == nil {
   354  		c.Models = TypeMap{}
   355  	}
   356  
   357  	type FilenamePackage struct {
   358  		Filename string
   359  		Package  string
   360  		Declaree string
   361  	}
   362  
   363  	fileList := map[string][]FilenamePackage{}
   364  
   365  	if err := c.Models.Check(); err != nil {
   366  		return fmt.Errorf("config.models: %w", err)
   367  	}
   368  	if err := c.Exec.Check(); err != nil {
   369  		return fmt.Errorf("config.exec: %w", err)
   370  	}
   371  	fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
   372  		Filename: c.Exec.Filename,
   373  		Package:  c.Exec.Package,
   374  		Declaree: "exec",
   375  	})
   376  
   377  	if c.Model.IsDefined() {
   378  		if err := c.Model.Check(); err != nil {
   379  			return fmt.Errorf("config.model: %w", err)
   380  		}
   381  		fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
   382  			Filename: c.Model.Filename,
   383  			Package:  c.Model.Package,
   384  			Declaree: "model",
   385  		})
   386  	}
   387  	if c.Resolver.IsDefined() {
   388  		if err := c.Resolver.Check(); err != nil {
   389  			return fmt.Errorf("config.resolver: %w", err)
   390  		}
   391  		fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
   392  			Filename: c.Resolver.Filename,
   393  			Package:  c.Resolver.Package,
   394  			Declaree: "resolver",
   395  		})
   396  	}
   397  	if c.Federation.IsDefined() {
   398  		if err := c.Federation.Check(); err != nil {
   399  			return fmt.Errorf("config.federation: %w", err)
   400  		}
   401  		fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
   402  			Filename: c.Federation.Filename,
   403  			Package:  c.Federation.Package,
   404  			Declaree: "federation",
   405  		})
   406  		if c.Federation.ImportPath() != c.Exec.ImportPath() {
   407  			return fmt.Errorf("federation and exec must be in the same package")
   408  		}
   409  	}
   410  	if c.Federated {
   411  		return fmt.Errorf("federated has been removed, instead use\nfederation:\n    filename: path/to/federated.go")
   412  	}
   413  
   414  	for importPath, pkg := range fileList {
   415  		for _, file1 := range pkg {
   416  			for _, file2 := range pkg {
   417  				if file1.Package != file2.Package {
   418  					return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
   419  						file1.Declaree,
   420  						file2.Declaree,
   421  						importPath,
   422  						file1.Package,
   423  						file2.Package,
   424  					)
   425  				}
   426  			}
   427  		}
   428  	}
   429  
   430  	return nil
   431  }
   432  
   433  type TypeMap map[string]TypeMapEntry
   434  
   435  func (tm TypeMap) Exists(typeName string) bool {
   436  	_, ok := tm[typeName]
   437  	return ok
   438  }
   439  
   440  func (tm TypeMap) UserDefined(typeName string) bool {
   441  	m, ok := tm[typeName]
   442  	return ok && len(m.Model) > 0
   443  }
   444  
   445  func (tm TypeMap) Check() error {
   446  	for typeName, entry := range tm {
   447  		for _, model := range entry.Model {
   448  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   449  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   450  			}
   451  		}
   452  	}
   453  	return nil
   454  }
   455  
   456  func (tm TypeMap) ReferencedPackages() []string {
   457  	var pkgs []string
   458  
   459  	for _, typ := range tm {
   460  		for _, model := range typ.Model {
   461  			if model == "map[string]interface{}" || model == "interface{}" {
   462  				continue
   463  			}
   464  			pkg, _ := code.PkgAndType(model)
   465  			if pkg == "" || inStrSlice(pkgs, pkg) {
   466  				continue
   467  			}
   468  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   469  		}
   470  	}
   471  
   472  	sort.Slice(pkgs, func(i, j int) bool {
   473  		return pkgs[i] > pkgs[j]
   474  	})
   475  	return pkgs
   476  }
   477  
   478  func (tm TypeMap) Add(name string, goType string) {
   479  	modelCfg := tm[name]
   480  	modelCfg.Model = append(modelCfg.Model, goType)
   481  	tm[name] = modelCfg
   482  }
   483  
   484  type DirectiveConfig struct {
   485  	SkipRuntime bool `yaml:"skip_runtime"`
   486  }
   487  
   488  func inStrSlice(haystack []string, needle string) bool {
   489  	for _, v := range haystack {
   490  		if needle == v {
   491  			return true
   492  		}
   493  	}
   494  
   495  	return false
   496  }
   497  
   498  // findCfg searches for the config file in this directory and all parents up the tree
   499  // looking for the closest match
   500  func findCfg() (string, error) {
   501  	dir, err := os.Getwd()
   502  	if err != nil {
   503  		return "", fmt.Errorf("unable to get working dir to findCfg: %w", err)
   504  	}
   505  
   506  	cfg := findCfgInDir(dir)
   507  
   508  	for cfg == "" && dir != filepath.Dir(dir) {
   509  		dir = filepath.Dir(dir)
   510  		cfg = findCfgInDir(dir)
   511  	}
   512  
   513  	if cfg == "" {
   514  		return "", os.ErrNotExist
   515  	}
   516  
   517  	return cfg, nil
   518  }
   519  
   520  func findCfgInDir(dir string) string {
   521  	for _, cfgName := range cfgFilenames {
   522  		path := filepath.Join(dir, cfgName)
   523  		if _, err := os.Stat(path); err == nil {
   524  			return path
   525  		}
   526  	}
   527  	return ""
   528  }
   529  
   530  func (c *Config) autobind() error {
   531  	if len(c.AutoBind) == 0 {
   532  		return nil
   533  	}
   534  
   535  	ps := c.Packages.LoadAll(c.AutoBind...)
   536  
   537  	for _, t := range c.Schema.Types {
   538  		if c.Models.UserDefined(t.Name) {
   539  			continue
   540  		}
   541  
   542  		for i, p := range ps {
   543  			if p == nil || p.Module == nil {
   544  				return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
   545  			}
   546  			if t := p.Types.Scope().Lookup(t.Name); t != nil {
   547  				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
   548  				break
   549  			}
   550  		}
   551  	}
   552  
   553  	for i, t := range c.Models {
   554  		for j, m := range t.Model {
   555  			pkg, typename := code.PkgAndType(m)
   556  
   557  			// skip anything that looks like an import path
   558  			if strings.Contains(pkg, "/") {
   559  				continue
   560  			}
   561  
   562  			for _, p := range ps {
   563  				if p.Name != pkg {
   564  					continue
   565  				}
   566  				if t := p.Types.Scope().Lookup(typename); t != nil {
   567  					c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
   568  					break
   569  				}
   570  			}
   571  		}
   572  	}
   573  
   574  	return nil
   575  }
   576  
   577  func (c *Config) injectBuiltins() {
   578  	builtins := TypeMap{
   579  		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
   580  		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   581  		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
   582  		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   583  		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
   584  		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
   585  		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
   586  		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
   587  		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.FloatContext"}},
   588  		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   589  		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
   590  		"Int": {Model: StringList{
   591  			"github.com/99designs/gqlgen/graphql.Int",
   592  			"github.com/99designs/gqlgen/graphql.Int32",
   593  			"github.com/99designs/gqlgen/graphql.Int64",
   594  		}},
   595  		"ID": {
   596  			Model: StringList{
   597  				"github.com/99designs/gqlgen/graphql.ID",
   598  				"github.com/99designs/gqlgen/graphql.IntID",
   599  			},
   600  		},
   601  	}
   602  
   603  	for typeName, entry := range builtins {
   604  		if !c.Models.Exists(typeName) {
   605  			c.Models[typeName] = entry
   606  		}
   607  	}
   608  
   609  	// These are additional types that are injected if defined in the schema as scalars.
   610  	extraBuiltins := TypeMap{
   611  		"Time":   {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
   612  		"Map":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
   613  		"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
   614  		"Any":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
   615  	}
   616  
   617  	for typeName, entry := range extraBuiltins {
   618  		if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
   619  			c.Models[typeName] = entry
   620  		}
   621  	}
   622  }
   623  
   624  func (c *Config) LoadSchema() error {
   625  	if c.Packages != nil {
   626  		c.Packages = &code.Packages{}
   627  	}
   628  
   629  	if err := c.check(); err != nil {
   630  		return err
   631  	}
   632  
   633  	schema, err := gqlparser.LoadSchema(c.Sources...)
   634  	if err != nil {
   635  		return err
   636  	}
   637  
   638  	if schema.Query == nil {
   639  		schema.Query = &ast.Definition{
   640  			Kind: ast.Object,
   641  			Name: "Query",
   642  		}
   643  		schema.Types["Query"] = schema.Query
   644  	}
   645  
   646  	c.Schema = schema
   647  	return nil
   648  }
   649  
   650  func abs(path string) string {
   651  	absPath, err := filepath.Abs(path)
   652  	if err != nil {
   653  		panic(err)
   654  	}
   655  	return filepath.ToSlash(absPath)
   656  }