github.com/fortexxx/gqlgen@v0.10.3-0.20191216030626-ca5ea8b21ead/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"regexp"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/99designs/gqlgen/internal/code"
    14  	"github.com/pkg/errors"
    15  	"github.com/vektah/gqlparser"
    16  	"github.com/vektah/gqlparser/ast"
    17  	"golang.org/x/tools/go/packages"
    18  	"gopkg.in/yaml.v2"
    19  )
    20  
    21  type Config struct {
    22  	SchemaFilename           StringList                 `yaml:"schema,omitempty"`
    23  	Exec                     PackageConfig              `yaml:"exec"`
    24  	Model                    PackageConfig              `yaml:"model,omitempty"`
    25  	Resolver                 PackageConfig              `yaml:"resolver,omitempty"`
    26  	Service           PackageConfig              		`yaml:"service,omitempty"`
    27  	AutoBind                 []string                   `yaml:"autobind"`
    28  	Models                   TypeMap                    `yaml:"models,omitempty"`
    29  	StructTag                string                     `yaml:"struct_tag,omitempty"`
    30  	Directives               map[string]DirectiveConfig `yaml:"directives,omitempty"`
    31  	Federated                bool                       `yaml:"federated,omitempty"`
    32  	AdditionalSources        []*ast.Source              `yaml:"-"`
    33  	OmitSliceElementPointers bool                       `yaml:"omit_slice_element_pointers,omitempty"`
    34  	SkipValidation           bool                       `yaml:"skip_validation,omitempty"`
    35  }
    36  
    37  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    38  
    39  // DefaultConfig creates a copy of the default config
    40  func DefaultConfig() *Config {
    41  	return &Config{
    42  		SchemaFilename: StringList{"schema.graphql"},
    43  		Model:          PackageConfig{Filename: "models_gen.go"},
    44  		Exec:           PackageConfig{Filename: "generated.go"},
    45  		Service:        PackageConfig{Filename: "service.go"},
    46  		Directives:     map[string]DirectiveConfig{},
    47  	}
    48  }
    49  
    50  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    51  // walking up the tree. The closest config file will be returned.
    52  func LoadConfigFromDefaultLocations() (*Config, error) {
    53  	cfgFile, err := findCfg()
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	err = os.Chdir(filepath.Dir(cfgFile))
    59  	if err != nil {
    60  		return nil, errors.Wrap(err, "unable to enter config dir")
    61  	}
    62  	return LoadConfig(cfgFile)
    63  }
    64  
    65  var path2regex = strings.NewReplacer(
    66  	`.`, `\.`,
    67  	`*`, `.+`,
    68  	`\`, `[\\/]`,
    69  	`/`, `[\\/]`,
    70  )
    71  
    72  // LoadConfig reads the gqlgen.yml config file
    73  func LoadConfig(filename string) (*Config, error) {
    74  	config := DefaultConfig()
    75  
    76  	b, err := ioutil.ReadFile(filename)
    77  	if err != nil {
    78  		return nil, errors.Wrap(err, "unable to read config")
    79  	}
    80  
    81  	if err := yaml.UnmarshalStrict(b, config); err != nil {
    82  		return nil, errors.Wrap(err, "unable to parse config")
    83  	}
    84  
    85  	defaultDirectives := map[string]DirectiveConfig{
    86  		"skip":       {SkipRuntime: true},
    87  		"include":    {SkipRuntime: true},
    88  		"deprecated": {SkipRuntime: true},
    89  	}
    90  
    91  	for key, value := range defaultDirectives {
    92  		if _, defined := config.Directives[key]; !defined {
    93  			config.Directives[key] = value
    94  		}
    95  	}
    96  
    97  	preGlobbing := config.SchemaFilename
    98  	config.SchemaFilename = StringList{}
    99  	for _, f := range preGlobbing {
   100  		var matches []string
   101  
   102  		// for ** we want to override default globbing patterns and walk all
   103  		// subdirectories to match schema files.
   104  		if strings.Contains(f, "**") {
   105  			pathParts := strings.SplitN(f, "**", 2)
   106  			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
   107  			// turn the rest of the glob into a regex, anchored only at the end because ** allows
   108  			// for any number of dirs in between and walk will let us match against the full path name
   109  			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
   110  
   111  			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
   112  				if err != nil {
   113  					return err
   114  				}
   115  
   116  				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
   117  					matches = append(matches, path)
   118  				}
   119  
   120  				return nil
   121  			}); err != nil {
   122  				return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
   123  			}
   124  		} else {
   125  			matches, err = filepath.Glob(f)
   126  			if err != nil {
   127  				return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
   128  			}
   129  		}
   130  
   131  		for _, m := range matches {
   132  			if config.SchemaFilename.Has(m) {
   133  				continue
   134  			}
   135  			config.SchemaFilename = append(config.SchemaFilename, m)
   136  		}
   137  	}
   138  
   139  	return config, nil
   140  }
   141  
   142  type PackageConfig struct {
   143  	Filename string `yaml:"filename,omitempty"`
   144  	Package  string `yaml:"package,omitempty"`
   145  	Type     string `yaml:"type,omitempty"`
   146  }
   147  
   148  type TypeMapEntry struct {
   149  	Model  StringList              `yaml:"model"`
   150  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
   151  }
   152  
   153  type TypeMapField struct {
   154  	Resolver  bool   `yaml:"resolver"`
   155  	FieldName string `yaml:"fieldName"`
   156  }
   157  
   158  type StringList []string
   159  
   160  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   161  	var single string
   162  	err := unmarshal(&single)
   163  	if err == nil {
   164  		*a = []string{single}
   165  		return nil
   166  	}
   167  
   168  	var multi []string
   169  	err = unmarshal(&multi)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	*a = multi
   175  	return nil
   176  }
   177  
   178  func (a StringList) Has(file string) bool {
   179  	for _, existing := range a {
   180  		if existing == file {
   181  			return true
   182  		}
   183  	}
   184  	return false
   185  }
   186  
   187  func (c *PackageConfig) normalize() error {
   188  	if c.Filename == "" {
   189  		return errors.New("Filename is required")
   190  	}
   191  	c.Filename = abs(c.Filename)
   192  	// If Package is not set, first attempt to load the package at the output dir. If that fails
   193  	// fallback to just the base dir name of the output filename.
   194  	if c.Package == "" {
   195  		c.Package = code.NameForDir(c.Dir())
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  func (c *PackageConfig) ImportPath() string {
   202  	return code.ImportPathForDir(c.Dir())
   203  }
   204  
   205  func (c *PackageConfig) Dir() string {
   206  	return filepath.Dir(c.Filename)
   207  }
   208  
   209  func (c *PackageConfig) Check() error {
   210  	if strings.ContainsAny(c.Package, "./\\") {
   211  		return fmt.Errorf("package should be the output package name only, do not include the output filename")
   212  	}
   213  	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
   214  		return fmt.Errorf("filename should be path to a go source file")
   215  	}
   216  
   217  	return c.normalize()
   218  }
   219  
   220  func (c *PackageConfig) Pkg() *types.Package {
   221  	return types.NewPackage(c.ImportPath(), c.Dir())
   222  }
   223  
   224  func (c *PackageConfig) IsDefined() bool {
   225  	return c.Filename != ""
   226  }
   227  
   228  func (c *Config) Check() error {
   229  	if err := c.Models.Check(); err != nil {
   230  		return errors.Wrap(err, "config.models")
   231  	}
   232  	if err := c.Exec.Check(); err != nil {
   233  		return errors.Wrap(err, "config.exec")
   234  	}
   235  	if c.Model.IsDefined() {
   236  		if err := c.Model.Check(); err != nil {
   237  			return errors.Wrap(err, "config.model")
   238  		}
   239  	}
   240  	if c.Resolver.IsDefined() {
   241  		if err := c.Resolver.Check(); err != nil {
   242  			return errors.Wrap(err, "config.resolver")
   243  		}
   244  	}
   245  	if c.Service.IsDefined() {
   246  		if err := c.Service.Check(); err != nil {
   247  			return errors.Wrap(err, "config.service")
   248  		}
   249  	}
   250  
   251  	// check packages names against conflict, if present in the same dir
   252  	// and check filenames for uniqueness
   253  	packageConfigList := []PackageConfig{}
   254  	if c.Model.IsDefined() {
   255  		packageConfigList = append(packageConfigList, c.Model)
   256  	}
   257  	packageConfigList = append(packageConfigList, []PackageConfig{
   258  		c.Exec,
   259  		c.Resolver,
   260  	}...)
   261  	filesMap := make(map[string]bool)
   262  	pkgConfigsByDir := make(map[string]PackageConfig)
   263  	for _, current := range packageConfigList {
   264  		_, fileFound := filesMap[current.Filename]
   265  		if fileFound {
   266  			return fmt.Errorf("filename %s defined more than once", current.Filename)
   267  		}
   268  		filesMap[current.Filename] = true
   269  		previous, inSameDir := pkgConfigsByDir[current.Dir()]
   270  		if inSameDir && current.Package != previous.Package {
   271  			return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
   272  		}
   273  		pkgConfigsByDir[current.Dir()] = current
   274  	}
   275  
   276  	return c.normalize()
   277  }
   278  
   279  func stripPath(path string) string {
   280  	return filepath.Base(path)
   281  }
   282  
   283  type TypeMap map[string]TypeMapEntry
   284  
   285  func (tm TypeMap) Exists(typeName string) bool {
   286  	_, ok := tm[typeName]
   287  	return ok
   288  }
   289  
   290  func (tm TypeMap) UserDefined(typeName string) bool {
   291  	m, ok := tm[typeName]
   292  	return ok && len(m.Model) > 0
   293  }
   294  
   295  func (tm TypeMap) Check() error {
   296  	for typeName, entry := range tm {
   297  		for _, model := range entry.Model {
   298  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   299  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   300  			}
   301  		}
   302  	}
   303  	return nil
   304  }
   305  
   306  func (tm TypeMap) ReferencedPackages() []string {
   307  	var pkgs []string
   308  
   309  	for _, typ := range tm {
   310  		for _, model := range typ.Model {
   311  			if model == "map[string]interface{}" || model == "interface{}" {
   312  				continue
   313  			}
   314  			pkg, _ := code.PkgAndType(model)
   315  			if pkg == "" || inStrSlice(pkgs, pkg) {
   316  				continue
   317  			}
   318  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   319  		}
   320  	}
   321  
   322  	sort.Slice(pkgs, func(i, j int) bool {
   323  		return pkgs[i] > pkgs[j]
   324  	})
   325  	return pkgs
   326  }
   327  
   328  func (tm TypeMap) Add(name string, goType string) {
   329  	modelCfg := tm[name]
   330  	modelCfg.Model = append(modelCfg.Model, goType)
   331  	tm[name] = modelCfg
   332  }
   333  
   334  type DirectiveConfig struct {
   335  	SkipRuntime bool `yaml:"skip_runtime"`
   336  }
   337  
   338  func inStrSlice(haystack []string, needle string) bool {
   339  	for _, v := range haystack {
   340  		if needle == v {
   341  			return true
   342  		}
   343  	}
   344  
   345  	return false
   346  }
   347  
   348  // findCfg searches for the config file in this directory and all parents up the tree
   349  // looking for the closest match
   350  func findCfg() (string, error) {
   351  	dir, err := os.Getwd()
   352  	if err != nil {
   353  		return "", errors.Wrap(err, "unable to get working dir to findCfg")
   354  	}
   355  
   356  	cfg := findCfgInDir(dir)
   357  
   358  	for cfg == "" && dir != filepath.Dir(dir) {
   359  		dir = filepath.Dir(dir)
   360  		cfg = findCfgInDir(dir)
   361  	}
   362  
   363  	if cfg == "" {
   364  		return "", os.ErrNotExist
   365  	}
   366  
   367  	return cfg, nil
   368  }
   369  
   370  func findCfgInDir(dir string) string {
   371  	for _, cfgName := range cfgFilenames {
   372  		path := filepath.Join(dir, cfgName)
   373  		if _, err := os.Stat(path); err == nil {
   374  			return path
   375  		}
   376  	}
   377  	return ""
   378  }
   379  
   380  func (c *Config) normalize() error {
   381  	if c.Model.IsDefined() {
   382  		if err := c.Model.normalize(); err != nil {
   383  			return errors.Wrap(err, "model")
   384  		}
   385  	}
   386  
   387  	if err := c.Exec.normalize(); err != nil {
   388  		return errors.Wrap(err, "exec")
   389  	}
   390  
   391  	if c.Resolver.IsDefined() {
   392  		if err := c.Resolver.normalize(); err != nil {
   393  			return errors.Wrap(err, "resolver")
   394  		}
   395  	}
   396  
   397  	if c.Service.IsDefined() {
   398  		if err := c.Service.normalize(); err != nil {
   399  			return errors.Wrap(err, "service")
   400  		}
   401  	}
   402  
   403  	if c.Models == nil {
   404  		c.Models = TypeMap{}
   405  	}
   406  
   407  	return nil
   408  }
   409  
   410  func (c *Config) Autobind(s *ast.Schema) error {
   411  	if len(c.AutoBind) == 0 {
   412  		return nil
   413  	}
   414  	ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
   415  	if err != nil {
   416  		return err
   417  	}
   418  
   419  	for _, t := range s.Types {
   420  		if c.Models.UserDefined(t.Name) {
   421  			continue
   422  		}
   423  
   424  		for _, p := range ps {
   425  			if t := p.Types.Scope().Lookup(t.Name); t != nil {
   426  				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
   427  				break
   428  			}
   429  		}
   430  	}
   431  
   432  	for i, t := range c.Models {
   433  		for j, m := range t.Model {
   434  			pkg, typename := code.PkgAndType(m)
   435  
   436  			// skip anything that looks like an import path
   437  			if strings.Contains(pkg, "/") {
   438  				continue
   439  			}
   440  
   441  			for _, p := range ps {
   442  				if p.Name != pkg {
   443  					continue
   444  				}
   445  				if t := p.Types.Scope().Lookup(typename); t != nil {
   446  					c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
   447  					break
   448  				}
   449  			}
   450  		}
   451  	}
   452  
   453  	return nil
   454  }
   455  
   456  func (c *Config) InjectBuiltins(s *ast.Schema) {
   457  	builtins := TypeMap{
   458  		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
   459  		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   460  		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
   461  		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   462  		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
   463  		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
   464  		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
   465  		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
   466  		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
   467  		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
   468  		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
   469  		"Int": {Model: StringList{
   470  			"github.com/99designs/gqlgen/graphql.Int",
   471  			"github.com/99designs/gqlgen/graphql.Int32",
   472  			"github.com/99designs/gqlgen/graphql.Int64",
   473  		}},
   474  		"ID": {
   475  			Model: StringList{
   476  				"github.com/99designs/gqlgen/graphql.ID",
   477  				"github.com/99designs/gqlgen/graphql.IntID",
   478  			},
   479  		},
   480  	}
   481  
   482  	for typeName, entry := range builtins {
   483  		if !c.Models.Exists(typeName) {
   484  			c.Models[typeName] = entry
   485  		}
   486  	}
   487  
   488  	// These are additional types that are injected if defined in the schema as scalars.
   489  	extraBuiltins := TypeMap{
   490  		"Time":   {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
   491  		"Map":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
   492  		"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
   493  		"Any":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
   494  	}
   495  
   496  	for typeName, entry := range extraBuiltins {
   497  		if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
   498  			c.Models[typeName] = entry
   499  		}
   500  	}
   501  }
   502  
   503  func (c *Config) LoadSchema() (*ast.Schema, error) {
   504  	sources := append([]*ast.Source{}, c.AdditionalSources...)
   505  	for _, filename := range c.SchemaFilename {
   506  		filename = filepath.ToSlash(filename)
   507  		var err error
   508  		var schemaRaw []byte
   509  		schemaRaw, err = ioutil.ReadFile(filename)
   510  		if err != nil {
   511  			fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
   512  			os.Exit(1)
   513  		}
   514  		sources = append(sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
   515  	}
   516  	schema, err := gqlparser.LoadSchema(sources...)
   517  	if err != nil {
   518  		return nil, err
   519  	}
   520  	return schema, nil
   521  }
   522  
   523  func abs(path string) string {
   524  	absPath, err := filepath.Abs(path)
   525  	if err != nil {
   526  		panic(err)
   527  	}
   528  	return filepath.ToSlash(absPath)
   529  }