github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/generate.go (about)

     1  package codegen
     2  
     3  import (
     4  	"embed"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  	"runtime"
    10  	"strings"
    11  
    12  	"github.com/vektah/gqlparser/v2/ast"
    13  
    14  	"github.com/niko0xdev/gqlgen/codegen/config"
    15  	"github.com/niko0xdev/gqlgen/codegen/templates"
    16  )
    17  
    18  //go:embed *.gotpl
    19  var codegenTemplates embed.FS
    20  
    21  func GenerateCode(data *Data) error {
    22  	if !data.Config.Exec.IsDefined() {
    23  		return fmt.Errorf("missing exec config")
    24  	}
    25  
    26  	switch data.Config.Exec.Layout {
    27  	case config.ExecLayoutSingleFile:
    28  		return generateSingleFile(data)
    29  	case config.ExecLayoutFollowSchema:
    30  		return generatePerSchema(data)
    31  	}
    32  
    33  	return fmt.Errorf("unrecognized exec layout %s", data.Config.Exec.Layout)
    34  }
    35  
    36  func generateSingleFile(data *Data) error {
    37  	return templates.Render(templates.Options{
    38  		PackageName:     data.Config.Exec.Package,
    39  		Filename:        data.Config.Exec.Filename,
    40  		Data:            data,
    41  		RegionTags:      true,
    42  		GeneratedHeader: true,
    43  		Packages:        data.Config.Packages,
    44  		TemplateFS:      codegenTemplates,
    45  	})
    46  }
    47  
    48  func generatePerSchema(data *Data) error {
    49  	err := generateRootFile(data)
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	builds := map[string]*Data{}
    55  
    56  	err = addObjects(data, &builds)
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	err = addInputs(data, &builds)
    62  	if err != nil {
    63  		return err
    64  	}
    65  
    66  	err = addInterfaces(data, &builds)
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	err = addReferencedTypes(data, &builds)
    72  	if err != nil {
    73  		return err
    74  	}
    75  
    76  	for filename, build := range builds {
    77  		if filename == "" {
    78  			continue
    79  		}
    80  
    81  		dir := data.Config.Exec.DirName
    82  		path := filepath.Join(dir, filename)
    83  
    84  		err = templates.Render(templates.Options{
    85  			PackageName:     data.Config.Exec.Package,
    86  			Filename:        path,
    87  			Data:            build,
    88  			RegionTags:      true,
    89  			GeneratedHeader: true,
    90  			Packages:        data.Config.Packages,
    91  			TemplateFS:      codegenTemplates,
    92  		})
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func filename(p *ast.Position, config *config.Config) string {
   102  	name := "common!"
   103  	if p != nil && p.Src != nil {
   104  		gqlname := filepath.Base(p.Src.Name)
   105  		ext := filepath.Ext(p.Src.Name)
   106  		name = strings.TrimSuffix(gqlname, ext)
   107  	}
   108  
   109  	filenameTempl := config.Exec.FilenameTemplate
   110  	if filenameTempl == "" {
   111  		filenameTempl = "{name}.generated.go"
   112  	}
   113  
   114  	return strings.ReplaceAll(filenameTempl, "{name}", name)
   115  }
   116  
   117  func addBuild(filename string, p *ast.Position, data *Data, builds *map[string]*Data) {
   118  	buildConfig := *data.Config
   119  	if p != nil {
   120  		buildConfig.Sources = []*ast.Source{p.Src}
   121  	}
   122  
   123  	(*builds)[filename] = &Data{
   124  		Config:           &buildConfig,
   125  		QueryRoot:        data.QueryRoot,
   126  		MutationRoot:     data.MutationRoot,
   127  		SubscriptionRoot: data.SubscriptionRoot,
   128  		AllDirectives:    data.AllDirectives,
   129  	}
   130  }
   131  
   132  // Root file contains top-level definitions that should not be duplicated across the generated
   133  // files for each schema file.
   134  func generateRootFile(data *Data) error {
   135  	dir := data.Config.Exec.DirName
   136  	path := filepath.Join(dir, "root_.generated.go")
   137  
   138  	_, thisFile, _, _ := runtime.Caller(0)
   139  	rootDir := filepath.Dir(thisFile)
   140  	templatePath := filepath.Join(rootDir, "root_.gotpl")
   141  	templateBytes, err := os.ReadFile(templatePath)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	template := string(templateBytes)
   146  
   147  	return templates.Render(templates.Options{
   148  		PackageName:     data.Config.Exec.Package,
   149  		Template:        template,
   150  		Filename:        path,
   151  		Data:            data,
   152  		RegionTags:      false,
   153  		GeneratedHeader: true,
   154  		Packages:        data.Config.Packages,
   155  		TemplateFS:      codegenTemplates,
   156  	})
   157  }
   158  
   159  func addObjects(data *Data, builds *map[string]*Data) error {
   160  	for _, o := range data.Objects {
   161  		filename := filename(o.Position, data.Config)
   162  		if (*builds)[filename] == nil {
   163  			addBuild(filename, o.Position, data, builds)
   164  		}
   165  
   166  		(*builds)[filename].Objects = append((*builds)[filename].Objects, o)
   167  	}
   168  	return nil
   169  }
   170  
   171  func addInputs(data *Data, builds *map[string]*Data) error {
   172  	for _, in := range data.Inputs {
   173  		filename := filename(in.Position, data.Config)
   174  		if (*builds)[filename] == nil {
   175  			addBuild(filename, in.Position, data, builds)
   176  		}
   177  
   178  		(*builds)[filename].Inputs = append((*builds)[filename].Inputs, in)
   179  	}
   180  	return nil
   181  }
   182  
   183  func addInterfaces(data *Data, builds *map[string]*Data) error {
   184  	for k, inf := range data.Interfaces {
   185  		filename := filename(inf.Position, data.Config)
   186  		if (*builds)[filename] == nil {
   187  			addBuild(filename, inf.Position, data, builds)
   188  		}
   189  		build := (*builds)[filename]
   190  
   191  		if build.Interfaces == nil {
   192  			build.Interfaces = map[string]*Interface{}
   193  		}
   194  		if build.Interfaces[k] != nil {
   195  			return errors.New("conflicting interface keys")
   196  		}
   197  
   198  		build.Interfaces[k] = inf
   199  	}
   200  	return nil
   201  }
   202  
   203  func addReferencedTypes(data *Data, builds *map[string]*Data) error {
   204  	for k, rt := range data.ReferencedTypes {
   205  		filename := filename(rt.Definition.Position, data.Config)
   206  		if (*builds)[filename] == nil {
   207  			addBuild(filename, rt.Definition.Position, data, builds)
   208  		}
   209  		build := (*builds)[filename]
   210  
   211  		if build.ReferencedTypes == nil {
   212  			build.ReferencedTypes = map[string]*config.TypeReference{}
   213  		}
   214  		if build.ReferencedTypes[k] != nil {
   215  			return errors.New("conflicting referenced type keys")
   216  		}
   217  
   218  		build.ReferencedTypes[k] = rt
   219  	}
   220  	return nil
   221  }