github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/codegen/generate.go (about)

     1  package codegen
     2  
     3  import (
     4  	"embed"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  	"runtime"
    10  	"strings"
    11  
    12  	"github.com/geneva/gqlgen/codegen/config"
    13  	"github.com/geneva/gqlgen/codegen/templates"
    14  	"github.com/vektah/gqlparser/v2/ast"
    15  )
    16  
    17  //go:embed *.gotpl
    18  var codegenTemplates embed.FS
    19  
    20  func GenerateCode(data *Data) error {
    21  	if !data.Config.Exec.IsDefined() {
    22  		return fmt.Errorf("missing exec config")
    23  	}
    24  
    25  	switch data.Config.Exec.Layout {
    26  	case config.ExecLayoutSingleFile:
    27  		return generateSingleFile(data)
    28  	case config.ExecLayoutFollowSchema:
    29  		return generatePerSchema(data)
    30  	}
    31  
    32  	return fmt.Errorf("unrecognized exec layout %s", data.Config.Exec.Layout)
    33  }
    34  
    35  func generateSingleFile(data *Data) error {
    36  	return templates.Render(templates.Options{
    37  		PackageName:     data.Config.Exec.Package,
    38  		Filename:        data.Config.Exec.Filename,
    39  		Data:            data,
    40  		RegionTags:      true,
    41  		GeneratedHeader: true,
    42  		Packages:        data.Config.Packages,
    43  		TemplateFS:      codegenTemplates,
    44  	})
    45  }
    46  
    47  func generatePerSchema(data *Data) error {
    48  	err := generateRootFile(data)
    49  	if err != nil {
    50  		return err
    51  	}
    52  
    53  	builds := map[string]*Data{}
    54  
    55  	err = addObjects(data, &builds)
    56  	if err != nil {
    57  		return err
    58  	}
    59  
    60  	err = addInputs(data, &builds)
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	err = addInterfaces(data, &builds)
    66  	if err != nil {
    67  		return err
    68  	}
    69  
    70  	err = addReferencedTypes(data, &builds)
    71  	if err != nil {
    72  		return err
    73  	}
    74  
    75  	for filename, build := range builds {
    76  		if filename == "" {
    77  			continue
    78  		}
    79  
    80  		dir := data.Config.Exec.DirName
    81  		path := filepath.Join(dir, filename)
    82  
    83  		err = templates.Render(templates.Options{
    84  			PackageName:     data.Config.Exec.Package,
    85  			Filename:        path,
    86  			Data:            build,
    87  			RegionTags:      true,
    88  			GeneratedHeader: true,
    89  			Packages:        data.Config.Packages,
    90  			TemplateFS:      codegenTemplates,
    91  		})
    92  		if err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func filename(p *ast.Position, config *config.Config) string {
   101  	name := "common!"
   102  	if p != nil && p.Src != nil {
   103  		gqlname := filepath.Base(p.Src.Name)
   104  		ext := filepath.Ext(p.Src.Name)
   105  		name = strings.TrimSuffix(gqlname, ext)
   106  	}
   107  
   108  	filenameTempl := config.Exec.FilenameTemplate
   109  	if filenameTempl == "" {
   110  		filenameTempl = "{name}.generated.go"
   111  	}
   112  
   113  	return strings.ReplaceAll(filenameTempl, "{name}", name)
   114  }
   115  
   116  func addBuild(filename string, p *ast.Position, data *Data, builds *map[string]*Data) {
   117  	buildConfig := *data.Config
   118  	if p != nil {
   119  		buildConfig.Sources = []*ast.Source{p.Src}
   120  	}
   121  
   122  	(*builds)[filename] = &Data{
   123  		Config:           &buildConfig,
   124  		QueryRoot:        data.QueryRoot,
   125  		MutationRoot:     data.MutationRoot,
   126  		SubscriptionRoot: data.SubscriptionRoot,
   127  		AllDirectives:    data.AllDirectives,
   128  	}
   129  }
   130  
   131  // Root file contains top-level definitions that should not be duplicated across the generated
   132  // files for each schema file.
   133  func generateRootFile(data *Data) error {
   134  	dir := data.Config.Exec.DirName
   135  	path := filepath.Join(dir, "root_.generated.go")
   136  
   137  	_, thisFile, _, _ := runtime.Caller(0)
   138  	rootDir := filepath.Dir(thisFile)
   139  	templatePath := filepath.Join(rootDir, "root_.gotpl")
   140  	templateBytes, err := os.ReadFile(templatePath)
   141  	if err != nil {
   142  		return err
   143  	}
   144  	template := string(templateBytes)
   145  
   146  	return templates.Render(templates.Options{
   147  		PackageName:     data.Config.Exec.Package,
   148  		Template:        template,
   149  		Filename:        path,
   150  		Data:            data,
   151  		RegionTags:      false,
   152  		GeneratedHeader: true,
   153  		Packages:        data.Config.Packages,
   154  		TemplateFS:      codegenTemplates,
   155  	})
   156  }
   157  
   158  func addObjects(data *Data, builds *map[string]*Data) error {
   159  	for _, o := range data.Objects {
   160  		filename := filename(o.Position, data.Config)
   161  		if (*builds)[filename] == nil {
   162  			addBuild(filename, o.Position, data, builds)
   163  		}
   164  
   165  		(*builds)[filename].Objects = append((*builds)[filename].Objects, o)
   166  	}
   167  	return nil
   168  }
   169  
   170  func addInputs(data *Data, builds *map[string]*Data) error {
   171  	for _, in := range data.Inputs {
   172  		filename := filename(in.Position, data.Config)
   173  		if (*builds)[filename] == nil {
   174  			addBuild(filename, in.Position, data, builds)
   175  		}
   176  
   177  		(*builds)[filename].Inputs = append((*builds)[filename].Inputs, in)
   178  	}
   179  	return nil
   180  }
   181  
   182  func addInterfaces(data *Data, builds *map[string]*Data) error {
   183  	for k, inf := range data.Interfaces {
   184  		filename := filename(inf.Position, data.Config)
   185  		if (*builds)[filename] == nil {
   186  			addBuild(filename, inf.Position, data, builds)
   187  		}
   188  		build := (*builds)[filename]
   189  
   190  		if build.Interfaces == nil {
   191  			build.Interfaces = map[string]*Interface{}
   192  		}
   193  		if build.Interfaces[k] != nil {
   194  			return errors.New("conflicting interface keys")
   195  		}
   196  
   197  		build.Interfaces[k] = inf
   198  	}
   199  	return nil
   200  }
   201  
   202  func addReferencedTypes(data *Data, builds *map[string]*Data) error {
   203  	for k, rt := range data.ReferencedTypes {
   204  		filename := filename(rt.Definition.Position, data.Config)
   205  		if (*builds)[filename] == nil {
   206  			addBuild(filename, rt.Definition.Position, data, builds)
   207  		}
   208  		build := (*builds)[filename]
   209  
   210  		if build.ReferencedTypes == nil {
   211  			build.ReferencedTypes = map[string]*config.TypeReference{}
   212  		}
   213  		if build.ReferencedTypes[k] != nil {
   214  			return errors.New("conflicting referenced type keys")
   215  		}
   216  
   217  		build.ReferencedTypes[k] = rt
   218  	}
   219  	return nil
   220  }