github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/plugin/resolvergen/resolver.go (about)

     1  package resolvergen
     2  
     3  import (
     4  	_ "embed"
     5  	"errors"
     6  	"fmt"
     7  	"go/ast"
     8  	"io/fs"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"golang.org/x/text/cases"
    14  	"golang.org/x/text/language"
    15  
    16  	"github.com/niko0xdev/gqlgen/codegen"
    17  	"github.com/niko0xdev/gqlgen/codegen/config"
    18  	"github.com/niko0xdev/gqlgen/codegen/templates"
    19  	"github.com/niko0xdev/gqlgen/graphql"
    20  	"github.com/niko0xdev/gqlgen/internal/rewrite"
    21  	"github.com/niko0xdev/gqlgen/plugin"
    22  )
    23  
    24  //go:embed resolver.gotpl
    25  var resolverTemplate string
    26  
    27  func New() plugin.Plugin {
    28  	return &Plugin{}
    29  }
    30  
    31  type Plugin struct{}
    32  
    33  var _ plugin.CodeGenerator = &Plugin{}
    34  
    35  func (m *Plugin) Name() string {
    36  	return "resolvergen"
    37  }
    38  
    39  func (m *Plugin) GenerateCode(data *codegen.Data) error {
    40  	if !data.Config.Resolver.IsDefined() {
    41  		return nil
    42  	}
    43  
    44  	switch data.Config.Resolver.Layout {
    45  	case config.LayoutSingleFile:
    46  		return m.generateSingleFile(data)
    47  	case config.LayoutFollowSchema:
    48  		return m.generatePerSchema(data)
    49  	}
    50  
    51  	return nil
    52  }
    53  
    54  func (m *Plugin) generateSingleFile(data *codegen.Data) error {
    55  	file := File{}
    56  
    57  	if _, err := os.Stat(data.Config.Resolver.Filename); err == nil {
    58  		// file already exists and we do not support updating resolvers with layout = single so just return
    59  		return nil
    60  	}
    61  
    62  	for _, o := range data.Objects {
    63  		if o.HasResolvers() {
    64  			file.Objects = append(file.Objects, o)
    65  		}
    66  		for _, f := range o.Fields {
    67  			if !f.IsResolver {
    68  				continue
    69  			}
    70  
    71  			implementFunc := buildImplementationFunc(o, f)
    72  			resolver := Resolver{o, f, nil, "", `panic("not implemented")`, nil, implementFunc}
    73  			file.Resolvers = append(file.Resolvers, &resolver)
    74  		}
    75  	}
    76  
    77  	resolverBuild := &ResolverBuild{
    78  		File:                &file,
    79  		PackageName:         data.Config.Resolver.Package,
    80  		ResolverType:        data.Config.Resolver.Type,
    81  		HasRoot:             true,
    82  		OmitTemplateComment: data.Config.Resolver.OmitTemplateComment,
    83  	}
    84  
    85  	newResolverTemplate := resolverTemplate
    86  	if data.Config.Resolver.ResolverTemplate != "" {
    87  		newResolverTemplate = readResolverTemplate(data.Config.Resolver.ResolverTemplate)
    88  	}
    89  
    90  	return templates.Render(templates.Options{
    91  		PackageName: data.Config.Resolver.Package,
    92  		FileNotice:  `// THIS CODE IS A STARTING POINT ONLY. IT WILL NOT BE UPDATED WITH SCHEMA CHANGES.`,
    93  		Filename:    data.Config.Resolver.Filename,
    94  		Data:        resolverBuild,
    95  		Packages:    data.Config.Packages,
    96  		Template:    newResolverTemplate,
    97  	})
    98  }
    99  
   100  func (m *Plugin) generatePerSchema(data *codegen.Data) error {
   101  	rewriter, err := rewrite.New(data.Config.Resolver.Dir())
   102  	if err != nil {
   103  		return err
   104  	}
   105  
   106  	files := map[string]*File{}
   107  
   108  	objects := make(codegen.Objects, len(data.Objects)+len(data.Inputs))
   109  	copy(objects, data.Objects)
   110  	copy(objects[len(data.Objects):], data.Inputs)
   111  
   112  	for _, o := range objects {
   113  		if o.HasResolvers() {
   114  			fnCase := gqlToResolverName(data.Config.Resolver.Dir(), o.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
   115  			fn := strings.ToLower(fnCase)
   116  			if files[fn] == nil {
   117  				files[fn] = &File{
   118  					name: fnCase,
   119  				}
   120  			}
   121  
   122  			caser := cases.Title(language.English, cases.NoLower)
   123  			rewriter.MarkStructCopied(templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type))
   124  			rewriter.GetMethodBody(data.Config.Resolver.Type, caser.String(o.Name))
   125  			files[fn].Objects = append(files[fn].Objects, o)
   126  		}
   127  		for _, f := range o.Fields {
   128  			if !f.IsResolver {
   129  				continue
   130  			}
   131  
   132  			structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)
   133  			comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`))
   134  			implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
   135  			// if implementation == "" {
   136  			// 	// use default implementation, if no implementation was previously used
   137  			// 	implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
   138  			// }
   139  
   140  			implementFunc := buildImplementationFunc(o, f)
   141  			resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation, nil, implementFunc}
   142  			var implExists bool
   143  			for _, p := range data.Plugins {
   144  				rImpl, ok := p.(plugin.ResolverImplementer)
   145  				if !ok {
   146  					continue
   147  				}
   148  				if implExists {
   149  					return fmt.Errorf("multiple plugins implement ResolverImplementer")
   150  				}
   151  				implExists = true
   152  				resolver.ImplementationRender = rImpl.Implement
   153  			}
   154  			fnCase := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
   155  			fn := strings.ToLower(fnCase)
   156  			if files[fn] == nil {
   157  				files[fn] = &File{
   158  					name: fnCase,
   159  				}
   160  			}
   161  
   162  			files[fn].Resolvers = append(files[fn].Resolvers, &resolver)
   163  		}
   164  	}
   165  
   166  	for _, file := range files {
   167  		file.imports = rewriter.ExistingImports(file.name)
   168  		file.RemainingSource = rewriter.RemainingSource(file.name)
   169  	}
   170  	newResolverTemplate := resolverTemplate
   171  	if data.Config.Resolver.ResolverTemplate != "" {
   172  		newResolverTemplate = readResolverTemplate(data.Config.Resolver.ResolverTemplate)
   173  	}
   174  
   175  	for _, file := range files {
   176  		resolverBuild := &ResolverBuild{
   177  			File:                file,
   178  			PackageName:         data.Config.Resolver.Package,
   179  			ResolverType:        data.Config.Resolver.Type,
   180  			OmitTemplateComment: data.Config.Resolver.OmitTemplateComment,
   181  		}
   182  
   183  		var fileNotice strings.Builder
   184  		if !data.Config.OmitGQLGenFileNotice {
   185  			fileNotice.WriteString(`
   186  			// This file will be automatically regenerated based on the schema, any resolver implementations
   187  			// will be copied through when generating and any unknown code will be moved to the end.
   188  			// Code generated by github.com/niko0xdev/gqlgen`,
   189  			)
   190  			if !data.Config.OmitGQLGenVersionInFileNotice {
   191  				fileNotice.WriteString(` version `)
   192  				fileNotice.WriteString(graphql.GetVersion())
   193  			}
   194  		}
   195  
   196  		err := templates.Render(templates.Options{
   197  			PackageName: data.Config.Resolver.Package,
   198  			FileNotice:  fileNotice.String(),
   199  			Filename:    file.name,
   200  			Data:        resolverBuild,
   201  			Packages:    data.Config.Packages,
   202  			Template:    newResolverTemplate,
   203  		})
   204  		if err != nil {
   205  			return err
   206  		}
   207  	}
   208  
   209  	if _, err := os.Stat(data.Config.Resolver.Filename); errors.Is(err, fs.ErrNotExist) {
   210  		err := templates.Render(templates.Options{
   211  			PackageName: data.Config.Resolver.Package,
   212  			FileNotice: `
   213  				// This file will not be regenerated automatically.
   214  				//
   215  				// It serves as dependency injection for your app, add any dependencies you require here.`,
   216  			Template: `type {{.}} struct {}`,
   217  			Filename: data.Config.Resolver.Filename,
   218  			Data:     data.Config.Resolver.Type,
   219  			Packages: data.Config.Packages,
   220  		})
   221  		if err != nil {
   222  			return err
   223  		}
   224  	}
   225  	return nil
   226  }
   227  
   228  type ResolverBuild struct {
   229  	*File
   230  	HasRoot             bool
   231  	PackageName         string
   232  	ResolverType        string
   233  	OmitTemplateComment bool
   234  }
   235  
   236  type File struct {
   237  	name string
   238  	// These are separated because the type definition of the resolver object may live in a different file from the
   239  	// resolver method implementations, for example when extending a type in a different graphql schema file
   240  	Objects         []*codegen.Object
   241  	Resolvers       []*Resolver
   242  	imports         []rewrite.Import
   243  	RemainingSource string
   244  }
   245  
   246  func (f *File) Imports() string {
   247  	for _, imp := range f.imports {
   248  		if imp.Alias == "" {
   249  			_, _ = templates.CurrentImports.Reserve(imp.ImportPath)
   250  		} else {
   251  			_, _ = templates.CurrentImports.Reserve(imp.ImportPath, imp.Alias)
   252  		}
   253  	}
   254  	return ""
   255  }
   256  
   257  type ResolverType string
   258  
   259  const (
   260  	GetOne  ResolverType = "GET_ONE"
   261  	GetList ResolverType = "GET_LIST"
   262  	Create  ResolverType = "CREATE"
   263  	Update  ResolverType = "UPDATE"
   264  	Delete  ResolverType = "DELETE"
   265  	NA      ResolverType = "NA"
   266  )
   267  
   268  type ResolverFuncFieldMap struct {
   269  	ModelField string
   270  	DtoField   string
   271  }
   272  
   273  type ResolverImplementationFunc struct {
   274  	Type   ResolverType
   275  	Model  string
   276  	Return string
   277  	Fields []ResolverFuncFieldMap
   278  }
   279  
   280  type Resolver struct {
   281  	Object               *codegen.Object
   282  	Field                *codegen.Field
   283  	PrevDecl             *ast.FuncDecl
   284  	Comment              string
   285  	ImplementationStr    string
   286  	ImplementationRender func(r *codegen.Field) string
   287  	ImplementationFunc   *ResolverImplementationFunc
   288  }
   289  
   290  func (r *Resolver) Implementation() string {
   291  	if r.ImplementationRender != nil {
   292  		return r.ImplementationRender(r.Field)
   293  	}
   294  
   295  	implStr := r.ImplementationStr
   296  	if implStr == "" {
   297  		implStr = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", r.Field.GoFieldName, r.Field.Name)
   298  	}
   299  
   300  	return implStr
   301  }
   302  
   303  func gqlToResolverName(base string, gqlname, filenameTmpl string) string {
   304  	gqlname = filepath.Base(gqlname)
   305  	ext := filepath.Ext(gqlname)
   306  	if filenameTmpl == "" {
   307  		filenameTmpl = "{name}.resolvers.go"
   308  	}
   309  	filename := strings.ReplaceAll(filenameTmpl, "{name}", strings.TrimSuffix(gqlname, ext))
   310  	return filepath.Join(base, filename)
   311  }
   312  
   313  func readResolverTemplate(customResolverTemplate string) string {
   314  	contentBytes, err := os.ReadFile(customResolverTemplate)
   315  	if err != nil {
   316  		panic(err)
   317  	}
   318  	return string(contentBytes)
   319  }
   320  
   321  func isPlural(word string) bool {
   322  	suffixes := []string{"s", "es"}
   323  
   324  	for _, suffix := range suffixes {
   325  		if strings.HasSuffix(word, suffix) {
   326  			return true
   327  		}
   328  	}
   329  
   330  	return false
   331  }
   332  
   333  func buildImplementationFunc(object *codegen.Object, field *codegen.Field) *ResolverImplementationFunc {
   334  	resolverType := detectResolverType(object, field)
   335  	modelName := field.GoFieldName
   336  	returnType := field.Type.NamedType
   337  
   338  	if resolverType == Delete {
   339  		modelName = strings.Replace(field.GoFieldName, "Delete", "", -1)
   340  	}
   341  
   342  	if resolverType == Create {
   343  		modelName = strings.Replace(field.GoFieldName, "Create", "", -1)
   344  	}
   345  
   346  	if resolverType == Update {
   347  		modelName = strings.Replace(field.GoFieldName, "Update", "", -1)
   348  	}
   349  
   350  	return &ResolverImplementationFunc{
   351  		Type:   resolverType,
   352  		Model:  modelName,
   353  		Return: returnType,
   354  		Fields: []ResolverFuncFieldMap{},
   355  	}
   356  }
   357  
   358  func detectResolverType(object *codegen.Object, field *codegen.Field) ResolverType {
   359  	resolverName := field.Name
   360  	hasId := false
   361  
   362  	for _, arg := range field.Args {
   363  		if arg.Name == "id" {
   364  			hasId = true
   365  		}
   366  	}
   367  
   368  	if strings.HasSuffix(resolverName, "Create") {
   369  		return Create
   370  	}
   371  
   372  	if strings.HasSuffix(resolverName, "Update") {
   373  		return Update
   374  	}
   375  
   376  	if strings.HasSuffix(resolverName, "Delete") {
   377  		return Delete
   378  	}
   379  
   380  	if !hasId && isPlural(resolverName) {
   381  		return GetList
   382  	}
   383  
   384  	if hasId {
   385  		return GetOne
   386  	}
   387  
   388  	return NA
   389  }