git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"sort"
    10  	"strings"
    11  
    12  	"git.sr.ht/~sircmpwn/gqlgen/internal/code"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser/v2"
    15  	"github.com/vektah/gqlparser/v2/ast"
    16  	"gopkg.in/yaml.v2"
    17  )
    18  
    19  type Config struct {
    20  	SchemaFilename           StringList                 `yaml:"schema,omitempty"`
    21  	Exec                     PackageConfig              `yaml:"exec"`
    22  	Model                    PackageConfig              `yaml:"model,omitempty"`
    23  	Federation               PackageConfig              `yaml:"federation,omitempty"`
    24  	Resolver                 ResolverConfig             `yaml:"resolver,omitempty"`
    25  	AutoBind                 []string                   `yaml:"autobind"`
    26  	Models                   TypeMap                    `yaml:"models,omitempty"`
    27  	StructTag                string                     `yaml:"struct_tag,omitempty"`
    28  	Directives               map[string]DirectiveConfig `yaml:"directives,omitempty"`
    29  	OmitSliceElementPointers bool                       `yaml:"omit_slice_element_pointers,omitempty"`
    30  	SkipValidation           bool                       `yaml:"skip_validation,omitempty"`
    31  	Sources                  []*ast.Source              `yaml:"-"`
    32  	Packages                 *code.Packages             `yaml:"-"`
    33  	Schema                   *ast.Schema                `yaml:"-"`
    34  
    35  	// Deprecated use Federation instead. Will be removed next release
    36  	Federated bool `yaml:"federated,omitempty"`
    37  }
    38  
    39  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    40  
    41  // DefaultConfig creates a copy of the default config
    42  func DefaultConfig() *Config {
    43  	return &Config{
    44  		SchemaFilename: StringList{"schema.graphql"},
    45  		Model:          PackageConfig{Filename: "models_gen.go"},
    46  		Exec:           PackageConfig{Filename: "generated.go"},
    47  		Directives:     map[string]DirectiveConfig{},
    48  		Models:         TypeMap{},
    49  	}
    50  }
    51  
    52  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    53  // walking up the tree. The closest config file will be returned.
    54  func LoadConfigFromDefaultLocations() (*Config, error) {
    55  	cfgFile, err := findCfg()
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	err = os.Chdir(filepath.Dir(cfgFile))
    61  	if err != nil {
    62  		return nil, errors.Wrap(err, "unable to enter config dir")
    63  	}
    64  	return LoadConfig(cfgFile)
    65  }
    66  
    67  var path2regex = strings.NewReplacer(
    68  	`.`, `\.`,
    69  	`*`, `.+`,
    70  	`\`, `[\\/]`,
    71  	`/`, `[\\/]`,
    72  )
    73  
    74  // LoadConfig reads the gqlgen.yml config file
    75  func LoadConfig(filename string) (*Config, error) {
    76  	config := DefaultConfig()
    77  
    78  	b, err := ioutil.ReadFile(filename)
    79  	if err != nil {
    80  		return nil, errors.Wrap(err, "unable to read config")
    81  	}
    82  
    83  	if err := yaml.UnmarshalStrict(b, config); err != nil {
    84  		return nil, errors.Wrap(err, "unable to parse config")
    85  	}
    86  
    87  	defaultDirectives := map[string]DirectiveConfig{
    88  		"skip":       {SkipRuntime: true},
    89  		"include":    {SkipRuntime: true},
    90  		"deprecated": {SkipRuntime: true},
    91  	}
    92  
    93  	for key, value := range defaultDirectives {
    94  		if _, defined := config.Directives[key]; !defined {
    95  			config.Directives[key] = value
    96  		}
    97  	}
    98  
    99  	preGlobbing := config.SchemaFilename
   100  	config.SchemaFilename = StringList{}
   101  	for _, f := range preGlobbing {
   102  		var matches []string
   103  
   104  		// for ** we want to override default globbing patterns and walk all
   105  		// subdirectories to match schema files.
   106  		if strings.Contains(f, "**") {
   107  			pathParts := strings.SplitN(f, "**", 2)
   108  			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
   109  			// turn the rest of the glob into a regex, anchored only at the end because ** allows
   110  			// for any number of dirs in between and walk will let us match against the full path name
   111  			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
   112  
   113  			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
   114  				if err != nil {
   115  					return err
   116  				}
   117  
   118  				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
   119  					matches = append(matches, path)
   120  				}
   121  
   122  				return nil
   123  			}); err != nil {
   124  				return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
   125  			}
   126  		} else {
   127  			matches, err = filepath.Glob(f)
   128  			if err != nil {
   129  				return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
   130  			}
   131  		}
   132  
   133  		for _, m := range matches {
   134  			if config.SchemaFilename.Has(m) {
   135  				continue
   136  			}
   137  			config.SchemaFilename = append(config.SchemaFilename, m)
   138  		}
   139  	}
   140  
   141  	for _, filename := range config.SchemaFilename {
   142  		filename = filepath.ToSlash(filename)
   143  		var err error
   144  		var schemaRaw []byte
   145  		schemaRaw, err = ioutil.ReadFile(filename)
   146  		if err != nil {
   147  			return nil, errors.Wrap(err, "unable to open schema")
   148  		}
   149  
   150  		config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
   151  	}
   152  
   153  	return config, nil
   154  }
   155  
   156  func (c *Config) Init() error {
   157  	if c.Packages == nil {
   158  		c.Packages = &code.Packages{}
   159  	}
   160  
   161  	if c.Schema == nil {
   162  		if err := c.LoadSchema(); err != nil {
   163  			return err
   164  		}
   165  	}
   166  
   167  	err := c.injectTypesFromSchema()
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	err = c.autobind()
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	c.injectBuiltins()
   178  
   179  	// prefetch all packages in one big packages.Load call
   180  	pkgs := []string{
   181  		"git.sr.ht/~sircmpwn/gqlgen/graphql",
   182  		"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection",
   183  	}
   184  	pkgs = append(pkgs, c.Models.ReferencedPackages()...)
   185  	pkgs = append(pkgs, c.AutoBind...)
   186  	c.Packages.LoadAll(pkgs...)
   187  
   188  	//  check everything is valid on the way out
   189  	err = c.check()
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  func (c *Config) injectTypesFromSchema() error {
   198  	c.Directives["goModel"] = DirectiveConfig{
   199  		SkipRuntime: true,
   200  	}
   201  
   202  	c.Directives["goField"] = DirectiveConfig{
   203  		SkipRuntime: true,
   204  	}
   205  
   206  	for _, schemaType := range c.Schema.Types {
   207  		if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
   208  			continue
   209  		}
   210  
   211  		if bd := schemaType.Directives.ForName("goModel"); bd != nil {
   212  			if ma := bd.Arguments.ForName("model"); ma != nil {
   213  				if mv, err := ma.Value.Value(nil); err == nil {
   214  					c.Models.Add(schemaType.Name, mv.(string))
   215  				}
   216  			}
   217  			if ma := bd.Arguments.ForName("models"); ma != nil {
   218  				if mvs, err := ma.Value.Value(nil); err == nil {
   219  					for _, mv := range mvs.([]interface{}) {
   220  						c.Models.Add(schemaType.Name, mv.(string))
   221  					}
   222  				}
   223  			}
   224  		}
   225  
   226  		if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
   227  			for _, field := range schemaType.Fields {
   228  				if fd := field.Directives.ForName("goField"); fd != nil {
   229  					forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
   230  					fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
   231  
   232  					if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
   233  						if fr, err := ra.Value.Value(nil); err == nil {
   234  							forceResolver = fr.(bool)
   235  						}
   236  					}
   237  
   238  					if na := fd.Arguments.ForName("name"); na != nil {
   239  						if fr, err := na.Value.Value(nil); err == nil {
   240  							fieldName = fr.(string)
   241  						}
   242  					}
   243  
   244  					if c.Models[schemaType.Name].Fields == nil {
   245  						c.Models[schemaType.Name] = TypeMapEntry{
   246  							Model:  c.Models[schemaType.Name].Model,
   247  							Fields: map[string]TypeMapField{},
   248  						}
   249  					}
   250  
   251  					c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
   252  						FieldName: fieldName,
   253  						Resolver:  forceResolver,
   254  					}
   255  				}
   256  			}
   257  		}
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  type TypeMapEntry struct {
   264  	Model  StringList              `yaml:"model"`
   265  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
   266  }
   267  
   268  type TypeMapField struct {
   269  	Resolver        bool   `yaml:"resolver"`
   270  	FieldName       string `yaml:"fieldName"`
   271  	GeneratedMethod string `yaml:"-"`
   272  }
   273  
   274  type StringList []string
   275  
   276  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   277  	var single string
   278  	err := unmarshal(&single)
   279  	if err == nil {
   280  		*a = []string{single}
   281  		return nil
   282  	}
   283  
   284  	var multi []string
   285  	err = unmarshal(&multi)
   286  	if err != nil {
   287  		return err
   288  	}
   289  
   290  	*a = multi
   291  	return nil
   292  }
   293  
   294  func (a StringList) Has(file string) bool {
   295  	for _, existing := range a {
   296  		if existing == file {
   297  			return true
   298  		}
   299  	}
   300  	return false
   301  }
   302  
   303  func (c *Config) check() error {
   304  	if c.Models == nil {
   305  		c.Models = TypeMap{}
   306  	}
   307  
   308  	type FilenamePackage struct {
   309  		Filename string
   310  		Package  string
   311  		Declaree string
   312  	}
   313  
   314  	fileList := map[string][]FilenamePackage{}
   315  
   316  	if err := c.Models.Check(); err != nil {
   317  		return errors.Wrap(err, "config.models")
   318  	}
   319  	if err := c.Exec.Check(); err != nil {
   320  		return errors.Wrap(err, "config.exec")
   321  	}
   322  	fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
   323  		Filename: c.Exec.Filename,
   324  		Package:  c.Exec.Package,
   325  		Declaree: "exec",
   326  	})
   327  
   328  	if c.Model.IsDefined() {
   329  		if err := c.Model.Check(); err != nil {
   330  			return errors.Wrap(err, "config.model")
   331  		}
   332  		fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
   333  			Filename: c.Model.Filename,
   334  			Package:  c.Model.Package,
   335  			Declaree: "model",
   336  		})
   337  	}
   338  	if c.Resolver.IsDefined() {
   339  		if err := c.Resolver.Check(); err != nil {
   340  			return errors.Wrap(err, "config.resolver")
   341  		}
   342  		fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
   343  			Filename: c.Resolver.Filename,
   344  			Package:  c.Resolver.Package,
   345  			Declaree: "resolver",
   346  		})
   347  	}
   348  	if c.Federation.IsDefined() {
   349  		if err := c.Federation.Check(); err != nil {
   350  			return errors.Wrap(err, "config.federation")
   351  		}
   352  		fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
   353  			Filename: c.Federation.Filename,
   354  			Package:  c.Federation.Package,
   355  			Declaree: "federation",
   356  		})
   357  		if c.Federation.ImportPath() != c.Exec.ImportPath() {
   358  			return fmt.Errorf("federation and exec must be in the same package")
   359  		}
   360  	}
   361  	if c.Federated {
   362  		return fmt.Errorf("federated has been removed, instead use\nfederation:\n    filename: path/to/federated.go")
   363  	}
   364  
   365  	for importPath, pkg := range fileList {
   366  		for _, file1 := range pkg {
   367  			for _, file2 := range pkg {
   368  				if file1.Package != file2.Package {
   369  					return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
   370  						file1.Declaree,
   371  						file2.Declaree,
   372  						importPath,
   373  						file1.Package,
   374  						file2.Package,
   375  					)
   376  				}
   377  			}
   378  		}
   379  	}
   380  
   381  	return nil
   382  }
   383  
   384  type TypeMap map[string]TypeMapEntry
   385  
   386  func (tm TypeMap) Exists(typeName string) bool {
   387  	_, ok := tm[typeName]
   388  	return ok
   389  }
   390  
   391  func (tm TypeMap) UserDefined(typeName string) bool {
   392  	m, ok := tm[typeName]
   393  	return ok && len(m.Model) > 0
   394  }
   395  
   396  func (tm TypeMap) Check() error {
   397  	for typeName, entry := range tm {
   398  		for _, model := range entry.Model {
   399  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   400  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   401  			}
   402  		}
   403  	}
   404  	return nil
   405  }
   406  
   407  func (tm TypeMap) ReferencedPackages() []string {
   408  	var pkgs []string
   409  
   410  	for _, typ := range tm {
   411  		for _, model := range typ.Model {
   412  			if model == "map[string]interface{}" || model == "interface{}" {
   413  				continue
   414  			}
   415  			pkg, _ := code.PkgAndType(model)
   416  			if pkg == "" || inStrSlice(pkgs, pkg) {
   417  				continue
   418  			}
   419  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   420  		}
   421  	}
   422  
   423  	sort.Slice(pkgs, func(i, j int) bool {
   424  		return pkgs[i] > pkgs[j]
   425  	})
   426  	return pkgs
   427  }
   428  
   429  func (tm TypeMap) Add(name string, goType string) {
   430  	modelCfg := tm[name]
   431  	modelCfg.Model = append(modelCfg.Model, goType)
   432  	tm[name] = modelCfg
   433  }
   434  
   435  type DirectiveConfig struct {
   436  	SkipRuntime bool `yaml:"skip_runtime"`
   437  }
   438  
   439  func inStrSlice(haystack []string, needle string) bool {
   440  	for _, v := range haystack {
   441  		if needle == v {
   442  			return true
   443  		}
   444  	}
   445  
   446  	return false
   447  }
   448  
   449  // findCfg searches for the config file in this directory and all parents up the tree
   450  // looking for the closest match
   451  func findCfg() (string, error) {
   452  	dir, err := os.Getwd()
   453  	if err != nil {
   454  		return "", errors.Wrap(err, "unable to get working dir to findCfg")
   455  	}
   456  
   457  	cfg := findCfgInDir(dir)
   458  
   459  	for cfg == "" && dir != filepath.Dir(dir) {
   460  		dir = filepath.Dir(dir)
   461  		cfg = findCfgInDir(dir)
   462  	}
   463  
   464  	if cfg == "" {
   465  		return "", os.ErrNotExist
   466  	}
   467  
   468  	return cfg, nil
   469  }
   470  
   471  func findCfgInDir(dir string) string {
   472  	for _, cfgName := range cfgFilenames {
   473  		path := filepath.Join(dir, cfgName)
   474  		if _, err := os.Stat(path); err == nil {
   475  			return path
   476  		}
   477  	}
   478  	return ""
   479  }
   480  
   481  func (c *Config) autobind() error {
   482  	if len(c.AutoBind) == 0 {
   483  		return nil
   484  	}
   485  
   486  	ps := c.Packages.LoadAll(c.AutoBind...)
   487  
   488  	for _, t := range c.Schema.Types {
   489  		if c.Models.UserDefined(t.Name) {
   490  			continue
   491  		}
   492  
   493  		for i, p := range ps {
   494  			if p == nil {
   495  				return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
   496  			}
   497  			if t := p.Types.Scope().Lookup(t.Name); t != nil {
   498  				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
   499  				break
   500  			}
   501  		}
   502  	}
   503  
   504  	for i, t := range c.Models {
   505  		for j, m := range t.Model {
   506  			pkg, typename := code.PkgAndType(m)
   507  
   508  			// skip anything that looks like an import path
   509  			if strings.Contains(pkg, "/") {
   510  				continue
   511  			}
   512  
   513  			for _, p := range ps {
   514  				if p.Name != pkg {
   515  					continue
   516  				}
   517  				if t := p.Types.Scope().Lookup(typename); t != nil {
   518  					c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
   519  					break
   520  				}
   521  			}
   522  		}
   523  	}
   524  
   525  	return nil
   526  }
   527  
   528  func (c *Config) injectBuiltins() {
   529  	builtins := TypeMap{
   530  		"__Directive":         {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Directive"}},
   531  		"__DirectiveLocation": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}},
   532  		"__Type":              {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Type"}},
   533  		"__TypeKind":          {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}},
   534  		"__Field":             {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Field"}},
   535  		"__EnumValue":         {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.EnumValue"}},
   536  		"__InputValue":        {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.InputValue"}},
   537  		"__Schema":            {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql/introspection.Schema"}},
   538  		"Float":               {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Float"}},
   539  		"String":              {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.String"}},
   540  		"Boolean":             {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Boolean"}},
   541  		"Int": {Model: StringList{
   542  			"git.sr.ht/~sircmpwn/gqlgen/graphql.Int",
   543  			"git.sr.ht/~sircmpwn/gqlgen/graphql.Int32",
   544  			"git.sr.ht/~sircmpwn/gqlgen/graphql.Int64",
   545  		}},
   546  		"ID": {
   547  			Model: StringList{
   548  				"git.sr.ht/~sircmpwn/gqlgen/graphql.ID",
   549  				"git.sr.ht/~sircmpwn/gqlgen/graphql.IntID",
   550  			},
   551  		},
   552  	}
   553  
   554  	for typeName, entry := range builtins {
   555  		if !c.Models.Exists(typeName) {
   556  			c.Models[typeName] = entry
   557  		}
   558  	}
   559  
   560  	// These are additional types that are injected if defined in the schema as scalars.
   561  	extraBuiltins := TypeMap{
   562  		"Time":   {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Time"}},
   563  		"Map":    {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Map"}},
   564  		"Upload": {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Upload"}},
   565  		"Any":    {Model: StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Any"}},
   566  	}
   567  
   568  	for typeName, entry := range extraBuiltins {
   569  		if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
   570  			c.Models[typeName] = entry
   571  		}
   572  	}
   573  }
   574  
   575  func (c *Config) LoadSchema() error {
   576  	if c.Packages != nil {
   577  		c.Packages = &code.Packages{}
   578  	}
   579  
   580  	if err := c.check(); err != nil {
   581  		return err
   582  	}
   583  
   584  	schema, err := gqlparser.LoadSchema(c.Sources...)
   585  	if err != nil {
   586  		return err
   587  	}
   588  
   589  	if schema.Query == nil {
   590  		schema.Query = &ast.Definition{
   591  			Kind: ast.Object,
   592  			Name: "Query",
   593  		}
   594  		schema.Types["Query"] = schema.Query
   595  	}
   596  
   597  	c.Schema = schema
   598  	return nil
   599  }
   600  
   601  func abs(path string) string {
   602  	absPath, err := filepath.Abs(path)
   603  	if err != nil {
   604  		panic(err)
   605  	}
   606  	return filepath.ToSlash(absPath)
   607  }