github.com/fortexxx/gqlgen@v0.10.3-0.20191216030626-ca5ea8b21ead/codegen/templates/templates.go (about)

     1  package templates
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/types"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"reflect"
    11  	"runtime"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"text/template"
    16  	"unicode"
    17  
    18  	"github.com/99designs/gqlgen/internal/imports"
    19  	"github.com/pkg/errors"
    20  )
    21  
    22  // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
    23  // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
    24  var CurrentImports *Imports
    25  
    26  // Options specify various parameters to rendering a template.
    27  type Options struct {
    28  	// PackageName is a helper that specifies the package header declaration.
    29  	// In other words, when you write the template you don't need to specify `package X`
    30  	// at the top of the file. By providing PackageName in the Options, the Render
    31  	// function will do that for you.
    32  	PackageName string
    33  	// Template is a string of the entire template that
    34  	// will be parsed and rendered. If it's empty,
    35  	// the plugin processor will look for .gotpl files
    36  	// in the same directory of where you wrote the plugin.
    37  	Template string
    38  	// Filename is the name of the file that will be
    39  	// written to the system disk once the template is rendered.
    40  	Filename        string
    41  	RegionTags      bool
    42  	GeneratedHeader bool
    43  	// Data will be passed to the template execution.
    44  	Data  interface{}
    45  	Funcs template.FuncMap
    46  }
    47  
    48  // Render renders a gql plugin template from the given Options. Render is an
    49  // abstraction of the text/template package that makes it easier to write gqlgen
    50  // plugins. If Options.Template is empty, the Render function will look for `.gotpl`
    51  // files inside the directory where you wrote the plugin.
    52  func Render(cfg Options) error {
    53  	if CurrentImports != nil {
    54  		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
    55  	}
    56  	CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
    57  
    58  	// load path relative to calling source file
    59  	_, callerFile, _, _ := runtime.Caller(1)
    60  	rootDir := filepath.Dir(callerFile)
    61  
    62  	funcs := Funcs()
    63  	for n, f := range cfg.Funcs {
    64  		funcs[n] = f
    65  	}
    66  	t := template.New("").Funcs(funcs)
    67  
    68  	var roots []string
    69  	if cfg.Template != "" {
    70  		var err error
    71  		t, err = t.New("template.gotpl").Parse(cfg.Template)
    72  		if err != nil {
    73  			return errors.Wrap(err, "error with provided template")
    74  		}
    75  		roots = append(roots, "template.gotpl")
    76  	} else {
    77  		// load all the templates in the directory
    78  		err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
    79  			if err != nil {
    80  				return err
    81  			}
    82  			name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
    83  			if !strings.HasSuffix(info.Name(), ".gotpl") {
    84  				return nil
    85  			}
    86  			b, err := ioutil.ReadFile(path)
    87  			if err != nil {
    88  				return err
    89  			}
    90  
    91  			t, err = t.New(name).Parse(string(b))
    92  			if err != nil {
    93  				return errors.Wrap(err, cfg.Filename)
    94  			}
    95  
    96  			roots = append(roots, name)
    97  
    98  			return nil
    99  		})
   100  		if err != nil {
   101  			return errors.Wrap(err, "locating templates")
   102  		}
   103  	}
   104  
   105  	// then execute all the important looking ones in order, adding them to the same file
   106  	sort.Slice(roots, func(i, j int) bool {
   107  		// important files go first
   108  		if strings.HasSuffix(roots[i], "!.gotpl") {
   109  			return true
   110  		}
   111  		if strings.HasSuffix(roots[j], "!.gotpl") {
   112  			return false
   113  		}
   114  		return roots[i] < roots[j]
   115  	})
   116  	var buf bytes.Buffer
   117  	for _, root := range roots {
   118  		if cfg.RegionTags {
   119  			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
   120  		}
   121  		err := t.Lookup(root).Execute(&buf, cfg.Data)
   122  		if err != nil {
   123  			return errors.Wrap(err, root)
   124  		}
   125  		if cfg.RegionTags {
   126  			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
   127  		}
   128  	}
   129  
   130  	var result bytes.Buffer
   131  	if cfg.GeneratedHeader {
   132  		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
   133  	}
   134  	result.WriteString("package ")
   135  	result.WriteString(cfg.PackageName)
   136  	result.WriteString("\n\n")
   137  	result.WriteString("import (\n")
   138  	result.WriteString(CurrentImports.String())
   139  	result.WriteString(")\n")
   140  	_, err := buf.WriteTo(&result)
   141  	if err != nil {
   142  		return err
   143  	}
   144  	CurrentImports = nil
   145  
   146  	return write(cfg.Filename, result.Bytes())
   147  }
   148  
   149  func center(width int, pad string, s string) string {
   150  	if len(s)+2 > width {
   151  		return s
   152  	}
   153  	lpad := (width - len(s)) / 2
   154  	rpad := width - (lpad + len(s))
   155  	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
   156  }
   157  
   158  func Funcs() template.FuncMap {
   159  	return template.FuncMap{
   160  		"ucFirst":       ucFirst,
   161  		"lcFirst":       lcFirst,
   162  		"quote":         strconv.Quote,
   163  		"rawQuote":      rawQuote,
   164  		"dump":          Dump,
   165  		"ref":           ref,
   166  		"ts":            TypeIdentifier,
   167  		"call":          Call,
   168  		"prefixLines":   prefixLines,
   169  		"notNil":        notNil,
   170  		"reserveImport": CurrentImports.Reserve,
   171  		"lookupImport":  CurrentImports.Lookup,
   172  		"go":            ToGo,
   173  		"goPrivate":     ToGoPrivate,
   174  		"title":         strings.Title,
   175  		"add": func(a, b int) int {
   176  			return a + b
   177  		},
   178  		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   179  			return render(resolveName(filename, 0), tpldata)
   180  		},
   181  	}
   182  }
   183  
   184  func ucFirst(s string) string {
   185  	if s == "" {
   186  		return ""
   187  	}
   188  	r := []rune(s)
   189  	r[0] = unicode.ToUpper(r[0])
   190  	return string(r)
   191  }
   192  
   193  func lcFirst(s string) string {
   194  	if s == "" {
   195  		return ""
   196  	}
   197  
   198  	r := []rune(s)
   199  	r[0] = unicode.ToLower(r[0])
   200  	return string(r)
   201  }
   202  
   203  func isDelimiter(c rune) bool {
   204  	return c == '-' || c == '_' || unicode.IsSpace(c)
   205  }
   206  
   207  func ref(p types.Type) string {
   208  	return CurrentImports.LookupType(p)
   209  }
   210  
   211  var pkgReplacer = strings.NewReplacer(
   212  	"/", "ᚋ",
   213  	".", "ᚗ",
   214  	"-", "ᚑ",
   215  )
   216  
   217  func TypeIdentifier(t types.Type) string {
   218  	res := ""
   219  	for {
   220  		switch it := t.(type) {
   221  		case *types.Pointer:
   222  			t.Underlying()
   223  			res += "ᚖ"
   224  			t = it.Elem()
   225  		case *types.Slice:
   226  			res += "ᚕ"
   227  			t = it.Elem()
   228  		case *types.Named:
   229  			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
   230  			res += "ᚐ"
   231  			res += it.Obj().Name()
   232  			return res
   233  		case *types.Basic:
   234  			res += it.Name()
   235  			return res
   236  		case *types.Map:
   237  			res += "map"
   238  			return res
   239  		case *types.Interface:
   240  			res += "interface"
   241  			return res
   242  		default:
   243  			panic(fmt.Errorf("unexpected type %T", it))
   244  		}
   245  	}
   246  }
   247  
   248  func Call(p *types.Func) string {
   249  	pkg := CurrentImports.Lookup(p.Pkg().Path())
   250  
   251  	if pkg != "" {
   252  		pkg += "."
   253  	}
   254  
   255  	if p.Type() != nil {
   256  		// make sure the returned type is listed in our imports.
   257  		ref(p.Type().(*types.Signature).Results().At(0).Type())
   258  	}
   259  
   260  	return pkg + p.Name()
   261  }
   262  
   263  func ToGo(name string) string {
   264  	runes := make([]rune, 0, len(name))
   265  
   266  	wordWalker(name, func(info *wordInfo) {
   267  		word := info.Word
   268  		if info.MatchCommonInitial {
   269  			word = strings.ToUpper(word)
   270  		} else if !info.HasCommonInitial {
   271  			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
   272  				// FOO or foo → Foo
   273  				// FOo → FOo
   274  				word = ucFirst(strings.ToLower(word))
   275  			}
   276  		}
   277  		runes = append(runes, []rune(word)...)
   278  	})
   279  
   280  	return string(runes)
   281  }
   282  
   283  func ToGoPrivate(name string) string {
   284  	runes := make([]rune, 0, len(name))
   285  
   286  	first := true
   287  	wordWalker(name, func(info *wordInfo) {
   288  		word := info.Word
   289  		switch {
   290  		case first:
   291  			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
   292  				// ID → id, CAMEL → camel
   293  				word = strings.ToLower(info.Word)
   294  			} else {
   295  				// ITicket → iTicket
   296  				word = lcFirst(info.Word)
   297  			}
   298  			first = false
   299  		case info.MatchCommonInitial:
   300  			word = strings.ToUpper(word)
   301  		case !info.HasCommonInitial:
   302  			word = ucFirst(strings.ToLower(word))
   303  		}
   304  		runes = append(runes, []rune(word)...)
   305  	})
   306  
   307  	return sanitizeKeywords(string(runes))
   308  }
   309  
   310  type wordInfo struct {
   311  	Word               string
   312  	MatchCommonInitial bool
   313  	HasCommonInitial   bool
   314  }
   315  
   316  // This function is based on the following code.
   317  // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
   318  func wordWalker(str string, f func(*wordInfo)) {
   319  	runes := []rune(str)
   320  	w, i := 0, 0 // index of start of word, scan
   321  	hasCommonInitial := false
   322  	for i+1 <= len(runes) {
   323  		eow := false // whether we hit the end of a word
   324  		switch {
   325  		case i+1 == len(runes):
   326  			eow = true
   327  		case isDelimiter(runes[i+1]):
   328  			// underscore; shift the remainder forward over any run of underscores
   329  			eow = true
   330  			n := 1
   331  			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
   332  				n++
   333  			}
   334  
   335  			// Leave at most one underscore if the underscore is between two digits
   336  			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
   337  				n--
   338  			}
   339  
   340  			copy(runes[i+1:], runes[i+n+1:])
   341  			runes = runes[:len(runes)-n]
   342  		case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
   343  			// lower->non-lower
   344  			eow = true
   345  		}
   346  		i++
   347  
   348  		// [w,i) is a word.
   349  		word := string(runes[w:i])
   350  		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
   351  			// through
   352  			// split IDFoo → ID, Foo
   353  			// but URLs → URLs
   354  		} else if !eow {
   355  			if commonInitialisms[word] {
   356  				hasCommonInitial = true
   357  			}
   358  			continue
   359  		}
   360  
   361  		matchCommonInitial := false
   362  		if commonInitialisms[strings.ToUpper(word)] {
   363  			hasCommonInitial = true
   364  			matchCommonInitial = true
   365  		}
   366  
   367  		f(&wordInfo{
   368  			Word:               word,
   369  			MatchCommonInitial: matchCommonInitial,
   370  			HasCommonInitial:   hasCommonInitial,
   371  		})
   372  		hasCommonInitial = false
   373  		w = i
   374  	}
   375  }
   376  
   377  var keywords = []string{
   378  	"break",
   379  	"default",
   380  	"func",
   381  	"interface",
   382  	"select",
   383  	"case",
   384  	"defer",
   385  	"go",
   386  	"map",
   387  	"struct",
   388  	"chan",
   389  	"else",
   390  	"goto",
   391  	"package",
   392  	"switch",
   393  	"const",
   394  	"fallthrough",
   395  	"if",
   396  	"range",
   397  	"type",
   398  	"continue",
   399  	"for",
   400  	"import",
   401  	"return",
   402  	"var",
   403  	"_",
   404  }
   405  
   406  // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
   407  func sanitizeKeywords(name string) string {
   408  	for _, k := range keywords {
   409  		if name == k {
   410  			return name + "Arg"
   411  		}
   412  	}
   413  	return name
   414  }
   415  
   416  // commonInitialisms is a set of common initialisms.
   417  // Only add entries that are highly unlikely to be non-initialisms.
   418  // For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
   419  var commonInitialisms = map[string]bool{
   420  	"ACL":   true,
   421  	"API":   true,
   422  	"ASCII": true,
   423  	"CPU":   true,
   424  	"CSS":   true,
   425  	"DNS":   true,
   426  	"EOF":   true,
   427  	"GUID":  true,
   428  	"HTML":  true,
   429  	"HTTP":  true,
   430  	"HTTPS": true,
   431  	"ID":    true,
   432  	"IP":    true,
   433  	"JSON":  true,
   434  	"LHS":   true,
   435  	"QPS":   true,
   436  	"RAM":   true,
   437  	"RHS":   true,
   438  	"RPC":   true,
   439  	"SLA":   true,
   440  	"SMTP":  true,
   441  	"SQL":   true,
   442  	"SSH":   true,
   443  	"TCP":   true,
   444  	"TLS":   true,
   445  	"TTL":   true,
   446  	"UDP":   true,
   447  	"UI":    true,
   448  	"UID":   true,
   449  	"UUID":  true,
   450  	"URI":   true,
   451  	"URL":   true,
   452  	"UTF8":  true,
   453  	"VM":    true,
   454  	"XML":   true,
   455  	"XMPP":  true,
   456  	"XSRF":  true,
   457  	"XSS":   true,
   458  }
   459  
   460  func rawQuote(s string) string {
   461  	return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
   462  }
   463  
   464  func notNil(field string, data interface{}) bool {
   465  	v := reflect.ValueOf(data)
   466  
   467  	if v.Kind() == reflect.Ptr {
   468  		v = v.Elem()
   469  	}
   470  	if v.Kind() != reflect.Struct {
   471  		return false
   472  	}
   473  	val := v.FieldByName(field)
   474  
   475  	return val.IsValid() && !val.IsNil()
   476  }
   477  
   478  func Dump(val interface{}) string {
   479  	switch val := val.(type) {
   480  	case int:
   481  		return strconv.Itoa(val)
   482  	case int64:
   483  		return fmt.Sprintf("%d", val)
   484  	case float64:
   485  		return fmt.Sprintf("%f", val)
   486  	case string:
   487  		return strconv.Quote(val)
   488  	case bool:
   489  		return strconv.FormatBool(val)
   490  	case nil:
   491  		return "nil"
   492  	case []interface{}:
   493  		var parts []string
   494  		for _, part := range val {
   495  			parts = append(parts, Dump(part))
   496  		}
   497  		return "[]interface{}{" + strings.Join(parts, ",") + "}"
   498  	case map[string]interface{}:
   499  		buf := bytes.Buffer{}
   500  		buf.WriteString("map[string]interface{}{")
   501  		var keys []string
   502  		for key := range val {
   503  			keys = append(keys, key)
   504  		}
   505  		sort.Strings(keys)
   506  
   507  		for _, key := range keys {
   508  			data := val[key]
   509  
   510  			buf.WriteString(strconv.Quote(key))
   511  			buf.WriteString(":")
   512  			buf.WriteString(Dump(data))
   513  			buf.WriteString(",")
   514  		}
   515  		buf.WriteString("}")
   516  		return buf.String()
   517  	default:
   518  		panic(fmt.Errorf("unsupported type %T", val))
   519  	}
   520  }
   521  
   522  func prefixLines(prefix, s string) string {
   523  	return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
   524  }
   525  
   526  func resolveName(name string, skip int) string {
   527  	if name[0] == '.' {
   528  		// load path relative to calling source file
   529  		_, callerFile, _, _ := runtime.Caller(skip + 1)
   530  		return filepath.Join(filepath.Dir(callerFile), name[1:])
   531  	}
   532  
   533  	// load path relative to this directory
   534  	_, callerFile, _, _ := runtime.Caller(0)
   535  	return filepath.Join(filepath.Dir(callerFile), name)
   536  }
   537  
   538  func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
   539  	t := template.New("").Funcs(Funcs())
   540  
   541  	b, err := ioutil.ReadFile(filename)
   542  	if err != nil {
   543  		return nil, err
   544  	}
   545  
   546  	t, err = t.New(filepath.Base(filename)).Parse(string(b))
   547  	if err != nil {
   548  		panic(err)
   549  	}
   550  
   551  	buf := &bytes.Buffer{}
   552  	return buf, t.Execute(buf, tpldata)
   553  }
   554  
   555  func write(filename string, b []byte) error {
   556  	err := os.MkdirAll(filepath.Dir(filename), 0755)
   557  	if err != nil {
   558  		return errors.Wrap(err, "failed to create directory")
   559  	}
   560  
   561  	formatted, err := imports.Prune(filename, b)
   562  	if err != nil {
   563  		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
   564  		formatted = b
   565  	}
   566  
   567  	err = ioutil.WriteFile(filename, formatted, 0644)
   568  	if err != nil {
   569  		return errors.Wrapf(err, "failed to write %s", filename)
   570  	}
   571  
   572  	return nil
   573  }