github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/plugin/resolvergen/resolver.go (about)

     1  package resolvergen
     2  
     3  import (
     4  	_ "embed"
     5  	"errors"
     6  	"fmt"
     7  	"io/fs"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  
    12  	"github.com/spread-ai/gqlgen/codegen"
    13  	"github.com/spread-ai/gqlgen/codegen/config"
    14  	"github.com/spread-ai/gqlgen/codegen/templates"
    15  	"github.com/spread-ai/gqlgen/internal/rewrite"
    16  	"github.com/spread-ai/gqlgen/plugin"
    17  	"golang.org/x/text/cases"
    18  	"golang.org/x/text/language"
    19  )
    20  
    21  //go:embed resolver.gotpl
    22  var resolverTemplate string
    23  
    24  func New() plugin.Plugin {
    25  	return &Plugin{}
    26  }
    27  
    28  type Plugin struct{}
    29  
    30  var _ plugin.CodeGenerator = &Plugin{}
    31  
    32  func (m *Plugin) Name() string {
    33  	return "resolvergen"
    34  }
    35  
    36  func (m *Plugin) GenerateCode(data *codegen.Data) error {
    37  	if !data.Config.Resolver.IsDefined() {
    38  		return nil
    39  	}
    40  
    41  	switch data.Config.Resolver.Layout {
    42  	case config.LayoutSingleFile:
    43  		return m.generateSingleFile(data)
    44  	case config.LayoutFollowSchema:
    45  		return m.generatePerSchema(data)
    46  	}
    47  
    48  	return nil
    49  }
    50  
    51  func (m *Plugin) generateSingleFile(data *codegen.Data) error {
    52  	file := File{}
    53  
    54  	if _, err := os.Stat(data.Config.Resolver.Filename); err == nil {
    55  		// file already exists and we dont support updating resolvers with layout = single so just return
    56  		return nil
    57  	}
    58  
    59  	for _, o := range data.Objects {
    60  		if o.HasResolvers() {
    61  			file.Objects = append(file.Objects, o)
    62  		}
    63  		for _, f := range o.Fields {
    64  			if !f.IsResolver {
    65  				continue
    66  			}
    67  
    68  			resolver := Resolver{o, f, "// foo", `panic("not implemented")`}
    69  			file.Resolvers = append(file.Resolvers, &resolver)
    70  		}
    71  	}
    72  
    73  	resolverBuild := &ResolverBuild{
    74  		File:         &file,
    75  		PackageName:  data.Config.Resolver.Package,
    76  		ResolverType: data.Config.Resolver.Type,
    77  		HasRoot:      true,
    78  	}
    79  
    80  	return templates.Render(templates.Options{
    81  		PackageName: data.Config.Resolver.Package,
    82  		FileNotice:  `// THIS CODE IS A STARTING POINT ONLY. IT WILL NOT BE UPDATED WITH SCHEMA CHANGES.`,
    83  		Filename:    data.Config.Resolver.Filename,
    84  		Data:        resolverBuild,
    85  		Packages:    data.Config.Packages,
    86  		Template:    resolverTemplate,
    87  	})
    88  }
    89  
    90  func (m *Plugin) generatePerSchema(data *codegen.Data) error {
    91  	rewriter, err := rewrite.New(data.Config.Resolver.Dir())
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	files := map[string]*File{}
    97  
    98  	objects := make(codegen.Objects, len(data.Objects)+len(data.Inputs))
    99  	copy(objects, data.Objects)
   100  	copy(objects[len(data.Objects):], data.Inputs)
   101  
   102  	for _, o := range objects {
   103  		if o.HasResolvers() {
   104  			fn := gqlToResolverName(data.Config.Resolver.Dir(), o.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
   105  			if files[fn] == nil {
   106  				files[fn] = &File{}
   107  			}
   108  
   109  			caser := cases.Title(language.English, cases.NoLower)
   110  			rewriter.MarkStructCopied(templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type))
   111  			rewriter.GetMethodBody(data.Config.Resolver.Type, caser.String(o.Name))
   112  			files[fn].Objects = append(files[fn].Objects, o)
   113  		}
   114  		for _, f := range o.Fields {
   115  			if !f.IsResolver {
   116  				continue
   117  			}
   118  
   119  			structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)
   120  			implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
   121  			comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`))
   122  			if implementation == "" {
   123  				implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
   124  			}
   125  			if comment == "" {
   126  				comment = fmt.Sprintf("%v is the resolver for the %v field.", f.GoFieldName, f.Name)
   127  			}
   128  
   129  			resolver := Resolver{o, f, comment, implementation}
   130  			fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
   131  			if files[fn] == nil {
   132  				files[fn] = &File{}
   133  			}
   134  
   135  			files[fn].Resolvers = append(files[fn].Resolvers, &resolver)
   136  		}
   137  	}
   138  
   139  	for filename, file := range files {
   140  		file.imports = rewriter.ExistingImports(filename)
   141  		file.RemainingSource = rewriter.RemainingSource(filename)
   142  	}
   143  
   144  	for filename, file := range files {
   145  		resolverBuild := &ResolverBuild{
   146  			File:         file,
   147  			PackageName:  data.Config.Resolver.Package,
   148  			ResolverType: data.Config.Resolver.Type,
   149  		}
   150  
   151  		err := templates.Render(templates.Options{
   152  			PackageName: data.Config.Resolver.Package,
   153  			FileNotice: `
   154  				// This file will be automatically regenerated based on the schema, any resolver implementations
   155  				// will be copied through when generating and any unknown code will be moved to the end.`,
   156  			Filename: filename,
   157  			Data:     resolverBuild,
   158  			Packages: data.Config.Packages,
   159  			Template: resolverTemplate,
   160  		})
   161  		if err != nil {
   162  			return err
   163  		}
   164  	}
   165  
   166  	if _, err := os.Stat(data.Config.Resolver.Filename); errors.Is(err, fs.ErrNotExist) {
   167  		err := templates.Render(templates.Options{
   168  			PackageName: data.Config.Resolver.Package,
   169  			FileNotice: `
   170  				// This file will not be regenerated automatically.
   171  				//
   172  				// It serves as dependency injection for your app, add any dependencies you require here.`,
   173  			Template: `type {{.}} struct {}`,
   174  			Filename: data.Config.Resolver.Filename,
   175  			Data:     data.Config.Resolver.Type,
   176  			Packages: data.Config.Packages,
   177  		})
   178  		if err != nil {
   179  			return err
   180  		}
   181  	}
   182  	return nil
   183  }
   184  
   185  type ResolverBuild struct {
   186  	*File
   187  	HasRoot      bool
   188  	PackageName  string
   189  	ResolverType string
   190  }
   191  
   192  type File struct {
   193  	// These are separated because the type definition of the resolver object may live in a different file from the
   194  	// resolver method implementations, for example when extending a type in a different graphql schema file
   195  	Objects         []*codegen.Object
   196  	Resolvers       []*Resolver
   197  	imports         []rewrite.Import
   198  	RemainingSource string
   199  }
   200  
   201  func (f *File) Imports() string {
   202  	for _, imp := range f.imports {
   203  		if imp.Alias == "" {
   204  			_, _ = templates.CurrentImports.Reserve(imp.ImportPath)
   205  		} else {
   206  			_, _ = templates.CurrentImports.Reserve(imp.ImportPath, imp.Alias)
   207  		}
   208  	}
   209  	return ""
   210  }
   211  
   212  type Resolver struct {
   213  	Object         *codegen.Object
   214  	Field          *codegen.Field
   215  	Comment        string
   216  	Implementation string
   217  }
   218  
   219  func gqlToResolverName(base string, gqlname, filenameTmpl string) string {
   220  	gqlname = filepath.Base(gqlname)
   221  	ext := filepath.Ext(gqlname)
   222  	if filenameTmpl == "" {
   223  		filenameTmpl = "{name}.resolvers.go"
   224  	}
   225  	filename := strings.ReplaceAll(filenameTmpl, "{name}", strings.TrimSuffix(gqlname, ext))
   226  	return filepath.Join(base, filename)
   227  }