github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/plugin/resolvergen/resolver.go (about) 1 package resolvergen 2 3 import ( 4 _ "embed" 5 "errors" 6 "fmt" 7 "go/ast" 8 "io/fs" 9 "os" 10 "path/filepath" 11 "strings" 12 13 "golang.org/x/text/cases" 14 "golang.org/x/text/language" 15 16 "github.com/geneva/gqlgen/codegen" 17 "github.com/geneva/gqlgen/codegen/config" 18 "github.com/geneva/gqlgen/codegen/templates" 19 "github.com/geneva/gqlgen/graphql" 20 "github.com/geneva/gqlgen/internal/rewrite" 21 "github.com/geneva/gqlgen/plugin" 22 ) 23 24 //go:embed resolver.gotpl 25 var resolverTemplate string 26 27 func New() plugin.Plugin { 28 return &Plugin{} 29 } 30 31 type Plugin struct{} 32 33 var _ plugin.CodeGenerator = &Plugin{} 34 35 func (m *Plugin) Name() string { 36 return "resolvergen" 37 } 38 39 func (m *Plugin) GenerateCode(data *codegen.Data) error { 40 if !data.Config.Resolver.IsDefined() { 41 return nil 42 } 43 44 switch data.Config.Resolver.Layout { 45 case config.LayoutSingleFile: 46 return m.generateSingleFile(data) 47 case config.LayoutFollowSchema: 48 return m.generatePerSchema(data) 49 } 50 51 return nil 52 } 53 54 func (m *Plugin) generateSingleFile(data *codegen.Data) error { 55 file := File{} 56 57 if _, err := os.Stat(data.Config.Resolver.Filename); err == nil { 58 // file already exists and we do not support updating resolvers with layout = single so just return 59 return nil 60 } 61 62 for _, o := range data.Objects { 63 if o.HasResolvers() { 64 file.Objects = append(file.Objects, o) 65 } 66 for _, f := range o.Fields { 67 if !f.IsResolver { 68 continue 69 } 70 71 resolver := Resolver{o, f, nil, "", `panic("not implemented")`} 72 file.Resolvers = append(file.Resolvers, &resolver) 73 } 74 } 75 76 resolverBuild := &ResolverBuild{ 77 File: &file, 78 PackageName: data.Config.Resolver.Package, 79 ResolverType: data.Config.Resolver.Type, 80 HasRoot: true, 81 OmitTemplateComment: data.Config.Resolver.OmitTemplateComment, 82 } 83 84 newResolverTemplate := resolverTemplate 85 if data.Config.Resolver.ResolverTemplate != "" { 86 newResolverTemplate = readResolverTemplate(data.Config.Resolver.ResolverTemplate) 87 } 88 89 return templates.Render(templates.Options{ 90 PackageName: data.Config.Resolver.Package, 91 FileNotice: `// THIS CODE IS A STARTING POINT ONLY. IT WILL NOT BE UPDATED WITH SCHEMA CHANGES.`, 92 Filename: data.Config.Resolver.Filename, 93 Data: resolverBuild, 94 Packages: data.Config.Packages, 95 Template: newResolverTemplate, 96 }) 97 } 98 99 func (m *Plugin) generatePerSchema(data *codegen.Data) error { 100 rewriter, err := rewrite.New(data.Config.Resolver.Dir()) 101 if err != nil { 102 return err 103 } 104 105 files := map[string]*File{} 106 107 objects := make(codegen.Objects, len(data.Objects)+len(data.Inputs)) 108 copy(objects, data.Objects) 109 copy(objects[len(data.Objects):], data.Inputs) 110 111 for _, o := range objects { 112 if o.HasResolvers() { 113 fn := gqlToResolverName(data.Config.Resolver.Dir(), o.Position.Src.Name, data.Config.Resolver.FilenameTemplate) 114 if files[fn] == nil { 115 files[fn] = &File{} 116 } 117 118 caser := cases.Title(language.English, cases.NoLower) 119 rewriter.MarkStructCopied(templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)) 120 rewriter.GetMethodBody(data.Config.Resolver.Type, caser.String(o.Name)) 121 files[fn].Objects = append(files[fn].Objects, o) 122 } 123 for _, f := range o.Fields { 124 if !f.IsResolver { 125 continue 126 } 127 128 structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type) 129 comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`)) 130 131 implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName)) 132 if implementation == "" { 133 // Check for Implementer Plugin 134 var resolver_implementer plugin.ResolverImplementer 135 var exists bool 136 for _, p := range data.Plugins { 137 if p_cast, ok := p.(plugin.ResolverImplementer); ok { 138 resolver_implementer = p_cast 139 exists = true 140 break 141 } 142 } 143 144 if exists { 145 implementation = resolver_implementer.Implement(f) 146 } else { 147 implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name) 148 } 149 150 } 151 152 resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation} 153 fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate) 154 if files[fn] == nil { 155 files[fn] = &File{} 156 } 157 158 files[fn].Resolvers = append(files[fn].Resolvers, &resolver) 159 } 160 } 161 162 for filename, file := range files { 163 file.imports = rewriter.ExistingImports(filename) 164 file.RemainingSource = rewriter.RemainingSource(filename) 165 } 166 newResolverTemplate := resolverTemplate 167 if data.Config.Resolver.ResolverTemplate != "" { 168 newResolverTemplate = readResolverTemplate(data.Config.Resolver.ResolverTemplate) 169 } 170 171 for filename, file := range files { 172 resolverBuild := &ResolverBuild{ 173 File: file, 174 PackageName: data.Config.Resolver.Package, 175 ResolverType: data.Config.Resolver.Type, 176 OmitTemplateComment: data.Config.Resolver.OmitTemplateComment, 177 } 178 179 var fileNotice strings.Builder 180 if !data.Config.OmitGQLGenFileNotice { 181 fileNotice.WriteString(` 182 // This file will be automatically regenerated based on the schema, any resolver implementations 183 // will be copied through when generating and any unknown code will be moved to the end. 184 // Code generated by github.com/geneva/gqlgen`, 185 ) 186 if !data.Config.OmitGQLGenVersionInFileNotice { 187 fileNotice.WriteString(` version `) 188 fileNotice.WriteString(graphql.Version) 189 } 190 } 191 192 err := templates.Render(templates.Options{ 193 PackageName: data.Config.Resolver.Package, 194 FileNotice: fileNotice.String(), 195 Filename: filename, 196 Data: resolverBuild, 197 Packages: data.Config.Packages, 198 Template: newResolverTemplate, 199 }) 200 if err != nil { 201 return err 202 } 203 } 204 205 if _, err := os.Stat(data.Config.Resolver.Filename); errors.Is(err, fs.ErrNotExist) { 206 err := templates.Render(templates.Options{ 207 PackageName: data.Config.Resolver.Package, 208 FileNotice: ` 209 // This file will not be regenerated automatically. 210 // 211 // It serves as dependency injection for your app, add any dependencies you require here.`, 212 Template: `type {{.}} struct {}`, 213 Filename: data.Config.Resolver.Filename, 214 Data: data.Config.Resolver.Type, 215 Packages: data.Config.Packages, 216 }) 217 if err != nil { 218 return err 219 } 220 } 221 return nil 222 } 223 224 type ResolverBuild struct { 225 *File 226 HasRoot bool 227 PackageName string 228 ResolverType string 229 OmitTemplateComment bool 230 } 231 232 type File struct { 233 // These are separated because the type definition of the resolver object may live in a different file from the 234 // resolver method implementations, for example when extending a type in a different graphql schema file 235 Objects []*codegen.Object 236 Resolvers []*Resolver 237 imports []rewrite.Import 238 RemainingSource string 239 } 240 241 func (f *File) Imports() string { 242 for _, imp := range f.imports { 243 if imp.Alias == "" { 244 _, _ = templates.CurrentImports.Reserve(imp.ImportPath) 245 } else { 246 _, _ = templates.CurrentImports.Reserve(imp.ImportPath, imp.Alias) 247 } 248 } 249 return "" 250 } 251 252 type Resolver struct { 253 Object *codegen.Object 254 Field *codegen.Field 255 PrevDecl *ast.FuncDecl 256 Comment string 257 Implementation string 258 } 259 260 func gqlToResolverName(base string, gqlname, filenameTmpl string) string { 261 gqlname = filepath.Base(gqlname) 262 ext := filepath.Ext(gqlname) 263 if filenameTmpl == "" { 264 filenameTmpl = "{name}.resolvers.go" 265 } 266 filename := strings.ReplaceAll(filenameTmpl, "{name}", strings.TrimSuffix(gqlname, ext)) 267 return filepath.Join(base, filename) 268 } 269 270 func readResolverTemplate(customResolverTemplate string) string { 271 contentBytes, err := os.ReadFile(customResolverTemplate) 272 if err != nil { 273 panic(err) 274 } 275 return string(contentBytes) 276 }