github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/codegen/templates/templates.go (about)

     1  package templates
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/types"
     7  	"io/fs"
     8  	"os"
     9  	"path/filepath"
    10  	"reflect"
    11  	"regexp"
    12  	"runtime"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"text/template"
    18  	"unicode"
    19  
    20  	"github.com/geneva/gqlgen/codegen/config"
    21  	"github.com/geneva/gqlgen/internal/code"
    22  	"github.com/geneva/gqlgen/internal/imports"
    23  )
    24  
    25  // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
    26  // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
    27  var CurrentImports *Imports
    28  
    29  // Options specify various parameters to rendering a template.
    30  type Options struct {
    31  	// PackageName is a helper that specifies the package header declaration.
    32  	// In other words, when you write the template you don't need to specify `package X`
    33  	// at the top of the file. By providing PackageName in the Options, the Render
    34  	// function will do that for you.
    35  	PackageName string
    36  	// Template is a string of the entire template that
    37  	// will be parsed and rendered. If it's empty,
    38  	// the plugin processor will look for .gotpl files
    39  	// in the same directory of where you wrote the plugin.
    40  	Template string
    41  
    42  	// Use the go:embed API to collect all the template files you want to pass into Render
    43  	// this is an alternative to passing the Template option
    44  	TemplateFS fs.FS
    45  
    46  	// Filename is the name of the file that will be
    47  	// written to the system disk once the template is rendered.
    48  	Filename        string
    49  	RegionTags      bool
    50  	GeneratedHeader bool
    51  	// PackageDoc is documentation written above the package line
    52  	PackageDoc string
    53  	// FileNotice is notice written below the package line
    54  	FileNotice string
    55  	// Data will be passed to the template execution.
    56  	Data  interface{}
    57  	Funcs template.FuncMap
    58  
    59  	// Packages cache, you can find me on config.Config
    60  	Packages *code.Packages
    61  }
    62  
    63  var (
    64  	modelNamesMu sync.Mutex
    65  	modelNames   = make(map[string]string, 0)
    66  	goNameRe     = regexp.MustCompile("[^a-zA-Z0-9_]")
    67  )
    68  
    69  // Render renders a gql plugin template from the given Options. Render is an
    70  // abstraction of the text/template package that makes it easier to write gqlgen
    71  // plugins. If Options.Template is empty, the Render function will look for `.gotpl`
    72  // files inside the directory where you wrote the plugin.
    73  func Render(cfg Options) error {
    74  	if CurrentImports != nil {
    75  		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
    76  	}
    77  	CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
    78  
    79  	funcs := Funcs()
    80  	for n, f := range cfg.Funcs {
    81  		funcs[n] = f
    82  	}
    83  
    84  	t := template.New("").Funcs(funcs)
    85  	t, err := parseTemplates(cfg, t)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	roots := make([]string, 0, len(t.Templates()))
    91  	for _, template := range t.Templates() {
    92  		// templates that end with _.gotpl are special files we don't want to include
    93  		if strings.HasSuffix(template.Name(), "_.gotpl") ||
    94  			// filter out templates added with {{ template xxx }} syntax inside the template file
    95  			!strings.HasSuffix(template.Name(), ".gotpl") {
    96  			continue
    97  		}
    98  
    99  		roots = append(roots, template.Name())
   100  	}
   101  
   102  	// then execute all the important looking ones in order, adding them to the same file
   103  	sort.Slice(roots, func(i, j int) bool {
   104  		// important files go first
   105  		if strings.HasSuffix(roots[i], "!.gotpl") {
   106  			return true
   107  		}
   108  		if strings.HasSuffix(roots[j], "!.gotpl") {
   109  			return false
   110  		}
   111  		return roots[i] < roots[j]
   112  	})
   113  
   114  	var buf bytes.Buffer
   115  	for _, root := range roots {
   116  		if cfg.RegionTags {
   117  			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
   118  		}
   119  		err := t.Lookup(root).Execute(&buf, cfg.Data)
   120  		if err != nil {
   121  			return fmt.Errorf("%s: %w", root, err)
   122  		}
   123  		if cfg.RegionTags {
   124  			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
   125  		}
   126  	}
   127  
   128  	var result bytes.Buffer
   129  	if cfg.GeneratedHeader {
   130  		result.WriteString("// Code generated by github.com/geneva/gqlgen, DO NOT EDIT.\n\n")
   131  	}
   132  	if cfg.PackageDoc != "" {
   133  		result.WriteString(cfg.PackageDoc + "\n")
   134  	}
   135  	result.WriteString("package ")
   136  	result.WriteString(cfg.PackageName)
   137  	result.WriteString("\n\n")
   138  	if cfg.FileNotice != "" {
   139  		result.WriteString(cfg.FileNotice)
   140  		result.WriteString("\n\n")
   141  	}
   142  	result.WriteString("import (\n")
   143  	result.WriteString(CurrentImports.String())
   144  	result.WriteString(")\n")
   145  	_, err = buf.WriteTo(&result)
   146  	if err != nil {
   147  		return err
   148  	}
   149  	CurrentImports = nil
   150  
   151  	err = write(cfg.Filename, result.Bytes(), cfg.Packages)
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
   157  	return nil
   158  }
   159  
   160  func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
   161  	if cfg.Template != "" {
   162  		var err error
   163  		t, err = t.New("template.gotpl").Parse(cfg.Template)
   164  		if err != nil {
   165  			return nil, fmt.Errorf("error with provided template: %w", err)
   166  		}
   167  		return t, nil
   168  	}
   169  
   170  	var fileSystem fs.FS
   171  	if cfg.TemplateFS != nil {
   172  		fileSystem = cfg.TemplateFS
   173  	} else {
   174  		// load path relative to calling source file
   175  		_, callerFile, _, _ := runtime.Caller(2)
   176  		rootDir := filepath.Dir(callerFile)
   177  		fileSystem = os.DirFS(rootDir)
   178  	}
   179  
   180  	t, err := t.ParseFS(fileSystem, "*.gotpl")
   181  	if err != nil {
   182  		return nil, fmt.Errorf("locating templates: %w", err)
   183  	}
   184  
   185  	return t, nil
   186  }
   187  
   188  func center(width int, pad string, s string) string {
   189  	if len(s)+2 > width {
   190  		return s
   191  	}
   192  	lpad := (width - len(s)) / 2
   193  	rpad := width - (lpad + len(s))
   194  	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
   195  }
   196  
   197  func Funcs() template.FuncMap {
   198  	return template.FuncMap{
   199  		"ucFirst":            UcFirst,
   200  		"lcFirst":            LcFirst,
   201  		"quote":              strconv.Quote,
   202  		"rawQuote":           rawQuote,
   203  		"dump":               Dump,
   204  		"ref":                ref,
   205  		"ts":                 config.TypeIdentifier,
   206  		"call":               Call,
   207  		"prefixLines":        prefixLines,
   208  		"notNil":             notNil,
   209  		"reserveImport":      CurrentImports.Reserve,
   210  		"lookupImport":       CurrentImports.Lookup,
   211  		"go":                 ToGo,
   212  		"goPrivate":          ToGoPrivate,
   213  		"goModelName":        ToGoModelName,
   214  		"goPrivateModelName": ToGoPrivateModelName,
   215  		"add": func(a, b int) int {
   216  			return a + b
   217  		},
   218  		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   219  			return render(resolveName(filename, 0), tpldata)
   220  		},
   221  	}
   222  }
   223  
   224  func UcFirst(s string) string {
   225  	if s == "" {
   226  		return ""
   227  	}
   228  	r := []rune(s)
   229  	r[0] = unicode.ToUpper(r[0])
   230  	return string(r)
   231  }
   232  
   233  func LcFirst(s string) string {
   234  	if s == "" {
   235  		return ""
   236  	}
   237  
   238  	r := []rune(s)
   239  	r[0] = unicode.ToLower(r[0])
   240  	return string(r)
   241  }
   242  
   243  func isDelimiter(c rune) bool {
   244  	return c == '-' || c == '_' || unicode.IsSpace(c)
   245  }
   246  
   247  func ref(p types.Type) string {
   248  	return CurrentImports.LookupType(p)
   249  }
   250  
   251  func Call(p *types.Func) string {
   252  	pkg := CurrentImports.Lookup(p.Pkg().Path())
   253  
   254  	if pkg != "" {
   255  		pkg += "."
   256  	}
   257  
   258  	if p.Type() != nil {
   259  		// make sure the returned type is listed in our imports.
   260  		ref(p.Type().(*types.Signature).Results().At(0).Type())
   261  	}
   262  
   263  	return pkg + p.Name()
   264  }
   265  
   266  func resetModelNames() {
   267  	modelNamesMu.Lock()
   268  	defer modelNamesMu.Unlock()
   269  	modelNames = make(map[string]string, 0)
   270  }
   271  
   272  func buildGoModelNameKey(parts []string) string {
   273  	const sep = ":"
   274  	return strings.Join(parts, sep)
   275  }
   276  
   277  func goModelName(primaryToGoFunc func(string) string, parts []string) string {
   278  	modelNamesMu.Lock()
   279  	defer modelNamesMu.Unlock()
   280  
   281  	var (
   282  		goNameKey string
   283  		partLen   int
   284  
   285  		nameExists = func(n string) bool {
   286  			for _, v := range modelNames {
   287  				if n == v {
   288  					return true
   289  				}
   290  			}
   291  			return false
   292  		}
   293  
   294  		applyToGoFunc = func(parts []string) string {
   295  			var out string
   296  			switch len(parts) {
   297  			case 0:
   298  				return ""
   299  			case 1:
   300  				return primaryToGoFunc(parts[0])
   301  			default:
   302  				out = primaryToGoFunc(parts[0])
   303  			}
   304  			for _, p := range parts[1:] {
   305  				out = fmt.Sprintf("%s%s", out, ToGo(p))
   306  			}
   307  			return out
   308  		}
   309  
   310  		applyValidGoName = func(parts []string) string {
   311  			var out string
   312  			for _, p := range parts {
   313  				out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p))
   314  			}
   315  			return out
   316  		}
   317  	)
   318  
   319  	// build key for this entity
   320  	goNameKey = buildGoModelNameKey(parts)
   321  
   322  	// determine if we've seen this entity before, and reuse if so
   323  	if goName, ok := modelNames[goNameKey]; ok {
   324  		return goName
   325  	}
   326  
   327  	// attempt first pass
   328  	if goName := applyToGoFunc(parts); !nameExists(goName) {
   329  		modelNames[goNameKey] = goName
   330  		return goName
   331  	}
   332  
   333  	// determine number of parts
   334  	partLen = len(parts)
   335  
   336  	// if there is only 1 part, append incrementing number until no conflict
   337  	if partLen == 1 {
   338  		base := applyToGoFunc(parts)
   339  		for i := 0; ; i++ {
   340  			tmp := fmt.Sprintf("%s%d", base, i)
   341  			if !nameExists(tmp) {
   342  				modelNames[goNameKey] = tmp
   343  				return tmp
   344  			}
   345  		}
   346  	}
   347  
   348  	// best effort "pretty" name
   349  	for i := partLen - 1; i >= 1; i-- {
   350  		tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:]))
   351  		if !nameExists(tmp) {
   352  			modelNames[goNameKey] = tmp
   353  			return tmp
   354  		}
   355  	}
   356  
   357  	// finally, fallback to just adding an incrementing number
   358  	base := applyToGoFunc(parts)
   359  	for i := 0; ; i++ {
   360  		tmp := fmt.Sprintf("%s%d", base, i)
   361  		if !nameExists(tmp) {
   362  			modelNames[goNameKey] = tmp
   363  			return tmp
   364  		}
   365  	}
   366  }
   367  
   368  func ToGoModelName(parts ...string) string {
   369  	return goModelName(ToGo, parts)
   370  }
   371  
   372  func ToGoPrivateModelName(parts ...string) string {
   373  	return goModelName(ToGoPrivate, parts)
   374  }
   375  
   376  func replaceInvalidCharacters(in string) string {
   377  	return goNameRe.ReplaceAllLiteralString(in, "_")
   378  }
   379  
   380  func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) {
   381  	return func(info *wordInfo) {
   382  		word := info.Word
   383  
   384  		switch {
   385  		case private && info.WordOffset == 0:
   386  			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
   387  				// ID → id, CAMEL → camel
   388  				word = strings.ToLower(info.Word)
   389  			} else {
   390  				// ITicket → iTicket
   391  				word = LcFirst(info.Word)
   392  			}
   393  
   394  		case info.MatchCommonInitial:
   395  			word = strings.ToUpper(word)
   396  
   397  		case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word):
   398  			// FOO or foo → Foo
   399  			// FOo → FOo
   400  			word = UcFirst(strings.ToLower(word))
   401  		}
   402  
   403  		*nameRunes = append(*nameRunes, []rune(word)...)
   404  	}
   405  }
   406  
   407  func ToGo(name string) string {
   408  	if name == "_" {
   409  		return "_"
   410  	}
   411  	runes := make([]rune, 0, len(name))
   412  
   413  	wordWalker(name, wordWalkerFunc(false, &runes))
   414  
   415  	return string(runes)
   416  }
   417  
   418  func ToGoPrivate(name string) string {
   419  	if name == "_" {
   420  		return "_"
   421  	}
   422  	runes := make([]rune, 0, len(name))
   423  
   424  	wordWalker(name, wordWalkerFunc(true, &runes))
   425  
   426  	return sanitizeKeywords(string(runes))
   427  }
   428  
   429  type wordInfo struct {
   430  	WordOffset         int
   431  	Word               string
   432  	MatchCommonInitial bool
   433  	HasCommonInitial   bool
   434  }
   435  
   436  // This function is based on the following code.
   437  // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
   438  func wordWalker(str string, f func(*wordInfo)) {
   439  	runes := []rune(strings.TrimFunc(str, isDelimiter))
   440  	w, i, wo := 0, 0, 0 // index of start of word, scan, word offset
   441  	hasCommonInitial := false
   442  	for i+1 <= len(runes) {
   443  		eow := false // whether we hit the end of a word
   444  		switch {
   445  		case i+1 == len(runes):
   446  			eow = true
   447  		case isDelimiter(runes[i+1]):
   448  			// underscore; shift the remainder forward over any run of underscores
   449  			eow = true
   450  			n := 1
   451  			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
   452  				n++
   453  			}
   454  
   455  			// Leave at most one underscore if the underscore is between two digits
   456  			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
   457  				n--
   458  			}
   459  
   460  			copy(runes[i+1:], runes[i+n+1:])
   461  			runes = runes[:len(runes)-n]
   462  		case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
   463  			// lower->non-lower
   464  			eow = true
   465  		}
   466  		i++
   467  
   468  		initialisms := config.GetInitialisms()
   469  		// [w,i) is a word.
   470  		word := string(runes[w:i])
   471  		if !eow && initialisms[word] && !unicode.IsLower(runes[i]) {
   472  			// through
   473  			// split IDFoo → ID, Foo
   474  			// but URLs → URLs
   475  		} else if !eow {
   476  			if initialisms[word] {
   477  				hasCommonInitial = true
   478  			}
   479  			continue
   480  		}
   481  
   482  		matchCommonInitial := false
   483  		upperWord := strings.ToUpper(word)
   484  		if initialisms[upperWord] {
   485  			// If the uppercase word (string(runes[w:i]) is "ID" or "IP"
   486  			// AND
   487  			// the word is the first two characters of the str
   488  			// AND
   489  			// that is not the end of the word
   490  			// AND
   491  			// the length of the string is greater than 3
   492  			// AND
   493  			// the third rune is an uppercase one
   494  			// THEN
   495  			// do NOT count this as an initialism.
   496  			switch upperWord {
   497  			case "ID", "IP":
   498  				if word == str[:2] && !eow && len(str) > 3 && unicode.IsUpper(runes[3]) {
   499  					continue
   500  				}
   501  			}
   502  			hasCommonInitial = true
   503  			matchCommonInitial = true
   504  		}
   505  
   506  		f(&wordInfo{
   507  			WordOffset:         wo,
   508  			Word:               word,
   509  			MatchCommonInitial: matchCommonInitial,
   510  			HasCommonInitial:   hasCommonInitial,
   511  		})
   512  		hasCommonInitial = false
   513  		w = i
   514  		wo++
   515  	}
   516  }
   517  
   518  var keywords = []string{
   519  	"break",
   520  	"default",
   521  	"func",
   522  	"interface",
   523  	"select",
   524  	"case",
   525  	"defer",
   526  	"go",
   527  	"map",
   528  	"struct",
   529  	"chan",
   530  	"else",
   531  	"goto",
   532  	"package",
   533  	"switch",
   534  	"const",
   535  	"fallthrough",
   536  	"if",
   537  	"range",
   538  	"type",
   539  	"continue",
   540  	"for",
   541  	"import",
   542  	"return",
   543  	"var",
   544  	"_",
   545  }
   546  
   547  // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
   548  func sanitizeKeywords(name string) string {
   549  	for _, k := range keywords {
   550  		if name == k {
   551  			return name + "Arg"
   552  		}
   553  	}
   554  	return name
   555  }
   556  
   557  func rawQuote(s string) string {
   558  	return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
   559  }
   560  
   561  func notNil(field string, data interface{}) bool {
   562  	v := reflect.ValueOf(data)
   563  
   564  	if v.Kind() == reflect.Ptr {
   565  		v = v.Elem()
   566  	}
   567  	if v.Kind() != reflect.Struct {
   568  		return false
   569  	}
   570  	val := v.FieldByName(field)
   571  
   572  	return val.IsValid() && !val.IsNil()
   573  }
   574  
   575  func Dump(val interface{}) string {
   576  	switch val := val.(type) {
   577  	case int:
   578  		return strconv.Itoa(val)
   579  	case int64:
   580  		return fmt.Sprintf("%d", val)
   581  	case float64:
   582  		return fmt.Sprintf("%f", val)
   583  	case string:
   584  		return strconv.Quote(val)
   585  	case bool:
   586  		return strconv.FormatBool(val)
   587  	case nil:
   588  		return "nil"
   589  	case []interface{}:
   590  		var parts []string
   591  		for _, part := range val {
   592  			parts = append(parts, Dump(part))
   593  		}
   594  		return "[]interface{}{" + strings.Join(parts, ",") + "}"
   595  	case map[string]interface{}:
   596  		buf := bytes.Buffer{}
   597  		buf.WriteString("map[string]interface{}{")
   598  		var keys []string
   599  		for key := range val {
   600  			keys = append(keys, key)
   601  		}
   602  		sort.Strings(keys)
   603  
   604  		for _, key := range keys {
   605  			data := val[key]
   606  
   607  			buf.WriteString(strconv.Quote(key))
   608  			buf.WriteString(":")
   609  			buf.WriteString(Dump(data))
   610  			buf.WriteString(",")
   611  		}
   612  		buf.WriteString("}")
   613  		return buf.String()
   614  	default:
   615  		panic(fmt.Errorf("unsupported type %T", val))
   616  	}
   617  }
   618  
   619  func prefixLines(prefix, s string) string {
   620  	return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
   621  }
   622  
   623  func resolveName(name string, skip int) string {
   624  	if name[0] == '.' {
   625  		// load path relative to calling source file
   626  		_, callerFile, _, _ := runtime.Caller(skip + 1)
   627  		return filepath.Join(filepath.Dir(callerFile), name[1:])
   628  	}
   629  
   630  	// load path relative to this directory
   631  	_, callerFile, _, _ := runtime.Caller(0)
   632  	return filepath.Join(filepath.Dir(callerFile), name)
   633  }
   634  
   635  func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   636  	t := template.New("").Funcs(Funcs())
   637  
   638  	b, err := os.ReadFile(filename)
   639  	if err != nil {
   640  		return nil, err
   641  	}
   642  
   643  	t, err = t.New(filepath.Base(filename)).Parse(string(b))
   644  	if err != nil {
   645  		panic(err)
   646  	}
   647  
   648  	buf := &bytes.Buffer{}
   649  	return buf, t.Execute(buf, tpldata)
   650  }
   651  
   652  func write(filename string, b []byte, packages *code.Packages) error {
   653  	err := os.MkdirAll(filepath.Dir(filename), 0o755)
   654  	if err != nil {
   655  		return fmt.Errorf("failed to create directory: %w", err)
   656  	}
   657  
   658  	formatted, err := imports.Prune(filename, b, packages)
   659  	if err != nil {
   660  		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
   661  		formatted = b
   662  	}
   663  
   664  	err = os.WriteFile(filename, formatted, 0o644)
   665  	if err != nil {
   666  		return fmt.Errorf("failed to write %s: %w", filename, err)
   667  	}
   668  
   669  	return nil
   670  }