
     1  package templates
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/types"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"reflect"
    11  	"runtime"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"text/template"
    16  	"unicode"
    18  	""
    20  	""
    21  )
    23  // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
    24  // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
    25  var CurrentImports *Imports
    27  // Options specify various parameters to rendering a template.
    28  type Options struct {
    29  	// PackageName is a helper that specifies the package header declaration.
    30  	// In other words, when you write the template you don't need to specify `package X`
    31  	// at the top of the file. By providing PackageName in the Options, the Render
    32  	// function will do that for you.
    33  	PackageName string
    34  	// Template is a string of the entire template that
    35  	// will be parsed and rendered. If it's empty,
    36  	// the plugin processor will look for .gotpl files
    37  	// in the same directory of where you wrote the plugin.
    38  	Template string
    39  	// Filename is the name of the file that will be
    40  	// written to the system disk once the template is rendered.
    41  	Filename        string
    42  	RegionTags      bool
    43  	GeneratedHeader bool
    44  	// PackageDoc is documentation written above the package line
    45  	PackageDoc string
    46  	// FileNotice is notice written below the package line
    47  	FileNotice string
    48  	// Data will be passed to the template execution.
    49  	Data  interface{}
    50  	Funcs template.FuncMap
    52  	// Packages cache, you can find me on config.Config
    53  	Packages *code.Packages
    54  }
    56  // Render renders a gql plugin template from the given Options. Render is an
    57  // abstraction of the text/template package that makes it easier to write gqlgen
    58  // plugins. If Options.Template is empty, the Render function will look for `.gotpl`
    59  // files inside the directory where you wrote the plugin.
    60  func Render(cfg Options) error {
    61  	if CurrentImports != nil {
    62  		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
    63  	}
    64  	CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
    66  	// load path relative to calling source file
    67  	_, callerFile, _, _ := runtime.Caller(1)
    68  	rootDir := filepath.Dir(callerFile)
    70  	funcs := Funcs()
    71  	for n, f := range cfg.Funcs {
    72  		funcs[n] = f
    73  	}
    74  	t := template.New("").Funcs(funcs)
    76  	var roots []string
    77  	if cfg.Template != "" {
    78  		var err error
    79  		t, err = t.New("template.gotpl").Parse(cfg.Template)
    80  		if err != nil {
    81  			return fmt.Errorf("error with provided template: %w", err)
    82  		}
    83  		roots = append(roots, "template.gotpl")
    84  	} else {
    85  		// load all the templates in the directory
    86  		err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
    87  			if err != nil {
    88  				return err
    89  			}
    90  			name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
    91  			if !strings.HasSuffix(info.Name(), ".gotpl") {
    92  				return nil
    93  			}
    94  			// omit any templates with "_" at the end of their name, which are meant for specific contexts only
    95  			if strings.HasSuffix(info.Name(), "_.gotpl") {
    96  				return nil
    97  			}
    98  			b, err := ioutil.ReadFile(path)
    99  			if err != nil {
   100  				return err
   101  			}
   103  			t, err = t.New(name).Parse(string(b))
   104  			if err != nil {
   105  				return fmt.Errorf("%s: %w", cfg.Filename, err)
   106  			}
   108  			roots = append(roots, name)
   110  			return nil
   111  		})
   112  		if err != nil {
   113  			return fmt.Errorf("locating templates: %w", err)
   114  		}
   115  	}
   117  	// then execute all the important looking ones in order, adding them to the same file
   118  	sort.Slice(roots, func(i, j int) bool {
   119  		// important files go first
   120  		if strings.HasSuffix(roots[i], "!.gotpl") {
   121  			return true
   122  		}
   123  		if strings.HasSuffix(roots[j], "!.gotpl") {
   124  			return false
   125  		}
   126  		return roots[i] < roots[j]
   127  	})
   128  	var buf bytes.Buffer
   129  	for _, root := range roots {
   130  		if cfg.RegionTags {
   131  			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
   132  		}
   133  		err := t.Lookup(root).Execute(&buf, cfg.Data)
   134  		if err != nil {
   135  			return fmt.Errorf("%s: %w", root, err)
   136  		}
   137  		if cfg.RegionTags {
   138  			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
   139  		}
   140  	}
   142  	var result bytes.Buffer
   143  	if cfg.GeneratedHeader {
   144  		result.WriteString("// Code generated by, DO NOT EDIT.\n\n")
   145  	}
   146  	if cfg.PackageDoc != "" {
   147  		result.WriteString(cfg.PackageDoc + "\n")
   148  	}
   149  	result.WriteString("package ")
   150  	result.WriteString(cfg.PackageName)
   151  	result.WriteString("\n\n")
   152  	if cfg.FileNotice != "" {
   153  		result.WriteString(cfg.FileNotice)
   154  		result.WriteString("\n\n")
   155  	}
   156  	result.WriteString("import (\n")
   157  	result.WriteString(CurrentImports.String())
   158  	result.WriteString(")\n")
   159  	_, err := buf.WriteTo(&result)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	CurrentImports = nil
   165  	err = write(cfg.Filename, result.Bytes(), cfg.Packages)
   166  	if err != nil {
   167  		return err
   168  	}
   170  	cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
   171  	return nil
   172  }
   174  func center(width int, pad string, s string) string {
   175  	if len(s)+2 > width {
   176  		return s
   177  	}
   178  	lpad := (width - len(s)) / 2
   179  	rpad := width - (lpad + len(s))
   180  	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
   181  }
   183  func Funcs() template.FuncMap {
   184  	return template.FuncMap{
   185  		"ucFirst":       UcFirst,
   186  		"lcFirst":       LcFirst,
   187  		"quote":         strconv.Quote,
   188  		"rawQuote":      rawQuote,
   189  		"dump":          Dump,
   190  		"ref":           ref,
   191  		"ts":            TypeIdentifier,
   192  		"call":          Call,
   193  		"prefixLines":   prefixLines,
   194  		"notNil":        notNil,
   195  		"reserveImport": CurrentImports.Reserve,
   196  		"lookupImport":  CurrentImports.Lookup,
   197  		"go":            ToGo,
   198  		"goPrivate":     ToGoPrivate,
   199  		"add": func(a, b int) int {
   200  			return a + b
   201  		},
   202  		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   203  			return render(resolveName(filename, 0), tpldata)
   204  		},
   205  	}
   206  }
   208  func UcFirst(s string) string {
   209  	if s == "" {
   210  		return ""
   211  	}
   212  	r := []rune(s)
   213  	r[0] = unicode.ToUpper(r[0])
   214  	return string(r)
   215  }
   217  func LcFirst(s string) string {
   218  	if s == "" {
   219  		return ""
   220  	}
   222  	r := []rune(s)
   223  	r[0] = unicode.ToLower(r[0])
   224  	return string(r)
   225  }
   227  func isDelimiter(c rune) bool {
   228  	return c == '-' || c == '_' || unicode.IsSpace(c)
   229  }
   231  func ref(p types.Type) string {
   232  	return CurrentImports.LookupType(p)
   233  }
   235  var pkgReplacer = strings.NewReplacer(
   236  	"/", "ᚋ",
   237  	".", "ᚗ",
   238  	"-", "ᚑ",
   239  	"~", "א",
   240  )
   242  func TypeIdentifier(t types.Type) string {
   243  	res := ""
   244  	for {
   245  		switch it := t.(type) {
   246  		case *types.Pointer:
   247  			t.Underlying()
   248  			res += "ᚖ"
   249  			t = it.Elem()
   250  		case *types.Slice:
   251  			res += "ᚕ"
   252  			t = it.Elem()
   253  		case *types.Named:
   254  			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
   255  			res += "ᚐ"
   256  			res += it.Obj().Name()
   257  			return res
   258  		case *types.Basic:
   259  			res += it.Name()
   260  			return res
   261  		case *types.Map:
   262  			res += "map"
   263  			return res
   264  		case *types.Interface:
   265  			res += "interface"
   266  			return res
   267  		default:
   268  			panic(fmt.Errorf("unexpected type %T", it))
   269  		}
   270  	}
   271  }
   273  func Call(p *types.Func) string {
   274  	pkg := CurrentImports.Lookup(p.Pkg().Path())
   276  	if pkg != "" {
   277  		pkg += "."
   278  	}
   280  	if p.Type() != nil {
   281  		// make sure the returned type is listed in our imports.
   282  		ref(p.Type().(*types.Signature).Results().At(0).Type())
   283  	}
   285  	return pkg + p.Name()
   286  }
   288  func ToGo(name string) string {
   289  	if name == "_" {
   290  		return "_"
   291  	}
   292  	runes := make([]rune, 0, len(name))
   294  	wordWalker(name, func(info *wordInfo) {
   295  		word := info.Word
   296  		if info.MatchCommonInitial {
   297  			word = strings.ToUpper(word)
   298  		} else if !info.HasCommonInitial {
   299  			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
   300  				// FOO or foo → Foo
   301  				// FOo → FOo
   302  				word = UcFirst(strings.ToLower(word))
   303  			}
   304  		}
   305  		runes = append(runes, []rune(word)...)
   306  	})
   308  	return string(runes)
   309  }
   311  func ToGoPrivate(name string) string {
   312  	if name == "_" {
   313  		return "_"
   314  	}
   315  	runes := make([]rune, 0, len(name))
   317  	first := true
   318  	wordWalker(name, func(info *wordInfo) {
   319  		word := info.Word
   320  		switch {
   321  		case first:
   322  			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
   323  				// ID → id, CAMEL → camel
   324  				word = strings.ToLower(info.Word)
   325  			} else {
   326  				// ITicket → iTicket
   327  				word = LcFirst(info.Word)
   328  			}
   329  			first = false
   330  		case info.MatchCommonInitial:
   331  			word = strings.ToUpper(word)
   332  		case !info.HasCommonInitial:
   333  			word = UcFirst(strings.ToLower(word))
   334  		}
   335  		runes = append(runes, []rune(word)...)
   336  	})
   338  	return sanitizeKeywords(string(runes))
   339  }
   341  type wordInfo struct {
   342  	Word               string
   343  	MatchCommonInitial bool
   344  	HasCommonInitial   bool
   345  }
   347  // This function is based on the following code.
   348  //
   349  func wordWalker(str string, f func(*wordInfo)) {
   350  	runes := []rune(strings.TrimFunc(str, isDelimiter))
   351  	w, i := 0, 0 // index of start of word, scan
   352  	hasCommonInitial := false
   353  	for i+1 <= len(runes) {
   354  		eow := false // whether we hit the end of a word
   355  		switch {
   356  		case i+1 == len(runes):
   357  			eow = true
   358  		case isDelimiter(runes[i+1]):
   359  			// underscore; shift the remainder forward over any run of underscores
   360  			eow = true
   361  			n := 1
   362  			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
   363  				n++
   364  			}
   366  			// Leave at most one underscore if the underscore is between two digits
   367  			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
   368  				n--
   369  			}
   371  			copy(runes[i+1:], runes[i+n+1:])
   372  			runes = runes[:len(runes)-n]
   373  		case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
   374  			// lower->non-lower
   375  			eow = true
   376  		}
   377  		i++
   379  		// [w,i) is a word.
   380  		word := string(runes[w:i])
   381  		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
   382  			// through
   383  			// split IDFoo → ID, Foo
   384  			// but URLs → URLs
   385  		} else if !eow {
   386  			if commonInitialisms[word] {
   387  				hasCommonInitial = true
   388  			}
   389  			continue
   390  		}
   392  		matchCommonInitial := false
   393  		if commonInitialisms[strings.ToUpper(word)] {
   394  			hasCommonInitial = true
   395  			matchCommonInitial = true
   396  		}
   398  		f(&wordInfo{
   399  			Word:               word,
   400  			MatchCommonInitial: matchCommonInitial,
   401  			HasCommonInitial:   hasCommonInitial,
   402  		})
   403  		hasCommonInitial = false
   404  		w = i
   405  	}
   406  }
   408  var keywords = []string{
   409  	"break",
   410  	"default",
   411  	"func",
   412  	"interface",
   413  	"select",
   414  	"case",
   415  	"defer",
   416  	"go",
   417  	"map",
   418  	"struct",
   419  	"chan",
   420  	"else",
   421  	"goto",
   422  	"package",
   423  	"switch",
   424  	"const",
   425  	"fallthrough",
   426  	"if",
   427  	"range",
   428  	"type",
   429  	"continue",
   430  	"for",
   431  	"import",
   432  	"return",
   433  	"var",
   434  	"_",
   435  }
   437  // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
   438  func sanitizeKeywords(name string) string {
   439  	for _, k := range keywords {
   440  		if name == k {
   441  			return name + "Arg"
   442  		}
   443  	}
   444  	return name
   445  }
   447  // commonInitialisms is a set of common initialisms.
   448  // Only add entries that are highly unlikely to be non-initialisms.
   449  // For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
   450  var commonInitialisms = map[string]bool{
   451  	"ACL":   true,
   452  	"API":   true,
   453  	"ASCII": true,
   454  	"CPU":   true,
   455  	"CSS":   true,
   456  	"CSV":   true,
   457  	"DNS":   true,
   458  	"EOF":   true,
   459  	"GUID":  true,
   460  	"HTML":  true,
   461  	"HTTP":  true,
   462  	"HTTPS": true,
   463  	"ICMP":  true,
   464  	"ID":    true,
   465  	"IP":    true,
   466  	"JSON":  true,
   467  	"KVK":   true,
   468  	"LHS":   true,
   469  	"PDF":   true,
   470  	"PGP":   true,
   471  	"QPS":   true,
   472  	"QR":    true,
   473  	"RAM":   true,
   474  	"RHS":   true,
   475  	"RPC":   true,
   476  	"SLA":   true,
   477  	"SMTP":  true,
   478  	"SQL":   true,
   479  	"SSH":   true,
   480  	"SVG":   true,
   481  	"TCP":   true,
   482  	"TLS":   true,
   483  	"TTL":   true,
   484  	"UDP":   true,
   485  	"UI":    true,
   486  	"UID":   true,
   487  	"URI":   true,
   488  	"URL":   true,
   489  	"UTF8":  true,
   490  	"UUID":  true,
   491  	"VM":    true,
   492  	"XML":   true,
   493  	"XMPP":  true,
   494  	"XSRF":  true,
   495  	"XSS":   true,
   496  }
   498  func rawQuote(s string) string {
   499  	return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
   500  }
   502  func notNil(field string, data interface{}) bool {
   503  	v := reflect.ValueOf(data)
   505  	if v.Kind() == reflect.Ptr {
   506  		v = v.Elem()
   507  	}
   508  	if v.Kind() != reflect.Struct {
   509  		return false
   510  	}
   511  	val := v.FieldByName(field)
   513  	return val.IsValid() && !val.IsNil()
   514  }
   516  func Dump(val interface{}) string {
   517  	switch val := val.(type) {
   518  	case int:
   519  		return strconv.Itoa(val)
   520  	case int64:
   521  		return fmt.Sprintf("%d", val)
   522  	case float64:
   523  		return fmt.Sprintf("%f", val)
   524  	case string:
   525  		return strconv.Quote(val)
   526  	case bool:
   527  		return strconv.FormatBool(val)
   528  	case nil:
   529  		return "nil"
   530  	case []interface{}:
   531  		var parts []string
   532  		for _, part := range val {
   533  			parts = append(parts, Dump(part))
   534  		}
   535  		return "[]interface{}{" + strings.Join(parts, ",") + "}"
   536  	case map[string]interface{}:
   537  		buf := bytes.Buffer{}
   538  		buf.WriteString("map[string]interface{}{")
   539  		var keys []string
   540  		for key := range val {
   541  			keys = append(keys, key)
   542  		}
   543  		sort.Strings(keys)
   545  		for _, key := range keys {
   546  			data := val[key]
   548  			buf.WriteString(strconv.Quote(key))
   549  			buf.WriteString(":")
   550  			buf.WriteString(Dump(data))
   551  			buf.WriteString(",")
   552  		}
   553  		buf.WriteString("}")
   554  		return buf.String()
   555  	default:
   556  		panic(fmt.Errorf("unsupported type %T", val))
   557  	}
   558  }
   560  func prefixLines(prefix, s string) string {
   561  	return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
   562  }
   564  func resolveName(name string, skip int) string {
   565  	if name[0] == '.' {
   566  		// load path relative to calling source file
   567  		_, callerFile, _, _ := runtime.Caller(skip + 1)
   568  		return filepath.Join(filepath.Dir(callerFile), name[1:])
   569  	}
   571  	// load path relative to this directory
   572  	_, callerFile, _, _ := runtime.Caller(0)
   573  	return filepath.Join(filepath.Dir(callerFile), name)
   574  }
   576  func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   577  	t := template.New("").Funcs(Funcs())
   579  	b, err := ioutil.ReadFile(filename)
   580  	if err != nil {
   581  		return nil, err
   582  	}
   584  	t, err = t.New(filepath.Base(filename)).Parse(string(b))
   585  	if err != nil {
   586  		panic(err)
   587  	}
   589  	buf := &bytes.Buffer{}
   590  	return buf, t.Execute(buf, tpldata)
   591  }
   593  func write(filename string, b []byte, packages *code.Packages) error {
   594  	err := os.MkdirAll(filepath.Dir(filename), 0o755)
   595  	if err != nil {
   596  		return fmt.Errorf("failed to create directory: %w", err)
   597  	}
   599  	formatted, err := imports.Prune(filename, b, packages)
   600  	if err != nil {
   601  		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
   602  		formatted = b
   603  	}
   605  	err = ioutil.WriteFile(filename, formatted, 0o644)
   606  	if err != nil {
   607  		return fmt.Errorf("failed to write %s: %w", filename, err)
   608  	}
   610  	return nil
   611  }