github.com/mariuspot/gqlgen@v0.7.2/codegen/codegen.go (about)

     1  package codegen
     2  
     3  import (
     4  	"log"
     5  	"os"
     6  	"path/filepath"
     7  	"regexp"
     8  	"syscall"
     9  
    10  	"github.com/99designs/gqlgen/codegen/templates"
    11  	"github.com/pkg/errors"
    12  	"github.com/vektah/gqlparser"
    13  	"github.com/vektah/gqlparser/ast"
    14  	"github.com/vektah/gqlparser/gqlerror"
    15  )
    16  
    17  func Generate(cfg Config) error {
    18  	if err := cfg.normalize(); err != nil {
    19  		return err
    20  	}
    21  
    22  	_ = syscall.Unlink(cfg.Exec.Filename)
    23  	_ = syscall.Unlink(cfg.Model.Filename)
    24  
    25  	modelsBuild, err := cfg.models()
    26  	if err != nil {
    27  		return errors.Wrap(err, "model plan failed")
    28  	}
    29  	if len(modelsBuild.Models) > 0 || len(modelsBuild.Enums) > 0 {
    30  		if err = templates.RenderToFile("models.gotpl", cfg.Model.Filename, modelsBuild); err != nil {
    31  			return err
    32  		}
    33  
    34  		for _, model := range modelsBuild.Models {
    35  			modelCfg := cfg.Models[model.GQLType]
    36  			modelCfg.Model = cfg.Model.ImportPath() + "." + model.GoType
    37  			cfg.Models[model.GQLType] = modelCfg
    38  		}
    39  
    40  		for _, enum := range modelsBuild.Enums {
    41  			modelCfg := cfg.Models[enum.GQLType]
    42  			modelCfg.Model = cfg.Model.ImportPath() + "." + enum.GoType
    43  			cfg.Models[enum.GQLType] = modelCfg
    44  		}
    45  	}
    46  
    47  	build, err := cfg.bind()
    48  	if err != nil {
    49  		return errors.Wrap(err, "exec plan failed")
    50  	}
    51  
    52  	if err := templates.RenderToFile("generated.gotpl", cfg.Exec.Filename, build); err != nil {
    53  		return err
    54  	}
    55  
    56  	if cfg.Resolver.IsDefined() {
    57  		if err := generateResolver(cfg); err != nil {
    58  			return errors.Wrap(err, "generating resolver failed")
    59  		}
    60  	}
    61  
    62  	if err := cfg.validate(); err != nil {
    63  		return errors.Wrap(err, "validation failed")
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  func GenerateServer(cfg Config, filename string) error {
    70  	if err := cfg.Exec.normalize(); err != nil {
    71  		return errors.Wrap(err, "exec")
    72  	}
    73  	if err := cfg.Resolver.normalize(); err != nil {
    74  		return errors.Wrap(err, "resolver")
    75  	}
    76  
    77  	serverFilename := abs(filename)
    78  	serverBuild := cfg.server(filepath.Dir(serverFilename))
    79  
    80  	if _, err := os.Stat(serverFilename); os.IsNotExist(errors.Cause(err)) {
    81  		err = templates.RenderToFile("server.gotpl", serverFilename, serverBuild)
    82  		if err != nil {
    83  			return errors.Wrap(err, "generate server failed")
    84  		}
    85  	} else {
    86  		log.Printf("Skipped server: %s already exists\n", serverFilename)
    87  	}
    88  	return nil
    89  }
    90  
    91  func generateResolver(cfg Config) error {
    92  	resolverBuild, err := cfg.resolver()
    93  	if err != nil {
    94  		return errors.Wrap(err, "resolver build failed")
    95  	}
    96  	filename := cfg.Resolver.Filename
    97  
    98  	if resolverBuild.ResolverFound {
    99  		log.Printf("Skipped resolver: %s.%s already exists\n", cfg.Resolver.ImportPath(), cfg.Resolver.Type)
   100  		return nil
   101  	}
   102  
   103  	if _, err := os.Stat(filename); os.IsNotExist(errors.Cause(err)) {
   104  		if err := templates.RenderToFile("resolver.gotpl", filename, resolverBuild); err != nil {
   105  			return err
   106  		}
   107  	} else {
   108  		log.Printf("Skipped resolver: %s already exists\n", filename)
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  func (cfg *Config) normalize() error {
   115  	if err := cfg.Model.normalize(); err != nil {
   116  		return errors.Wrap(err, "model")
   117  	}
   118  
   119  	if err := cfg.Exec.normalize(); err != nil {
   120  		return errors.Wrap(err, "exec")
   121  	}
   122  
   123  	if cfg.Resolver.IsDefined() {
   124  		if err := cfg.Resolver.normalize(); err != nil {
   125  			return errors.Wrap(err, "resolver")
   126  		}
   127  	}
   128  
   129  	builtins := TypeMap{
   130  		"__Directive":  {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
   131  		"__Type":       {Model: "github.com/99designs/gqlgen/graphql/introspection.Type"},
   132  		"__Field":      {Model: "github.com/99designs/gqlgen/graphql/introspection.Field"},
   133  		"__EnumValue":  {Model: "github.com/99designs/gqlgen/graphql/introspection.EnumValue"},
   134  		"__InputValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.InputValue"},
   135  		"__Schema":     {Model: "github.com/99designs/gqlgen/graphql/introspection.Schema"},
   136  		"Int":          {Model: "github.com/99designs/gqlgen/graphql.Int"},
   137  		"Float":        {Model: "github.com/99designs/gqlgen/graphql.Float"},
   138  		"String":       {Model: "github.com/99designs/gqlgen/graphql.String"},
   139  		"Boolean":      {Model: "github.com/99designs/gqlgen/graphql.Boolean"},
   140  		"ID":           {Model: "github.com/99designs/gqlgen/graphql.ID"},
   141  		"Time":         {Model: "github.com/99designs/gqlgen/graphql.Time"},
   142  		"Map":          {Model: "github.com/99designs/gqlgen/graphql.Map"},
   143  	}
   144  
   145  	if cfg.Models == nil {
   146  		cfg.Models = TypeMap{}
   147  	}
   148  	for typeName, entry := range builtins {
   149  		if !cfg.Models.Exists(typeName) {
   150  			cfg.Models[typeName] = entry
   151  		}
   152  	}
   153  
   154  	var sources []*ast.Source
   155  	for _, filename := range cfg.SchemaFilename {
   156  		sources = append(sources, &ast.Source{Name: filename, Input: cfg.SchemaStr[filename]})
   157  	}
   158  
   159  	var err *gqlerror.Error
   160  	cfg.schema, err = gqlparser.LoadSchema(sources...)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	return nil
   165  }
   166  
   167  var invalidPackageNameChar = regexp.MustCompile(`[^\w]`)
   168  
   169  func sanitizePackageName(pkg string) string {
   170  	return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_")
   171  }
   172  
   173  func abs(path string) string {
   174  	absPath, err := filepath.Abs(path)
   175  	if err != nil {
   176  		panic(err)
   177  	}
   178  	return filepath.ToSlash(absPath)
   179  }