github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/codegen/config/config.go (about)

     1  package config
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/HaswinVidanage/gqlgen/internal/code"
    13  	"github.com/pkg/errors"
    14  	"github.com/vektah/gqlparser"
    15  	"github.com/vektah/gqlparser/ast"
    16  	yaml "gopkg.in/yaml.v2"
    17  )
    18  
    19  type Config struct {
    20  	SchemaFilename StringList    `yaml:"schema,omitempty"`
    21  	Exec           PackageConfig `yaml:"exec"`
    22  	Model          PackageConfig `yaml:"model"`
    23  	Resolver       PackageConfig `yaml:"resolver,omitempty"`
    24  	Models         TypeMap       `yaml:"models,omitempty"`
    25  	StructTag      string        `yaml:"struct_tag,omitempty"`
    26  }
    27  
    28  var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
    29  
    30  // DefaultConfig creates a copy of the default config
    31  func DefaultConfig() *Config {
    32  	return &Config{
    33  		SchemaFilename: StringList{"schema.graphql"},
    34  		Model:          PackageConfig{Filename: "models_gen.go"},
    35  		Exec:           PackageConfig{Filename: "generated.go"},
    36  	}
    37  }
    38  
    39  // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
    40  // walking up the tree. The closest config file will be returned.
    41  func LoadConfigFromDefaultLocations() (*Config, error) {
    42  	cfgFile, err := findCfg()
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	err = os.Chdir(filepath.Dir(cfgFile))
    48  	if err != nil {
    49  		return nil, errors.Wrap(err, "unable to enter config dir")
    50  	}
    51  	return LoadConfig(cfgFile)
    52  }
    53  
    54  // LoadConfig reads the gqlgen.yml config file
    55  func LoadConfig(filename string) (*Config, error) {
    56  	config := DefaultConfig()
    57  
    58  	b, err := ioutil.ReadFile(filename)
    59  	if err != nil {
    60  		return nil, errors.Wrap(err, "unable to read config")
    61  	}
    62  
    63  	if err := yaml.UnmarshalStrict(b, config); err != nil {
    64  		return nil, errors.Wrap(err, "unable to parse config")
    65  	}
    66  
    67  	preGlobbing := config.SchemaFilename
    68  	config.SchemaFilename = StringList{}
    69  	for _, f := range preGlobbing {
    70  		matches, err := filepath.Glob(f)
    71  		if err != nil {
    72  			return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
    73  		}
    74  
    75  		for _, m := range matches {
    76  			if config.SchemaFilename.Has(m) {
    77  				continue
    78  			}
    79  			config.SchemaFilename = append(config.SchemaFilename, m)
    80  		}
    81  	}
    82  
    83  	return config, nil
    84  }
    85  
    86  type PackageConfig struct {
    87  	Filename string `yaml:"filename,omitempty"`
    88  	Package  string `yaml:"package,omitempty"`
    89  	Type     string `yaml:"type,omitempty"`
    90  }
    91  
    92  type TypeMapEntry struct {
    93  	Model  StringList              `yaml:"model"`
    94  	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
    95  }
    96  
    97  type TypeMapField struct {
    98  	Resolver  bool   `yaml:"resolver"`
    99  	FieldName string `yaml:"fieldName"`
   100  }
   101  
   102  type StringList []string
   103  
   104  func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
   105  	var single string
   106  	err := unmarshal(&single)
   107  	if err == nil {
   108  		*a = []string{single}
   109  		return nil
   110  	}
   111  
   112  	var multi []string
   113  	err = unmarshal(&multi)
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	*a = multi
   119  	return nil
   120  }
   121  
   122  func (a StringList) Has(file string) bool {
   123  	for _, existing := range a {
   124  		if existing == file {
   125  			return true
   126  		}
   127  	}
   128  	return false
   129  }
   130  
   131  func (c *PackageConfig) normalize() error {
   132  	if c.Filename == "" {
   133  		return errors.New("Filename is required")
   134  	}
   135  	c.Filename = abs(c.Filename)
   136  	// If Package is not set, first attempt to load the package at the output dir. If that fails
   137  	// fallback to just the base dir name of the output filename.
   138  	if c.Package == "" {
   139  		c.Package = code.NameForPackage(c.ImportPath())
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  func (c *PackageConfig) ImportPath() string {
   146  	return code.ImportPathForDir(c.Dir())
   147  }
   148  
   149  func (c *PackageConfig) Dir() string {
   150  	return filepath.Dir(c.Filename)
   151  }
   152  
   153  func (c *PackageConfig) Check() error {
   154  	if strings.ContainsAny(c.Package, "./\\") {
   155  		return fmt.Errorf("package should be the output package name only, do not include the output filename")
   156  	}
   157  	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
   158  		return fmt.Errorf("filename should be path to a go source file")
   159  	}
   160  
   161  	return c.normalize()
   162  }
   163  
   164  func (c *PackageConfig) Pkg() *types.Package {
   165  	return types.NewPackage(c.ImportPath(), c.Dir())
   166  }
   167  
   168  func (c *PackageConfig) IsDefined() bool {
   169  	return c.Filename != ""
   170  }
   171  
   172  func (c *Config) Check() error {
   173  	if err := c.Models.Check(); err != nil {
   174  		return errors.Wrap(err, "config.models")
   175  	}
   176  	if err := c.Exec.Check(); err != nil {
   177  		return errors.Wrap(err, "config.exec")
   178  	}
   179  	if err := c.Model.Check(); err != nil {
   180  		return errors.Wrap(err, "config.model")
   181  	}
   182  	if c.Resolver.IsDefined() {
   183  		if err := c.Resolver.Check(); err != nil {
   184  			return errors.Wrap(err, "config.resolver")
   185  		}
   186  	}
   187  
   188  	// check packages names against conflict, if present in the same dir
   189  	// and check filenames for uniqueness
   190  	packageConfigList := []PackageConfig{
   191  		c.Model,
   192  		c.Exec,
   193  		c.Resolver,
   194  	}
   195  	filesMap := make(map[string]bool)
   196  	pkgConfigsByDir := make(map[string]PackageConfig)
   197  	for _, current := range packageConfigList {
   198  		_, fileFound := filesMap[current.Filename]
   199  		if fileFound {
   200  			return fmt.Errorf("filename %s defined more than once", current.Filename)
   201  		}
   202  		filesMap[current.Filename] = true
   203  		previous, inSameDir := pkgConfigsByDir[current.Dir()]
   204  		if inSameDir && current.Package != previous.Package {
   205  			return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
   206  		}
   207  		pkgConfigsByDir[current.Dir()] = current
   208  	}
   209  
   210  	return c.normalize()
   211  }
   212  
   213  func stripPath(path string) string {
   214  	return filepath.Base(path)
   215  }
   216  
   217  type TypeMap map[string]TypeMapEntry
   218  
   219  func (tm TypeMap) Exists(typeName string) bool {
   220  	_, ok := tm[typeName]
   221  	return ok
   222  }
   223  
   224  func (tm TypeMap) UserDefined(typeName string) bool {
   225  	m, ok := tm[typeName]
   226  	return ok && len(m.Model) > 0
   227  }
   228  
   229  func (tm TypeMap) Check() error {
   230  	for typeName, entry := range tm {
   231  		for _, model := range entry.Model {
   232  			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
   233  				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
   234  			}
   235  		}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (tm TypeMap) ReferencedPackages() []string {
   241  	var pkgs []string
   242  
   243  	for _, typ := range tm {
   244  		for _, model := range typ.Model {
   245  			if model == "map[string]interface{}" || model == "interface{}" {
   246  				continue
   247  			}
   248  			pkg, _ := code.PkgAndType(model)
   249  			if pkg == "" || inStrSlice(pkgs, pkg) {
   250  				continue
   251  			}
   252  			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
   253  		}
   254  	}
   255  
   256  	sort.Slice(pkgs, func(i, j int) bool {
   257  		return pkgs[i] > pkgs[j]
   258  	})
   259  	return pkgs
   260  }
   261  
   262  func (tm TypeMap) Add(Name string, goType string) {
   263  	modelCfg := tm[Name]
   264  	modelCfg.Model = append(modelCfg.Model, goType)
   265  	tm[Name] = modelCfg
   266  }
   267  
   268  func inStrSlice(haystack []string, needle string) bool {
   269  	for _, v := range haystack {
   270  		if needle == v {
   271  			return true
   272  		}
   273  	}
   274  
   275  	return false
   276  }
   277  
   278  // findCfg searches for the config file in this directory and all parents up the tree
   279  // looking for the closest match
   280  func findCfg() (string, error) {
   281  	dir, err := os.Getwd()
   282  	if err != nil {
   283  		return "", errors.Wrap(err, "unable to get working dir to findCfg")
   284  	}
   285  
   286  	cfg := findCfgInDir(dir)
   287  
   288  	for cfg == "" && dir != filepath.Dir(dir) {
   289  		dir = filepath.Dir(dir)
   290  		cfg = findCfgInDir(dir)
   291  	}
   292  
   293  	if cfg == "" {
   294  		return "", os.ErrNotExist
   295  	}
   296  
   297  	return cfg, nil
   298  }
   299  
   300  func findCfgInDir(dir string) string {
   301  	for _, cfgName := range cfgFilenames {
   302  		path := filepath.Join(dir, cfgName)
   303  		if _, err := os.Stat(path); err == nil {
   304  			return path
   305  		}
   306  	}
   307  	return ""
   308  }
   309  
   310  func (c *Config) normalize() error {
   311  	if err := c.Model.normalize(); err != nil {
   312  		return errors.Wrap(err, "model")
   313  	}
   314  
   315  	if err := c.Exec.normalize(); err != nil {
   316  		return errors.Wrap(err, "exec")
   317  	}
   318  
   319  	if c.Resolver.IsDefined() {
   320  		if err := c.Resolver.normalize(); err != nil {
   321  			return errors.Wrap(err, "resolver")
   322  		}
   323  	}
   324  
   325  	if c.Models == nil {
   326  		c.Models = TypeMap{}
   327  	}
   328  
   329  	return nil
   330  }
   331  
   332  func (c *Config) InjectBuiltins(s *ast.Schema) {
   333  	builtins := TypeMap{
   334  		"__Directive":         {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.Directive"}},
   335  		"__DirectiveLocation": {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.String"}},
   336  		"__Type":              {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.Type"}},
   337  		"__TypeKind":          {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.String"}},
   338  		"__Field":             {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.Field"}},
   339  		"__EnumValue":         {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.EnumValue"}},
   340  		"__InputValue":        {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.InputValue"}},
   341  		"__Schema":            {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql/introspection.Schema"}},
   342  		"Float":               {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.Float"}},
   343  		"String":              {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.String"}},
   344  		"Boolean":             {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.Boolean"}},
   345  		"Time":                {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.Time"}},
   346  		"Map":                 {Model: StringList{"github.com/HaswinVidanage/gqlgen/graphql.Map"}},
   347  		"Int": {Model: StringList{
   348  			"github.com/HaswinVidanage/gqlgen/graphql.Int",
   349  			"github.com/HaswinVidanage/gqlgen/graphql.Int32",
   350  			"github.com/HaswinVidanage/gqlgen/graphql.Int64",
   351  		}},
   352  		"ID": {
   353  			Model: StringList{
   354  				"github.com/HaswinVidanage/gqlgen/graphql.ID",
   355  				"github.com/HaswinVidanage/gqlgen/graphql.IntID",
   356  			},
   357  		},
   358  	}
   359  
   360  	for typeName, entry := range builtins {
   361  		if !c.Models.Exists(typeName) {
   362  			c.Models[typeName] = entry
   363  		}
   364  	}
   365  }
   366  
   367  func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
   368  	schemaStrings := map[string]string{}
   369  
   370  	var sources []*ast.Source
   371  
   372  	for _, filename := range c.SchemaFilename {
   373  		filename = filepath.ToSlash(filename)
   374  		var err error
   375  		var schemaRaw []byte
   376  		schemaRaw, err = ioutil.ReadFile(filename)
   377  		if err != nil {
   378  			fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
   379  			os.Exit(1)
   380  		}
   381  		schemaStrings[filename] = string(schemaRaw)
   382  		sources = append(sources, &ast.Source{Name: filename, Input: schemaStrings[filename]})
   383  	}
   384  
   385  	schema, err := gqlparser.LoadSchema(sources...)
   386  	if err != nil {
   387  		return nil, nil, err
   388  	}
   389  	return schema, schemaStrings, nil
   390  }
   391  
   392  func abs(path string) string {
   393  	absPath, err := filepath.Abs(path)
   394  	if err != nil {
   395  		panic(err)
   396  	}
   397  	return filepath.ToSlash(absPath)
   398  }